鱼C论坛

 找回密码
 立即注册
查看: 470|回复: 1

seq2seq训练模型定义看不懂

[复制链接]
发表于 2024-10-12 08:26:12 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
  1. #@save
  2. def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
  3.     """训练序列到序列模型"""
  4.     def xavier_init_weights(m):
  5.         if type(m) == nn.Linear:
  6.             nn.init.xavier_uniform_(m.weight)
  7.         if type(m) == nn.GRU:
  8.             for param in m._flat_weights_names:
  9.                 if "weight" in param:
  10.                     nn.init.xavier_uniform_(m._parameters[param])

  11.     net.apply(xavier_init_weights)
  12.     net.to(device)
  13.     optimizer = torch.optim.Adam(net.parameters(), lr=lr)
  14.     loss = MaskedSoftmaxCELoss()
  15.     net.train()
  16.     animator = d2l.Animator(xlabel='epoch', ylabel='loss',
  17.                      xlim=[10, num_epochs])
  18.     for epoch in range(num_epochs):
  19.         timer = d2l.Timer()
  20.         metric = d2l.Accumulator(2)  # 训练损失总和,词元数量
  21.         for batch in data_iter:
  22.             optimizer.zero_grad()
  23.             X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
  24.             bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
  25.                           device=device).reshape(-1, 1)
  26.             dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学
  27.             Y_hat, _ = net(X, dec_input, X_valid_len)
  28.             l = loss(Y_hat, Y, Y_valid_len)
  29.             l.sum().backward()      # 损失函数的标量进行“反向传播”
  30.             d2l.grad_clipping(net, 1)
  31.             num_tokens = Y_valid_len.sum()
  32.             optimizer.step()
  33.             with torch.no_grad():
  34.                 metric.add(l.sum(), num_tokens)
  35.         if (epoch + 1) % 10 == 0:
  36.             animator.add(epoch + 1, (metric[0] / metric[1],))
  37.     print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
  38.         f'tokens/sec on {str(device)}')
复制代码

问题1:
  1. def xavier_init_weights(m):
  2.         if type(m) == nn.Linear:
  3.             nn.init.xavier_uniform_(m.weight)
  4.         if type(m) == nn.GRU:
  5.             for param in m._flat_weights_names:
  6.                 if "weight" in param:
  7.                     nn.init.xavier_uniform_(m._parameters[param])
复制代码

这段代码的作用是什么?
问题2:
  1. animator = d2l.Animator(xlabel='epoch', ylabel='loss',
  2.                      xlim=[10, num_epochs])
复制代码

animator在干什么?
问题3:
  1. metric = d2l.Accumulator(2)  # 训练损失总和,词元数量
复制代码

metirc、batch的形状是什么样的
问题4:
  1. Y[:, :-1]]
复制代码

为什么要丢弃Y最后一列
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2024-10-12 08:26:23 | 显示全部楼层
针对您的问题,这里是简要的解答:

### 问题1:`xavier_init_weights` 函数的作用
`xavier_init_weights` 函数的作用是对模型中的线性层 (`nn.Linear`) 和 GRU 层 (`nn.GRU`) 的权重参数进行 Xavier(格拉洛福)初始化。这种初始化方法有助于在深层网络中保持信号的传播,避免因初始权重分布不当导致的梯度消失或爆炸问题。

### 问题2:`animator` 的作用
`animator = d2l.Animator` 是用来创建可视化工具,用于动态地绘制训练过程中每个 epoch 的损失变化图。`xlabel='epoch'` 和 `ylabel='loss'` 指定了图表的X轴和Y轴的标签,`xlim` 未给定具体范围,通常用于设置X轴的显示范围。

### 问题3:`metric` 和 `batch` 的形状
- `metric = d2l.Accumulator(2)` 初始化了一个指标累加器,用于累计训练过程中损失的总和和词元的数量。`Accumulator(2)` 表示这个累加器中有两个存储槽位。
- `batch` 的形状根据具体的数据集而定,通常是一个元组,包括输入序列 `X`、输入序列的有效长度 `X_valid_len`、目标序列 `Y`、目标序列的有效长度 `Y_valid_len`。

### 问题4:`dec_input` 中丢弃 `Y` 的最后一列的原因
在构建 `dec_input` 时,通过强制教学策略,将目标序列的开始符(`bos`)与去掉最后一列的目标序列 `Y` 进行拼接。这一步是为了在解码过程中使用实际的目标序列(除去其最后一个标记)作为输入,而非网络预测的输出。这种策略帮助模型更好地学习序列到序列的映射,特别是在训练初期。

如有更多疑问,请继续询问!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-4-20 02:36

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表