鱼C论坛

 找回密码
 立即注册
查看: 1813|回复: 2

在人家的代码中我想加加一个LSTM模型测试一下结果,但是加上之后出现了问题。

[复制链接]
发表于 2023-9-27 19:17:02 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
这个是LSTM模型
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import numpy as np

  5. class LSTM(nn.Module):
  6.     def __init__(self, input_num, hid_num, layers_num, out_num, batch_first=True):
  7.         super().__init__()
  8.         self.l1 = nn.LSTM(input_size=input_num,hidden_size=hid_num,num_layers=layers_num,batch_first=batch_first)
  9.         self.out = nn.Linear(hid_num,out_num)

  10.     def forward(self,data):
  11.         flow_x = data['flow_x'] #B*T*D
  12.         l_out,(h_n, c_n) = self.l1(flow_x,None) #None表示第一次 hidden_state是0
  13.         print(l_out[:, -1, :])
  14.         out = self.out(l_out[:, -1, :])
  15.         return out
复制代码


这个是我的代码
  1. import os
  2. import time
  3. import h5py
  4. import torch
  5. import numpy as np
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import torch.nn.functional as F
  9. from torch.utils.data import DataLoader

  10. from traffic_dataset import LoadData
  11. from utils import Evaluation
  12. from utils import visualize_result
  13. from chebnet import ChebNet
  14. from gat import GATNet
  15. from lstm import LSTM

  16. class Baseline(nn.Module):
  17.     def __init__(self, in_c, out_c):
  18.         super(Baseline, self).__init__()
  19.         self.layer = nn.Linear(in_c, out_c)

  20.     def forward(self, data, device):
  21.         flow_x = data["flow_x"].to(device)  # [B, N, H, D]

  22.         B, N = flow_x.size(0), flow_x.size(1)

  23.         flow_x = flow_x.view(B, N, -1)  # [B, N, H*D]  H = 6, D = 1

  24.         output = self.layer(flow_x)  # [B, N, Out_C], Out_C = D

  25.         return output.unsqueeze(2)  # [B, N, 1, D=Out_C]


  26. def main():
  27.     os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  28.     # Loading Dataset

  29.     train_data = LoadData(data_path=["PEMS08/PEMS08.csv", "PEMS08/PEMS08.npz"], num_nodes=170, divide_days=[46, 16],
  30.                           time_interval=5, history_length=6,
  31.                           train_mode="train")
  32.     train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=8)

  33.     test_data = LoadData(data_path=["PEMS08/PEMS08.csv", "PEMS08/PEMS08.npz"], num_nodes=170, divide_days=[46, 16],
  34.                          time_interval=5, history_length=6,
  35.                          train_mode="test")
  36.     test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=8)

  37.     # Loading Model
  38.     # my_net = GATNet(in_c=6 * 1, hid_c=6, out_c=1, n_heads=2)
  39.     # my_net = GATNet(in_c=6 * 1, hid_c=6, out_c=1, n_heads=2, lstm_hidden_dim=1)
  40.     # my_net = GCN(in_c=6,hid_c=6,out_c=1)
  41.     # my_net = ChebNet(in_c=6, hid_c=6, out_c=1, K=5)
  42.     my_net = LSTM(input_num=6,hid_num=6,layers_num=3,out_num=1)

  43.     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  44.     my_net = my_net.to(device)

  45.     criterion = nn.MSELoss()

  46.     optimizer = optim.Adam(params=my_net.parameters())

  47.     # Train model
  48.     Epoch = 100

  49.     my_net.train()
  50.     for epoch in range(Epoch):
  51.         epoch_loss = 0.0
  52.         start_time = time.time()
  53.         for data in train_loader:  # ["graph": [B, N, N] , "flow_x": [B, N, H, D], "flow_y": [B, N, 1, D]]
  54.             my_net.zero_grad()

  55.             predict_value = my_net(data, device).to(torch.device("cpu"))  # [0, 1] -> recover

  56.             loss = criterion(predict_value, data["flow_y"])

  57.             epoch_loss += loss.item()

  58.             loss.backward()

  59.             optimizer.step()
  60.         end_time = time.time()

  61.         print("Epoch: {:04d}, Loss: {:02.4f}, Time: {:02.2f} mins".format(epoch, 1000 * epoch_loss / len(train_data),
  62.                                                                           (end_time-start_time)/60))

  63.     # Test Model
  64.     my_net.eval()
  65.     with torch.no_grad():
  66.         MAE, MAPE, RMSE = [], [], []
  67.         # Target = np.zeros([307, 1, 1]) # [N, 1, D]
  68.         Target = np.zeros([170, 1, 1]) # [N, 1, D]

  69.         Predict = np.zeros_like(Target)  #[N, T, D]

  70.         total_loss = 0.0
  71.         for data in test_loader:

  72.             predict_value = my_net(data, device).to(torch.device("cpu"))  # [B, N, 1, D]  -> [1, N, B(T), D]

  73.             loss = criterion(predict_value, data["flow_y"])

  74.             total_loss += loss.item()

  75.             predict_value = predict_value.transpose(0, 2).squeeze(0)  # [1, N, B(T), D] -> [N, B(T), D] -> [N, T, D]
  76.             target_value = data["flow_y"].transpose(0, 2).squeeze(0)  # [1, N, B(T), D] -> [N, B(T), D] -> [N, T, D]

  77.             performance, data_to_save = compute_performance(predict_value, target_value, test_loader)

  78.             Predict = np.concatenate([Predict, data_to_save[0]], axis=1)
  79.             Target = np.concatenate([Target, data_to_save[1]], axis=1)

  80.             MAE.append(performance[0])
  81.             MAPE.append(performance[1])
  82.             RMSE.append(performance[2])

  83.         print("Test Loss: {:02.4f}".format(1000 * total_loss / len(test_data)))

  84.     print("Performance:  MAE {:2.2f}   MAPE {:2.2f}%  RMSE  {:2.2f}".format(np.mean(MAE), np.mean(MAPE * 100), np.mean(RMSE)))

  85.     Predict = np.delete(Predict, 0, axis=1)
  86.     Target = np.delete(Target, 0, axis=1)

  87.     result_file = "GAT_result.h5"
  88.     file_obj = h5py.File(result_file, "w")

  89.     file_obj["predict"] = Predict
  90.     file_obj["target"] = Target


  91. def compute_performance(prediction, target, data):
  92.     try:
  93.         dataset = data.dataset  # dataloader
  94.     except:
  95.         dataset = data  # dataset

  96.     prediction = LoadData.recover_data(dataset.flow_norm[0], dataset.flow_norm[1], prediction.numpy())
  97.     target = LoadData.recover_data(dataset.flow_norm[0], dataset.flow_norm[1], target.numpy())

  98.     mae, mape, rmse = Evaluation.total(target.reshape(-1), prediction.reshape(-1))

  99.     performance = [mae, mape, rmse]
  100.     recovered_data = [prediction, target]

  101.     return performance, recovered_data


  102. if __name__ == '__main__':
  103.     main()
  104.     # visualize_result(h5_file="GAT_result.h5",
  105.     #                  nodes_id=120,
  106.     #                  time_se=[0, 24 * 12 * 2],
  107.     #                  visualize_file="gat_node_120")

复制代码


这个是我最后报错的情况

  1. Traceback (most recent call last):
  2.   File "D:\traffic\traffic_prediction\traffic_prediction.py", line 196, in <module>
  3.     main()
  4.   File "D:\traffic\traffic_prediction\traffic_prediction.py", line 120, in main
  5.     predict_value = my_net(data, device).to(torch.device("cpu"))  # [0, 1] -> recover
  6.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  7.     return forward_call(*input, **kwargs)
  8. TypeError: forward() takes 2 positional arguments but 3 were given

  9. 进程已结束,退出代码1
复制代码



我想知道的是,我应该如何做才能解决这个问题,希望有人教我一下,本人没有学过类似的东西,第一次接触。

本帖被以下淘专辑推荐:

小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

发表于 2023-9-27 19:17:11 | 显示全部楼层
问题在于您的LSTM类中的forward方法需要一个参数,但您在main函数中调用my_net(data, device)时传递了两个参数。解决这个问题的方法是将数据和设备作为一个元组传递给my_net,如下所示:

  1. predict_value = my_net((data, device)).to(torch.device("cpu"))
复制代码


然后在LSTM类中的forward方法中进行解包:

  1. def forward(self, data_device_tuple):
  2.     data, device = data_device_tuple
  3.     flow_x = data['flow_x'].to(device)
  4.     ...
复制代码


这样就可以解决这个问题了。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2023-9-28 18:43:34 | 显示全部楼层
没有解决,不过应该是我数据处理的问题,很难过,但是没有办法。
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-6-25 04:48

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表