鱼C论坛

 找回密码
 立即注册
查看: 261|回复: 1

训练模型代码要返回两个列表在哪里添加

[复制链接]
发表于 2024-11-6 11:11:34 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. import gc,os
  2. from tqdm.auto import tqdm # visualizing tool for progress

  3. # They will be used to pick the best model.pt given to the valid loss
  4. best_model_epoch, valid_loss_values = [],[]
  5. valid_loss_min = [1] # arbitrary loss I set here
  6. def train(model,device,train_dataloader,valid_dataloader,epochs,loss_fn,optimizer,metric):

  7.     for epoch in range(epochs):
  8.         gc.collect() # memory cleaning垃圾回收机制,减少占用内存
  9.         model.train()

  10.         train_loss = 0
  11.         train_step = 0
  12.         pbar = tqdm(train_dataloader)#tqdm参数是一个iterable

  13.         for batch in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  14.             optimizer.zero_grad() # initialize
  15.             train_step += 1

  16.             train_input_ids = batch['input_ids'].to(device)
  17.             train_attention_mask = batch['attention_mask'].to(device)
  18.             train_labels = batch['labels'].squeeze().to(device).long()#long()转化成一维张量
  19.             
  20.             # You can refer to the class "TweetsModel" for understand
  21.             # what would be logits
  22.             logits = model(train_input_ids, train_attention_mask).to(device)
  23.             predictions = torch.argmax(logits, dim=1) # get an index from larger one
  24.             detached_predictions = predictions.detach().cpu().numpy()
  25.             
  26.             loss = loss_fn(logits, train_labels)
  27.             loss.backward()
  28.             optimizer.step()
  29.             model.zero_grad()

  30.             train_loss += loss.detach().cpu().numpy().item()

  31.             pbar.set_postfix({'train_loss':train_loss/train_step})#设置进度条显示信息
  32.         pbar.close()

  33.         with torch.no_grad():
  34.             model.eval()

  35.             valid_loss = 0
  36.             valid_step = 0
  37.             total_valid_score = 0

  38.             y_pred = [] # for getting f1_score that is a metric of the competition
  39.             y_true = []

  40.             pbar = tqdm(valid_dataloader)
  41.             for batch in pbar:
  42.                 valid_step += 1

  43.                 valid_input_ids = batch['input_ids'].to(device)
  44.                 valid_attention_mask = batch['attention_mask'].to(device)
  45.                 valid_labels = batch['labels'].squeeze().to(device).long()

  46.                 logits = model(valid_input_ids, valid_attention_mask).to(device)
  47.                 predictions = torch.argmax(logits, dim=1)
  48.                 detached_predictions = predictions.detach().cpu().numpy()
  49.                
  50.                 loss = loss_fn(logits, valid_labels)
  51.                 valid_loss += loss.detach().cpu().numpy().item()

  52.                 y_pred.extend(predictions.cpu().numpy())
  53.                 y_true.extend(valid_labels.cpu().numpy())

  54.             valid_loss /= valid_step
  55.             f1 = f1_score(y_true,y_pred)

  56.             print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
  57.             print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')

  58.             if valid_loss < min(valid_loss_min):
  59.                 print('model improved!')
  60.             else:
  61.                 print('model not improved')
  62.    
  63.             torch.save(model.state_dict(), f'epoch:{epoch+1}_model.pt')#state_dict 是一个字典对象,包含了模型的所有可学习参数(如权重和偏置)及其当前值
  64.             print('save checkpoint!')
  65.             valid_loss_min.append(valid_loss)
  66.             print(f'valid_loss_min:{min(valid_loss_min)}')

  67.         best_model_epoch.append(f'/kaggle/working/epoch:{epoch+1}_model.pt')
  68.         valid_loss_values.append(valid_loss)
  69.         print('='*100)

  70.     select_best_model() # refer to below function
  71.     print('Train/Valid Completed!!')
  72.     del train_dataloader, valid_dataloader # memory cleaning
  73.     gc.collect()

  74. def select_best_model():
  75.     best_model = best_model_epoch[np.array(valid_loss_values).argmin()]
  76.     os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')#重命名文件
复制代码


我想在上述代码中返回y_pred y_true的列表并pd.concat起来返回一个新的dataframe,要在函数内部哪个位置添加return语句
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-11-6 11:18:21 | 显示全部楼层
要在你的`train`函数中返回两个列表 `y_pred` 和 `y_true`,并进一步通过 `pd.concat` 将它们合并成一个新的 DataFrame,可以在函数的最后部分添加以下代码:

  1. import pandas as pd

  2. def train(model, device, train_dataloader, valid_dataloader, epochs, loss_fn, optimizer, metric):
  3.     # ... (existing code)

  4.     # 在训练循环结束后,返回 y_pred 和 y_true 列表
  5.     return y_pred, y_true

  6. # 在调用 train 函数的地方,进行列表的合并
  7. y_pred, y_true = train(model, device, train_dataloader, valid_dataloader, epochs, loss_fn, optimizer, metric)

  8. # 使用 y_pred 和 y_true 构建 DataFrame
  9. df = pd.DataFrame({'y_pred': y_pred, 'y_true': y_true})

  10. # 打印或返回该 DataFrame
  11. print(df)
复制代码


请将 `train` 函数的代码末尾加上 `return y_pred, y_true` 这一行,以确保函数返回这两个列表。然后,在函数调用后,你可以使用 `pd.DataFrame` 将这两个列表组合成一个新的 DataFrame。

注意:在上述代码中,确保你在函数内部初始化了 `y_pred` 和 `y_true` 为列表对象,例如 `y_pred = []` 和 `y_true = []`,这样 `extend` 方法才能正常工作。同时确保 `pandas` 库已经安装并导入。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-18 16:42

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表