鱼C论坛

 找回密码
 立即注册
查看: 1176|回复: 0

[学习笔记] torch.cat()函数对矩阵进行拼接时的dim区别

[复制链接]
发表于 2023-7-26 11:31:17 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 Handsome_zhou 于 2023-7-26 11:32 编辑

torch.cat(tensors, dim=0)定义两个二维数组A, B
对A,B按列(dim=0)和按行(dim=1)进行拼接。
import torch
A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

print('A 的维度为{}'.format(A.dim()))
print('A 的形状为{}'.format(A.shape))
out:
A 的维度为2
A 的形状为torch.Size([3, 3])
B = torch.tensor([
    [12, 22, 33],
    [44, 55, 66],
    [77, 88, 99]
])

print('B 的维度为{}'.format(B.dim()))
print('B 的形状为{}'.format(B.shape))
out:
B 的维度为2
B 的形状为torch.Size([3, 3])


按列进行拼接:dim=0
resut = torch.cat((A, B), dim=0) # dim=0, 按列进行拼接
print(resut)
out:
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [12, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
print('resut 的维度为{}'.format(resut.dim()))
print('resut 的形状为{}'.format(resut.shape))
out:
resut 的维度为2
resut 的形状为torch.Size([6, 3])


按行进行拼接:
resut1 = torch.cat((A, B), dim=1) # dim=1, 按行进行拼接
print(resut1)
out:
tensor([[ 1,  2,  3, 12, 22, 33],
        [ 4,  5,  6, 44, 55, 66],
        [ 7,  8,  9, 77, 88, 99]])

print('resut1 的维度为{}'.format(resut1.dim()))
print('resut1 的形状为{}'.format(resut1.shape))
out:
resut1 的维度为2
resut1 的形状为torch.Size([3, 6])

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-12-23 14:14

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表