pytorch张量列表索引和多维度张量索引比较

打印 上一主题 下一主题

主题 1877|帖子 1877|积分 5631

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
pytorch张量的高级索引取值原明白读
代码:

  1. import torch
  2. x = torch.tensor([[10, 20, 30], [40, 50, 60]])
  3. x1 = x[[[0, 1], [1, 0]]]
  4. x2 = x[torch.tensor([[0, 1], [1, 0]])]
  5. print(f"x1:{x1}")
  6. print(f"x2:{x2}")
复制代码
输出:
  1. x1:tensor([20, 40])
  2. x2:tensor([[[10, 20, 30],
  3.          [40, 50, 60]],
  4.         [[40, 50, 60],
  5.          [10, 20, 30]]])
复制代码
代码解读:

张量 x是一个 2x3 的张量:
x1 的取值
  1. x1 = x[[[0, 1], [1, 0]]]
复制代码


  • 索引机制: 这里的索引 [[0, 1], [1, 0]] 是 高级整数索引

    • 它取的是第 1 维的具体位置。

  • 步骤

    • x[[0, 1], [1, 0]] 等价于以下操纵:

      • x[0, 1] -> 20
      • x[1, 0] -> 40


因此:
  1. x1 = [20, 40]
复制代码
        注:x[[[0, 1], [1, 0]]] 效果同 x[[0, 1], [1, 0]]
x2 的取值
  1. x2 = x[torch.tensor([[0, 1], [1, 0]])]
  2. ### 复杂索引,在0维和1维度都取
  3. #x3 = x[torch.tensor([[0, 1], [1, 0]]),torch.tensor([[0, 1], [1, 0]])]
  4. #print(f"x3:{x3}")
  5. #x 3:tensor([[10, 50],
  6. #        [50, 10]])
  7. #print(f"x3.shape:{x3.shape}")   # x3.shape:torch.Size([2, 2])
复制代码


  • 索引机制: 这里的索引 torch.tensor([[0, 1], [1, 0]]) 是 多维整形张量索引

    • 这种索引会在第 0 维上按张量的外形举行广播

  • 广播行为

    • 索引张量的外形是 (2, 2)。
    • PyTorch 会沿第 0 维取出对应的行,并按照索引效果重新分列。

  • 步骤

    • x[0] -> [10, 20, 30]
    • x[1] -> [40, 50, 60]
    根据索引张量 [[0, 1], [1, 0]],效果分列为:

  1. [[[10, 20, 30],  # 对应索引 (0, 0)
  2.   [40, 50, 60]], # 对应索引 (0, 1)
  3. [[40, 50, 60],  # 对应索引 (1, 0)
  4.   [10, 20, 30]]] # 对应索引 (1, 1)
复制代码
总结:



  • x1 使用的是高级整数索引,按指定的具体位置取值(减少维度)。
  • x2 使用的是多维张量索引,按张量外形广播,天生一个更高维的效果(不减少维度)。


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

宝塔山

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表