torch中维度操作总结(repeat,squeeze,unsqueeze,flatten,transpose)

打印 上一主题 下一主题

主题 886|帖子 886|积分 2658

repeat 函数

1.repeat参数个数与tensor向量维数一致

  1. a = torch.tensor([[1, 2, 3],
  2.                   [1, 2, 3]])
  3. b = a.repeat(2, 2)
  4. print(b.shape)
复制代码
效果为:
  1. torch.Size([4,6])
复制代码
即repeat的参数是对应维度的复制个数,上段代码为0维复制两次,1维复制两次,则得到以上运行效果。其余扩展情况依此类推。
2.repeat参数个数与tensor向量维数不一致

在参数个数大于原tensor维度个数时,总是先在第0维扩展一个维数为1的维度,然后按照参数指定的复制次数进行复制。计算输出的形状时,可以按照 对应参数*对应维度维数 得到效果
  1. # a形状(2,3)
  2. a = torch.tensor([[1, 2, 3],
  3.                   [1, 2, 3]])
  4. # repeat参数比维度多
  5. # 首先在第0维扩展一个维度,维数为1,然后按照参数指定的次数进行复制
  6. # 在扩展前先将a的形状扩展为(1,2,3)然后复制
  7. b = a.repeat(1, 2, 1)
  8. print(b.shape)  # 得到结果torch.Size([1, 4, 3])
复制代码
  1. # a形状(2,3)
  2. a = torch.tensor([[1, 2, 3],
  3.                   [1, 2, 3]])
  4. # repeat参数比维度多,在扩展前先将a的形状扩展为(1,2,3)然后复制
  5. b = a.repeat(1, 1, 2)
  6. print(b.shape)  # 得到结果torch.Size([1, 2, 6])
复制代码
  1. # a形状(2,3)
  2. a = torch.tensor([[1, 2, 3],
  3.                   [1, 2, 3]])
  4. # repeat参数比维度多,在扩展前先将a的形状扩展为(1,2,3)然后复制
  5. b = a.repeat(2, 1, 1)
  6. print(b.shape)  # 得到结果torch.Size([2, 2, 3])
复制代码
squeeze 函数

  1. torch.squeeze(A,N)
复制代码
torch.unsqueeze()函数:减少数组A指定位置N的维度。
如果不指定位置参数N,如果数组A的维度为(1,1,3)。
如果指定位置参数,执行 torch.squeeze(A,1) ,A的维度变为 (1,3),中心的维度被删除。
注:

  • 如果指定的维度大于1,那么将操作无效
  • 如果不指定维度N,那么将删除全部维度为1的维度
  1. a=torch.randn(1,1,3)
  2. print(a.shape) # torch.Size([1, 1, 3])
  3. b=torch.squeeze(a)
  4. print(b.shape)        # torch.Size([3])
  5. c=torch.squeeze(a,0)
  6. print(c.shape)  # torch.Size([1, 3])
  7. d=torch.squeeze(a,1)
  8. print(d.shape)        # torch.Size([1, 3])
  9. e=torch.squeeze(a,2)#如果去掉第三维,则数不够放了,所以直接保留
  10. print(e.shape)        # torch.Size([1, 1, 3])
复制代码
unsqueeze 函数

  1. torch.unsqueeze(A,N)
复制代码
torch.unsqueeze()函数:增加数组A指定位置N的维度。
两行三列的数组A维度为(2,3),那么这个数组就有三个位置可以增加维度,分别是
  1. ([位置0], 2,[位置1], 3, [位置2])
  2. 或者
  3. ( [位置-3] ,2,[位置-2], 3 ,[位置-1] )
复制代码
如果执行 torch.unsqueeze(A,1),数据的维度就变为了 (2,1,3)
  1. a=torch.randn(1,3)
  2. print(a.shape)        # torch.Size([1, 3])
  3. b=torch.unsqueeze(a,0)
  4. print(b.shape)        # torch.Size([1, 1, 3])
  5. c=torch.unsqueeze(a,1)
  6. print(c.shape)        # torch.Size([1, 1, 3])
  7. d=torch.unsqueeze(a,2)
  8. print(d.shape)        # torch.Size([1, 3, 1])
复制代码
flatten 函数

flatten() 是对多维数据的降维函数。
flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
python里的flatten(dim)表示,从第dim个维度开始睁开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
  1. import torch
  2. a = torch.rand(2,3,4)
  3. print(a.shape) # torch.Size([2, 3, 4])
  4. b = a.flatten()
  5. print(b.shape)  # torch.Size([24])
  6. c = a.flatten(0)
  7. print(c.shape)  # torch.Size([24])
  8. d = a.flatten(1)
  9. print(d.shape)  # torch.Size([2, 12])
  10. e = a.flatten(2)
  11. print(e.shape)         # torch.Size([2, 3, 4])
复制代码
transpose函数

二维数组

  1. import numpy as np
  2. X=np.arange(6).reshape((2,3))
  3. print(X)
  4. #[[0 1 2]
  5. # [3 4 5]]
  6. print(X.transpose())
  7. #[[0 3]
  8. # [1 4]
  9. # [2 5]]
  10. print(X.T)
  11. #[[0 3]
  12. # [1 4]
  13. # [2 5]]
复制代码
多维数组

  1. x=np.arange(24).reshape((2,3,4))
  2. print(x.shape)
  3. y = x.transpose((0,1,2))
  4. print(y.shape)
  5. y = x.transpose((0,2,1))
  6. print(y.shape)
  7. y = x.transpose((2,1,0))
  8. print(y.shape)
  9. #(2, 3, 4)
  10. #(2, 3, 4)
  11. #(2, 4, 3)
  12. #(4, 3, 2)
复制代码
参考网址

https://blog.csdn.net/tequila53/article/details/119183678
https://blog.csdn.net/kuan__/article/details/116987162
说明

说明如下,如有侵权,非常抱歉,可联系本人删除对应内容。
会根据平时使用不断更新博客内容。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

正序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

王柳

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表