鱼C论坛

 找回密码
 立即注册
查看: 1368|回复: 1

我的模型在多轮训练时报错index out of range,求大佬捞我一手,我在这里卡了好久了

[复制链接]
发表于 2021-1-31 11:59:45 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本人刚刚开始学seq2seq的模型,看了很多文章以后开始自己写,但是在训练的时候出了一些问题实在解决不了,来此求助。
这个是我训练的部分:
EPOCHS = 10
for epoch in range(EPOCHS):
    start = time.time()
    total_loss = 0
    for step, (eng, chn) in enumerate(ce_data_loader):
        loss = 0
        enc_output, enc_hidden = encoder(eng)
        dec_hidden = enc_hidden

        dec_input = torch.tensor([word2id['<bos>']] * 32).view(-1, 1)

        for t in range(1, chn.size(1)):
            predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output.float())
            # print(predictions)
            chn_data = chn_data.float()

            dec_input = predictions
            dec_input = dec_input.long()

            loss += criterion(predictions, chn[:, t].long())

            if dec_input == '<eos>':
                break
        batch_loss = (loss / int(chn.size(1)))
        total_loss += batch_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 32 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                         step,
                                                         batch_loss.detach().item()))

    print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                        total_loss))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
这个是我decoder模型的部分:
class AtteDecoder(nn.Module):
    def __init__(self, output_size, hidden_size, vocab_size, drop_out=0.1, max_length=30):
        super(AtteDecoder, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.drop_out = drop_out
        self.max_length = max_length
        self.embedding = nn.Embedding(self.output_size, 300)
        self.gru = nn.GRU(self.hidden_size + 300, self.hidden_size)
        self.fc = nn.Linear(self.hidden_size, self.output_size)
        # Attention部分
        self.W1 = nn.Linear(self.hidden_size, self.hidden_size)
        self.W2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.V = nn.Linear(self.hidden_size, 1)

    def forward(self, inputs, hidden, encoder_outputs):
        # print(inputs.shape)
        # (max_len. batch, hidden)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # (batch, max_len, hidden)
        hidden_time = hidden.permute(1, 0, 2)
        # (1, batch, hidden) -> # (batch, 1, hidden)
        score = torch.tanh(self.W1(encoder_outputs) + self.W2(hidden_time))
        attention_weights = torch.softmax(self.V(score), dim=1)
        context_vector = attention_weights * encoder_outputs
        context_vector = torch.sum(context_vector, dim=1)
        inputs = self.embedding(inputs)
        inputs = torch.cat((context_vector.unsqueeze(1), inputs), -1)
        output, state = self.gru(inputs)
        output = output.view(-1, output.size(2))
        x = self.fc(output)
        return x, state, attention_weights

    def init_hidden(self):
        return torch.zeros(1, batch_size, self.output_size)
我在单次测试encoder和decoder的时候都没有问题,其实就相当于跑了一个single batch,可以跑通,每个步骤输出的shape打印出来也没有问题。但是在多个epoch训练的时候,在predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output.float())就会报错,因为dec_input会经过nn.embedding层,这个时候就会报错index out of range。但不是一进去就会报错,我在loss+后面随便打印个东西,会输出两次,意思就是两次就会报出这个错误。有没有大佬帮我看一下这个错误是我哪个部分引起的。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2021-1-31 18:48:37 From FishC Mobile | 显示全部楼层
怕是序列越界了吧,代码有亿点点多a,不大想看了 qwq
我再看看哪些地方有涉及序列访问的代码,单独拿出来跑一下,或者自己先拿一个数据比较小的,打上断点自己看看
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-1-16 15:50

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表