|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
- import gc,os
- from tqdm.auto import tqdm # visualizing tool for progress
- # They will be used to pick the best model.pt given to the valid loss
- best_model_epoch, valid_loss_values = [],[]
- valid_loss_min = [1] # arbitrary loss I set here
- def train(model,device,train_dataloader,valid_dataloader,epochs,loss_fn,optimizer,metric):
- for epoch in range(epochs):
- gc.collect() # memory cleaning垃圾回收机制,减少占用内存
- model.train()
- train_loss = 0
- train_step = 0
- pbar = tqdm(train_dataloader)#tqdm参数是一个iterable
- for batch in pbar: # you can also write like "for batch in tqdm(train_dataloader"
- optimizer.zero_grad() # initialize
- train_step += 1
- train_input_ids = batch['input_ids'].to(device)
- train_attention_mask = batch['attention_mask'].to(device)
- train_labels = batch['labels'].squeeze().to(device).long()#long()转化成一维张量
-
- # You can refer to the class "TweetsModel" for understand
- # what would be logits
- logits = model(train_input_ids, train_attention_mask).to(device)
- predictions = torch.argmax(logits, dim=1) # get an index from larger one
- detached_predictions = predictions.detach().cpu().numpy()
-
- loss = loss_fn(logits, train_labels)
- loss.backward()
- optimizer.step()
- model.zero_grad()
- train_loss += loss.detach().cpu().numpy().item()
- pbar.set_postfix({'train_loss':train_loss/train_step})#设置进度条显示信息
- pbar.close()
- with torch.no_grad():
- model.eval()
- valid_loss = 0
- valid_step = 0
- total_valid_score = 0
- y_pred = [] # for getting f1_score that is a metric of the competition
- y_true = []
- pbar = tqdm(valid_dataloader)
- for batch in pbar:
- valid_step += 1
- valid_input_ids = batch['input_ids'].to(device)
- valid_attention_mask = batch['attention_mask'].to(device)
- valid_labels = batch['labels'].squeeze().to(device).long()
- logits = model(valid_input_ids, valid_attention_mask).to(device)
- predictions = torch.argmax(logits, dim=1)
- detached_predictions = predictions.detach().cpu().numpy()
-
- loss = loss_fn(logits, valid_labels)
- valid_loss += loss.detach().cpu().numpy().item()
- y_pred.extend(predictions.cpu().numpy())
- y_true.extend(valid_labels.cpu().numpy())
- valid_loss /= valid_step
- f1 = f1_score(y_true,y_pred)
- print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
- print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')
- if valid_loss < min(valid_loss_min):
- print('model improved!')
- else:
- print('model not improved')
-
- torch.save(model.state_dict(), f'epoch:{epoch+1}_model.pt')#state_dict 是一个字典对象,包含了模型的所有可学习参数(如权重和偏置)及其当前值
- print('save checkpoint!')
- valid_loss_min.append(valid_loss)
- print(f'valid_loss_min:{min(valid_loss_min)}')
- best_model_epoch.append(f'/kaggle/working/epoch:{epoch+1}_model.pt')
- valid_loss_values.append(valid_loss)
- print('='*100)
- select_best_model() # refer to below function
- print('Train/Valid Completed!!')
- del train_dataloader, valid_dataloader # memory cleaning
- gc.collect()
- def select_best_model():
- best_model = best_model_epoch[np.array(valid_loss_values).argmin()]
- os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')#重命名文件
复制代码
我想在上述代码中返回y_pred y_true的列表并pd.concat起来返回一个新的dataframe,要在函数内部哪个位置添加return语句 |
|