torch里面sum函数不太明白
A = torch.arange(20*2).reshape(2,5,4)A_sum_axis1 = A.sum(axis=1)
为什么这个A_sum_axis1是tensor([[ 40,45,50,55],])?
A是tensor
([[[ 0,1,2,3],
[ 4,5,6,7],
[ 8,9, 10, 11],
,
],
[,
,
,
,
]])
看看这个有没有帮助
https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
页:
[1]