鱼C论坛

 找回密码
 立即注册
查看: 893|回复: 5

deberta训练模型

[复制链接]
发表于 2024-11-2 21:41:49 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. #定义训练模型
  2. class LLMModel(nn.Module):
  3.     def __init__(self, cfg, config_path=None, pretrained=False):
  4.         super().__init__()
  5.         self.cfg = cfg
  6.         if config_path is None:
  7.             self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)#
  8.             self.config.hidden_dropout = 0.
  9.             self.config.hidden_dropout_prob = 0.
  10.             self.config.attention_dropout = 0.
  11.             self.config.attention_probs_dropout_prob = 0.
  12.             self.config.add_pooling_layer = False
  13.         else:
  14.             self.config = torch.load(config_path)
  15.         if pretrained:
  16.             self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
  17.         else:
  18.             self.model = AutoModel.from_config(self.config)
  19.         
  20.         self.model.resize_token_embeddings(len(tokenizer))#
  21.         if self.cfg.gradient_checkpointing:#使用梯度检查点技术时,模型不会保存前向传播计算的中间结果,减少占用内存
  22.             self.model.gradient_checkpointing_enable()
  23.         self.fc = nn.Linear(self.config.hidden_size, self.cfg.num_labels)
  24.         self._init_weights(self.fc)
  25.         
  26.     def _init_weights(self, module):
  27.         if isinstance(module, nn.Linear):
  28.             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  29.             if module.bias is not None:
  30.                 module.bias.data.zero_()
  31.         elif isinstance(module, nn.Embedding):
  32.             module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  33.             if module.padding_idx is not None:
  34.                 module.weight.data[module.padding_idx].zero_()
  35.         elif isinstance(module, nn.LayerNorm):
  36.             module.bias.data.zero_()
  37.             module.weight.data.fill_(1.0)
  38.         
  39.     def feature(self, input_ids, attention_mask, token_type_ids):
  40.         outputs = self.model( input_ids, attention_mask, token_type_ids)
  41.         last_hidden_states = outputs[0] #encoder最后一个隐藏状态的输出传递给decoder做cross attention
  42.         feature = last_hidden_states[:, 0, :] ## CLS token
  43.         return feature

  44.     def forward(self,  input_ids, attention_mask, token_type_ids):#attention_mask避免用注意力机制的时候关注到填充符
  45.         feature = self.feature( input_ids, attention_mask, token_type_ids)
  46.         output = self.fc(feature)
  47.         return output.squeeze(-1)
复制代码


其中,
  1.     def feature(self, input_ids, attention_mask, token_type_ids):
  2.         outputs = self.model( input_ids, attention_mask, token_type_ids)
  3.         last_hidden_states = outputs[0] #encoder最后一个隐藏状态的输出传递给decoder做cross attention
  4.         feature = last_hidden_states[:, 0, :] ## CLS token
  5.         return feature

  6.     def forward(self,  input_ids, attention_mask, token_type_ids):#attention_mask避免用注意力机制的时候关注到填充符
  7.         feature = self.feature( input_ids, attention_mask, token_type_ids)
  8.         output = self.fc(feature)
  9.         return output.squeeze(-1)
复制代码


feature函数最终返回的是什么形状的张量?然后将其作为最后输出全连接层的输入
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2024-11-2 22:15:46 | 显示全部楼层
FishC 发表于 2024-11-2 21:42
在您的代码中,`feature`函数返回的是`self.model`的输出,即`last_hidden_states`。根据Hugging Face的Tra ...

上述模型的decoder部分体现在哪里
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-11-2 22:43:58 | 显示全部楼层
FishC 发表于 2024-11-2 22:27
在您提供的代码中,并没有明确的decoder部分。这个模型主要是一个encoder模型,通常用于特征提取或分类 ...
  1. self.model.resize_token_embeddings(len(tokenizer))
复制代码

这句代码如何解释
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-10-9 04:39

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表