Handsome_zhou 发表于 2023-7-26 11:31:17

torch.cat()函数对矩阵进行拼接时的dim区别

本帖最后由 Handsome_zhou 于 2023-7-26 11:32 编辑

torch.cat(tensors, dim=0)定义两个二维数组A, B
对A,B按列(dim=0)和按行(dim=1)进行拼接。

import torch
A = torch.tensor([
    ,
    ,
   
])

print('A 的维度为{}'.format(A.dim()))
print('A 的形状为{}'.format(A.shape))
out:
A 的维度为2
A 的形状为torch.Size()

B = torch.tensor([
    ,
    ,
   
])

print('B 的维度为{}'.format(B.dim()))
print('B 的形状为{}'.format(B.shape))
out:
B 的维度为2
B 的形状为torch.Size()


按列进行拼接:dim=0

resut = torch.cat((A, B), dim=0) # dim=0, 按列进行拼接
print(resut)
out:
tensor([[ 1,2,3],
      [ 4,5,6],
      [ 7,8,9],
      ,
      ,
      ])
print('resut 的维度为{}'.format(resut.dim()))
print('resut 的形状为{}'.format(resut.shape))
out:
resut 的维度为2
resut 的形状为torch.Size()


按行进行拼接:

resut1 = torch.cat((A, B), dim=1) # dim=1, 按行进行拼接
print(resut1)
out:
tensor([[ 1,2,3, 12, 22, 33],
      [ 4,5,6, 44, 55, 66],
      [ 7,8,9, 77, 88, 99]])

print('resut1 的维度为{}'.format(resut1.dim()))
print('resut1 的形状为{}'.format(resut1.shape))
out:
resut1 的维度为2
resut1 的形状为torch.Size()

页: [1]
查看完整版本: torch.cat()函数对矩阵进行拼接时的dim区别