鱼C论坛

 找回密码
 立即注册
查看: 795|回复: 0

[学习笔记] torch.cat()函数对矩阵进行拼接时的dim区别

[复制链接]
发表于 2023-7-26 11:31:17 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2023-7-26 11:32 编辑

torch.cat(tensors, dim=0)定义两个二维数组A, B
对A,B按列(dim=0)和按行(dim=1)进行拼接。

  1. import torch
  2. A = torch.tensor([
  3.     [1, 2, 3],
  4.     [4, 5, 6],
  5.     [7, 8, 9]
  6. ])

  7. print('A 的维度为{}'.format(A.dim()))
  8. print('A 的形状为{}'.format(A.shape))
  9. out:
  10. A 的维度为2
  11. A 的形状为torch.Size([3, 3])
复制代码

  1. B = torch.tensor([
  2.     [12, 22, 33],
  3.     [44, 55, 66],
  4.     [77, 88, 99]
  5. ])

  6. print('B 的维度为{}'.format(B.dim()))
  7. print('B 的形状为{}'.format(B.shape))
  8. out:
  9. B 的维度为2
  10. B 的形状为torch.Size([3, 3])
复制代码



按列进行拼接:dim=0

  1. resut = torch.cat((A, B), dim=0) # dim=0, 按列进行拼接
  2. print(resut)
  3. out:
  4. tensor([[ 1,  2,  3],
  5.         [ 4,  5,  6],
  6.         [ 7,  8,  9],
  7.         [12, 22, 33],
  8.         [44, 55, 66],
  9.         [77, 88, 99]])
  10. print('resut 的维度为{}'.format(resut.dim()))
  11. print('resut 的形状为{}'.format(resut.shape))
  12. out:
  13. resut 的维度为2
  14. resut 的形状为torch.Size([6, 3])
复制代码



按行进行拼接:

  1. resut1 = torch.cat((A, B), dim=1) # dim=1, 按行进行拼接
  2. print(resut1)
  3. out:
  4. tensor([[ 1,  2,  3, 12, 22, 33],
  5.         [ 4,  5,  6, 44, 55, 66],
  6.         [ 7,  8,  9, 77, 88, 99]])

  7. print('resut1 的维度为{}'.format(resut1.dim()))
  8. print('resut1 的形状为{}'.format(resut1.shape))
  9. out:
  10. resut1 的维度为2
  11. resut1 的形状为torch.Size([3, 6])
复制代码


想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-6-2 04:04

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表