马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 Handsome_zhou 于 2023-7-26 11:32 编辑
torch.cat(tensors, dim=0)定义两个二维数组A, B
对A,B按列(dim=0)和按行(dim=1)进行拼接。
import torch
A = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
print('A 的维度为{}'.format(A.dim()))
print('A 的形状为{}'.format(A.shape))
out:
A 的维度为2
A 的形状为torch.Size([3, 3])
B = torch.tensor([
[12, 22, 33],
[44, 55, 66],
[77, 88, 99]
])
print('B 的维度为{}'.format(B.dim()))
print('B 的形状为{}'.format(B.shape))
out:
B 的维度为2
B 的形状为torch.Size([3, 3])
按列进行拼接:dim=0
resut = torch.cat((A, B), dim=0) # dim=0, 按列进行拼接
print(resut)
out:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[12, 22, 33],
[44, 55, 66],
[77, 88, 99]])
print('resut 的维度为{}'.format(resut.dim()))
print('resut 的形状为{}'.format(resut.shape))
out:
resut 的维度为2
resut 的形状为torch.Size([6, 3])
按行进行拼接:
resut1 = torch.cat((A, B), dim=1) # dim=1, 按行进行拼接
print(resut1)
out:
tensor([[ 1, 2, 3, 12, 22, 33],
[ 4, 5, 6, 44, 55, 66],
[ 7, 8, 9, 77, 88, 99]])
print('resut1 的维度为{}'.format(resut1.dim()))
print('resut1 的形状为{}'.format(resut1.shape))
out:
resut1 的维度为2
resut1 的形状为torch.Size([3, 6])
|