鱼C论坛

 找回密码
 立即注册
查看: 2271|回复: 7

自己融合了一个网络模型,在跑的时候出现了错误,希望各位朋友们能给点意见

[复制链接]
发表于 2024-1-19 18:43:12 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
下面是我报错的问题。
  1. Traceback (most recent call last):
  2.   File "D:\traffic\proj\traffic_prediction\traffic_prediction\traffic_prediction.py", line 234, in <module>
  3.     main()
  4.   File "D:\traffic\proj\traffic_prediction\traffic_prediction\traffic_prediction.py", line 155, in main
  5.     predict_value = my_net(data, device).to(torch.device("cpu"))  # [0, 1] -> recover
  6.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  7.     return forward_call(*input, **kwargs)
  8.   File "D:\traffic\proj\traffic_prediction\traffic_prediction\gat.py", line 114, in forward
  9.     prediction = self.subnet(flow, graph).unsqueeze(2)  # [B, N, 1, C]
  10.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  11.     return forward_call(*input, **kwargs)
  12.   File "D:\traffic\proj\traffic_prediction\traffic_prediction\gat.py", line 67, in forward
  13.     outputs = self.multiDimConv(outputs)
  14.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  15.     return forward_call(*input, **kwargs)
  16.   File "D:\traffic\proj\traffic_prediction\traffic_prediction\advs.py", line 34, in forward
  17.     x = self.conv(x)
  18.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  19.     return forward_call(*input, **kwargs)
  20.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\container.py", line 141, in forward
  21.     input = module(input)
  22.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
  23.     return forward_call(*input, **kwargs)
  24.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\conv.py", line 301, in forward
  25.     return self._conv_forward(input, self.weight, self.bias)
  26.   File "D:\anaconda\envs\traffic_pred\lib\site-packages\torch\nn\modules\conv.py", line 298, in _conv_forward
  27.     self.padding, self.dilation, self.groups)
  28. RuntimeError: Given groups=1, weight of size [512, 192, 4], expected input[64, 307, 12] to have 192 channels, but got 307 channels instead

  29. 进程已结束,退出代码1
复制代码

我通过GPT尝试修改代码,但是根据GPT修改后,并没有成功,依旧报的是这个错误。
下面是GPT给我的建议。
根据您提供的错误信息,问题依然是第一个卷积层的权重和输入数据的通道数不匹配。错误信息显示,尽管您已经将所有卷积层改为一维卷积,但第一个卷积层的定义仍然是 Conv1d(192, 512, kernel_size=(4,), stride=(2,), padding=(1,)),这意味着它期望输入数据有192个通道,而实际上输入数据有307个通道。

要解决这个问题,请确保第一个卷积层的输入通道数正确设置为307。根据您的模型结构,您需要在 MultiShapeConv 类中做如下修改:

  1. python
  2. Copy code
  3. class MultiShapeConv(nn.Module):
  4.     # 其他代码保持不变

  5.     def __init__(self, ...):
  6.         # ... 其他初始化代码
  7.         self.conv = nn.Sequential(
  8.             nn.Conv1d(307, 512, kernel_size=(4,), stride=(2,), padding=(1,)), # 修改这里
  9.             nn.BatchNorm1d(512),
  10.             nn.ReLU(inplace=True),
  11.             # ... 其他卷积层保持不变
  12.         )

  13.     # ... 其他方法保持不变
复制代码

在这里,我将 MultiShapeConv 中第一个卷积层的输入通道数从192更改为307。请确保进行这一修改后重新训练模型,因为网络结构的改变会影响模型的性能。

如果您在修改和训练模型后仍然遇到问题,或者需要更多帮助,请随时告知。
小甲鱼最新课程 -> https://ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2024-1-19 18:44:32 | 显示全部楼层
FishC 发表于 2024-1-19 18:44
根据您提供的错误信息和GPT的建议,问题出在您的第一个卷积层上。根据错误信息,该卷积层期望输入数据有192 ...

这个办法我尝试了,没有用,是否还有其他的解决办法
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-1-19 19:32:07 | 显示全部楼层
AFUL 发表于 2024-1-19 19:21
错误信息你不看吗?
“Given groups=1, weight of size [512, 192, 4], expected input[64, 307, 12] to h ...

是有看到的,但是尺寸我并不想改,所以出了问题
小甲鱼最新课程 -> https://ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-11-7 02:57

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表