鱼C论坛

 找回密码
 立即注册
查看: 87|回复: 1

seq2seq训练模型定义看不懂

[复制链接]
发表于 2024-10-12 08:26:12 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
#@save
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):
    """训练序列到序列模型"""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])

    net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                     xlim=[10, num_epochs])
    for epoch in range(num_epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)  # 训练损失总和,词元数量
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                          device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学
            Y_hat, _ = net(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()      # 损失函数的标量进行“反向传播”
            d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            with torch.no_grad():
                metric.add(l.sum(), num_tokens)
        if (epoch + 1) % 10 == 0:
            animator.add(epoch + 1, (metric[0] / metric[1],))
    print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} '
        f'tokens/sec on {str(device)}')
问题1:
def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
这段代码的作用是什么?
问题2:
animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                     xlim=[10, num_epochs])
animator在干什么?
问题3:
metric = d2l.Accumulator(2)  # 训练损失总和,词元数量
metirc、batch的形状是什么样的
问题4:
Y[:, :-1]]
为什么要丢弃Y最后一列
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2024-10-12 08:26:23 | 显示全部楼层
针对您的问题,这里是简要的解答:

### 问题1:`xavier_init_weights` 函数的作用
`xavier_init_weights` 函数的作用是对模型中的线性层 (`nn.Linear`) 和 GRU 层 (`nn.GRU`) 的权重参数进行 Xavier(格拉洛福)初始化。这种初始化方法有助于在深层网络中保持信号的传播,避免因初始权重分布不当导致的梯度消失或爆炸问题。

### 问题2:`animator` 的作用
`animator = d2l.Animator` 是用来创建可视化工具,用于动态地绘制训练过程中每个 epoch 的损失变化图。`xlabel='epoch'` 和 `ylabel='loss'` 指定了图表的X轴和Y轴的标签,`xlim` 未给定具体范围,通常用于设置X轴的显示范围。

### 问题3:`metric` 和 `batch` 的形状
- `metric = d2l.Accumulator(2)` 初始化了一个指标累加器,用于累计训练过程中损失的总和和词元的数量。`Accumulator(2)` 表示这个累加器中有两个存储槽位。
- `batch` 的形状根据具体的数据集而定,通常是一个元组,包括输入序列 `X`、输入序列的有效长度 `X_valid_len`、目标序列 `Y`、目标序列的有效长度 `Y_valid_len`。

### 问题4:`dec_input` 中丢弃 `Y` 的最后一列的原因
在构建 `dec_input` 时,通过强制教学策略,将目标序列的开始符(`bos`)与去掉最后一列的目标序列 `Y` 进行拼接。这一步是为了在解码过程中使用实际的目标序列(除去其最后一个标记)作为输入,而非网络预测的输出。这种策略帮助模型更好地学习序列到序列的映射,特别是在训练初期。

如有更多疑问,请继续询问!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-16 06:50

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表