torch.cat()函数对矩阵进行拼接时的dim区别
本帖最后由 Handsome_zhou 于 2023-7-26 11:32 编辑torch.cat(tensors, dim=0)定义两个二维数组A, B
对A,B按列(dim=0)和按行(dim=1)进行拼接。
import torch
A = torch.tensor([
,
,
])
print('A 的维度为{}'.format(A.dim()))
print('A 的形状为{}'.format(A.shape))
out:
A 的维度为2
A 的形状为torch.Size()
B = torch.tensor([
,
,
])
print('B 的维度为{}'.format(B.dim()))
print('B 的形状为{}'.format(B.shape))
out:
B 的维度为2
B 的形状为torch.Size()
按列进行拼接:dim=0
resut = torch.cat((A, B), dim=0) # dim=0, 按列进行拼接
print(resut)
out:
tensor([[ 1,2,3],
[ 4,5,6],
[ 7,8,9],
,
,
])
print('resut 的维度为{}'.format(resut.dim()))
print('resut 的形状为{}'.format(resut.shape))
out:
resut 的维度为2
resut 的形状为torch.Size()
按行进行拼接:
resut1 = torch.cat((A, B), dim=1) # dim=1, 按行进行拼接
print(resut1)
out:
tensor([[ 1,2,3, 12, 22, 33],
[ 4,5,6, 44, 55, 66],
[ 7,8,9, 77, 88, 99]])
print('resut1 的维度为{}'.format(resut1.dim()))
print('resut1 的形状为{}'.format(resut1.shape))
out:
resut1 的维度为2
resut1 的形状为torch.Size()
页:
[1]