鱼C论坛

 找回密码
 立即注册
查看: 181|回复: 2

detach()和stack有什么作用

[复制链接]
发表于 2024-10-30 17:35:29 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
#模型期望输入的是PyTorch 张量,先将 DataFrame 转换为张量

X_test = X_test.astype(np.float32)# 将布尔型转为整数型
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)# 转换为 PyTorch 张量

preds = []

for i in range(30):
    test_output = dnn_model(X_test_tensor)
    test_output = test_output.ravel() # 将张量转换为 numpy一维数组
    preds.append(test_output)
pred = torch.mean(torch.stack(preds),dim=0).detach()
print(pred)
submission = pd.DataFrame({'Id': test_dataset.Id, 'SalePrice': pred})
submission.to_csv('submission.csv', index=False)
上述代码中为什么要用detach和stack
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2024-10-30 17:40:41 | 显示全部楼层
在这段代码中,`detach()` 和 `stack()` 的作用如下:

1. **`stack()`**:
   - `stack()` 函数用于将多个张量沿着一个新的维度进行拼接。在这段代码中,`preds` 是一个包含多个张量的列表。通过 `torch.stack(preds)`,这些张量被沿着一个新的维度堆叠成一个新的张量。这样可以方便地对这些张量进行批量操作,比如计算平均值。

2. **`detach()`**:
   - `detach()` 函数用于从当前计算图中分离张量。它返回一个新的张量,该张量与原始张量共享数据,但不再参与梯度计算。在这段代码中,`detach()` 用于确保在计算 `pred` 时,不会对 `dnn_model` 的参数进行梯度更新。这通常用于推理阶段,以避免不必要的计算开销。

总结来说,`stack()` 是为了将多个预测结果组合在一起,方便后续的平均计算,而 `detach()` 是为了确保在推理时不计算梯度。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-10-30 19:17:02 | 显示全部楼层
tack() 用于将多个预测结果组合成一个张量,而 detach() 用于确保这个张量不会参与梯度计算
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-12-4 01:20

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表