鱼C论坛

 找回密码
 立即注册
查看: 82|回复: 1

gpu训练神经网络时报错

[复制链接]
发表于 2024-10-26 23:11:05 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
#plus 1 to avoid log(0)
def MSLE_loss(pred, target):
    log_pred = torch.log(pred + 1)  
    log_target = torch.log(target + 1)  
    loss = nn.MSELoss()(log_pred, log_target)  
    return loss

model = Net().cuda()
optimizer = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-3)
for epoch in range(3500):
    pred = model(X_train)
    pred = pred.squeeze()
    loss = MSLE_loss(pred,y_train)
    
    if epoch % 500 == 0 :
        print(loss)
        print(pred)
        print(y_train)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

报错如下
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[54], line 11
      9 optimizer = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-3)
     10 for epoch in range(3500):
---> 11     pred = model(X_train)
     12     pred = pred.squeeze()
     13     loss = MSLE_loss(pred,y_train)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

Cell In[45], line 23, in Net.forward(self, x)
     21 def forward(self,x):
     22     inputs = x
---> 23     layer1_out = torch.nn.functional.gelu(self.hidden_layer1(inputs))
     24     layer1_out = self.dropout(layer1_out)
     25     layer2_out = torch.nn.functional.gelu(self.hidden_layer2(layer1_out))

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:117, in Linear.forward(self, input)
    116 def forward(self, input: Tensor) -> Tensor:
--> 117     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2024-10-26 23:11:10 | 显示全部楼层
报错显示您的张量在不同设备上:有的在GPU上(cuda:0),有的在CPU上。解决办法如下:

1. **将数据移动到GPU**:确保`X_train`和`y_train`在使用时都在GPU上。使用`.cuda()`方法将它们移到GPU上。例如:

X_train = X_train.cuda()
   y_train = y_train.cuda()

2. **检查模型参数和其他张量**:确保模型的所有参数和输入张量都在同一设备上。可以在定义模型和数据加载时直接使用`.to(device)`,其中`device`是定义的计算设备,如:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   model = Net().to(device)
   X_train = X_train.to(device)
   y_train = y_train.to(device)

确保所有的计算都在同一设备上,将解决这个错误。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-21 19:53

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表