鱼C论坛

 找回密码
 立即注册
查看: 317|回复: 2

训练模型中如何增加梯度计算

[复制链接]
发表于 2024-11-3 13:14:17 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. class CustomModel(nn.Module):
  2.     def __init__(self, cfg, config_path=None, pretrained=False):
  3.         super().__init__()
  4.         self.cfg = cfg
  5.         if config_path is None:
  6.             self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
  7.             self.config.hidden_dropout = 0.
  8.             self.config.hidden_dropout_prob = 0.
  9.             self.config.attention_dropout = 0.
  10.             self.config.attention_probs_dropout_prob = 0.
  11.             self.config.add_pooling_layer = False
  12.         else:
  13.             self.config = torch.load(config_path)
  14.         if pretrained:
  15.             self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
  16.         else:
  17.             self.model = AutoModel.from_config(self.config)
  18.         
  19.         self.model.resize_token_embeddings(len(CFG.tokenizer))
  20.         if self.cfg.gradient_checkpointing:
  21.             self.model.gradient_checkpointing_enable()
  22.         self.fc = nn.Linear(self.config.hidden_size, self.cfg.num_labels)
  23.         self._init_weights(self.fc)
  24.         
  25.     def _init_weights(self, module):
  26.         if isinstance(module, nn.Linear):
  27.             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  28.             if module.bias is not None:
  29.                 module.bias.data.zero_()
  30.         elif isinstance(module, nn.Embedding):
  31.             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  32.             if module.padding_idx is not None:
  33.                 module.weight.data[module.padding_idx].zero_()
  34.         elif isinstance(module, nn.LayerNorm):
  35.             module.bias.data.zero_()
  36.             module.weight.data.fill_(1.0)
  37.         
  38.     def feature(self, input_ids, attention_mask, token_type_ids ):
  39.         outputs = self.model( input_ids, attention_mask, token_type_ids)
  40.         last_hidden_states = outputs[0]
  41.         feature = last_hidden_states[:, 0, :] ## CLS token
  42.         return feature

  43.     def forward(self,  input_ids, attention_mask, token_type_ids):
  44.         feature = self.feature( input_ids, attention_mask, token_type_ids)
  45.         output = self.fc(feature)
  46.         return output.squeeze(-1)
复制代码

  1. # ====================================================
  2. # inference
  3. # ====================================================
  4. def inference_fn(test_loader, model, device):
  5.     preds = []
  6.     model.eval()
  7.     model.to(device)
  8.     tk0 = tqdm(test_loader, total=len(test_loader))
  9.     for inputs in tk0:
  10.         inputs["input_ids"] = inputs["input_ids"].to(device)
  11.         inputs["attention_mask"] = inputs["attention_mask"].to(device)
  12.         inputs['token_type_ids'] = inputs['token_type_ids'].to(device)
  13.         
  14.         with torch.no_grad():
  15.             with torch.cuda.amp.autocast(
  16.                 enabled=True, dtype=torch.float16, cache_enabled=True
  17.             ):
  18.                 y_preds = model(inputs["input_ids"], inputs["attention_mask"],inputs['token_type_ids']  )
  19.         preds.append(y_preds.to("cpu").numpy().astype(np.float32))
  20.     predictions = np.concatenate(preds)
  21.     return predictions
复制代码


上述代码不使用梯度更新仅适用于测试集,如何修改可将其用于训练集
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-11-3 13:14:23 | 显示全部楼层
要将上述代码用于训练集,你需要在模型的训练过程中计算梯度并更新模型参数。可以通过以下步骤实现:

1. **设置模型为训练模式**:使用`model.train()`。
2. **定义损失函数和优化器**:选择合适的损失函数(如`nn.CrossEntropyLoss()`)和优化器(如`torch.optim.AdamW()`)。
3. **计算损失并反向传播**:使用`loss.backward()`计算梯度。
4. **更新模型参数**:使用`optimizer.step()`更新参数。
5. **清零梯度**:在每个训练步骤开始前使用`optimizer.zero_grad()`清除上一步的梯度。

以下是如何修改代码以适用于训练集的示例:

  1. # 定义损失函数和优化器
  2. criterion = nn.CrossEntropyLoss()
  3. optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

  4. # 训练函数
  5. def train_fn(train_loader, model, device, optimizer, criterion):
  6.     model.train()
  7.     model.to(device)
  8.     tk0 = tqdm(train_loader, total=len(train_loader))
  9.     for inputs, labels in tk0:
  10.         inputs, labels = inputs.to(device), labels.to(device)
  11.         
  12.         # 清零梯度
  13.         optimizer.zero_grad()
  14.         
  15.         # 前向传播
  16.         outputs = model(inputs, inputs, inputs)
  17.         
  18.         # 计算损失
  19.         loss = criterion(outputs, labels)
  20.         
  21.         # 反向传播
  22.         loss.backward()
  23.         
  24.         # 更新参数
  25.         optimizer.step()
  26.         
  27.         tk0.set_postfix(loss=loss.item())
复制代码


在训练过程中,确保你传入了正确的`train_loader`,并且`inputs`和`labels`是从数据集中获取的。这样就可以在训练集上进行梯度计算和参数更新了。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-11-3 13:23:45 | 显示全部楼层
FishC 发表于 2024-11-3 13:14
要将上述代码用于训练集,你需要在模型的训练过程中计算梯度并更新模型参数。可以通过以下步骤实现:

1. ...
  1. import gc,os
  2. from tqdm.auto import tqdm # visualizing tool for progress

  3. # They will be used to pick the best model.pt given to the valid loss
  4. best_model_epoch, valid_loss_values = [],[]
  5. valid_loss_min = [1] # arbitrary loss I set here
  6. def train(model,device,train_dataloader,valid_dataloader,epochs,loss_fn,optimizer,metric):

  7.     for epoch in range(epochs):
  8.         gc.collect() # memory cleaning垃圾回收机制,减少占用内存
  9.         model.train()

  10.         train_loss = 0
  11.         train_step = 0
  12.         pbar = tqdm(train_dataloader, total=len(train_dataloader))#tqdm参数是一个iterable

  13.         for batch, labels in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  14.             optimizer.zero_grad() # initialize
  15.             train_step += 1
  16.             

  17.             train_input_ids = batch['input_ids'].to(device)#batch是一个字典
  18.             train_attention_mask = batch['attention_mask'].to(device)
  19.             train_token_type_ids = batch['token_type_ids'].to(device)
  20.             train_labels = labels.squeeze().to(device).long()#label真实值long()转化成一维张量
  21.             
  22.             # You can refer to the class "TweetsModel" for understand
  23.             # what would be logits
  24.             logits = model(train_input_ids, train_attention_mask,train_token_type_ids).to(device)
  25.             predictions = torch.argmax(logits, dim=1) # get an index from larger one
  26.             detached_predictions = predictions.detach().cpu().numpy()
  27.             
  28.             loss = loss_fn(logits, train_labels)
  29.             loss.backward()
  30.             optimizer.step()
  31.             model.zero_grad()

  32.             train_loss += loss.detach().cpu().numpy().item()

  33.             pbar.set_postfix({'train_loss':train_loss/train_step})#设置进度条显示信息
  34.         pbar.close()

  35.         with torch.no_grad():
  36.             model.eval()

  37.             valid_loss = 0
  38.             valid_step = 0
  39.             total_valid_score = 0

  40.             y_pred = [] # for getting f1_score that is a metric of the competition
  41.             y_true = []

  42.             pbar = tqdm(valid_dataloader)
  43.             for batch,labels in pbar:
  44.                 valid_step += 1

  45.                 valid_input_ids = batch['input_ids'].to(device)
  46.                 valid_attention_mask = batch['attention_mask'].to(device)
  47.                 valid_token_type_ids = batch['token_type_ids'].to(device)
  48.                 valid_labels = labels.squeeze().to(device).long()

  49.                 logits = model(valid_input_ids, valid_attention_mask).to(device)
  50.                 predictions = torch.argmax(logits, dim=1)
  51.                 detached_predictions = predictions.detach().cpu().numpy()
  52.                
  53.                 loss = loss_fn(logits, valid_labels)
  54.                 valid_loss += loss.detach().cpu().numpy().item()

  55.                 y_pred.extend(predictions.cpu().numpy())
  56.                 y_true.extend(valid_labels.cpu().numpy())

  57.             valid_loss /= valid_step
  58.             f1 = f1_score(y_true,y_pred)

  59.             print(f'Epoch [{epoch+1}/{epochs}] Score: {f1}')
  60.             print(f'Epoch [{epoch+1}/{epochs}] Valid_loss: {valid_loss}')

  61.             if valid_loss < min(valid_loss_min):
  62.                 print('model improved!')
  63.             else:
  64.                 print('model not improved')
  65.    
  66.             torch.save(model.state_dict(), f'epoch:{epoch+1}_model.pt')#state_dict 是一个字典对象,包含了模型的所有可学习参数(如权重和偏置)及其当前值
  67.             print('save checkpoint!')
  68.             valid_loss_min.append(valid_loss)
  69.             print(f'valid_loss_min:{min(valid_loss_min)}')

  70.         best_model_epoch.append(f'/kaggle/working/epoch:{epoch+1}_model.pt')
  71.         valid_loss_values.append(valid_loss)
  72.         print('='*100)

  73.     select_best_model() # refer to below function
  74.     print('Train/Valid Completed!!')
  75.     del train_dataloader, valid_dataloader # memory cleaning
  76.     gc.collect()

  77. def select_best_model():
  78.     best_model = best_model_epoch[np.array(valid_loss_values).argmin()]
  79.     os.rename(best_model, best_model.split('.pt')[0] + '_best.pt')#重命名文件
复制代码

上述代码train_dataloader要求是什么形状的?batch是一个字典
有如下报错
  1. ---------------------------------------------------------------------------
  2. ValueError                                Traceback (most recent call last)
  3. Cell In[128], line 4
  4.       1 print('Training Start!')
  5.       2 print('=' * 100)
  6. ----> 4 train(model,
  7.       5     device,
  8.       6     train_dataloader,
  9.       7     valid_dataloader,
  10.       8     CFG.epochs,
  11.       9     loss_fn,
  12.      10     optimizer,
  13.      11     metric)
  14.      13 del model,train_dataloader, valid_dataloader
  15.      14 gc.collect()

  16. Cell In[127], line 17, in train(model, device, train_dataloader, valid_dataloader, epochs, loss_fn, optimizer, metric)
  17.      14 train_step = 0
  18.      15 pbar = tqdm(train_dataloader, total=len(train_dataloader))#tqdm参数是一个iterable
  19. ---> 17 for batch, labels in pbar: # you can also write like "for batch in tqdm(train_dataloader"
  20.      18     optimizer.zero_grad() # initialize
  21.      19     train_step += 1

  22. ValueError: too many values to unpack (expected 2)
复制代码
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-18 17:02

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表