PyTorch_张量索引操作

打印 上一主题 下一主题

主题 2055|帖子 2055|积分 6165

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
简单行,列索引操作

  1. import torch
  2. import numpy as np
  3. # 简单行列索引
  4. def test01():
  5.     data = torch.randint(0, 10, [4, 5])
  6.     print(data)
  7.     # 获得指定的某行元素
  8.     print(data[0])
  9.     # 获得指定某个列的元素
  10.     print(data[:, 0])  # 逗号前面表示行,逗号后面表示列。冒号表示所有行或者所有列
  11.     # 获得指定位置的某个元素
  12.     print(data[1, 2])
  13.     # 表示先获得前三行,再获得第三列的数据
  14.     print(data[:3, 2])
  15.     # 表示获得前三行的前两列
  16.     print(data[:3, :2])
  17. # 列表索引
  18. def test02():
  19.     data = torch.randint(0, 10, [4, 5])
  20.     print(data)
  21.     # 如果索引的行列都是一个一维的列表,那么两个列表的长度必须相等
  22.     # 表示获得 (0, 0), (2, 1), (3, 2) 三个位置的元素
  23.     print(data[[0, 2, 3], [0, 1, 2]])
  24.     # 表示获得 0, 2, 3 行的 0, 1, 2 列
  25.     print(data[[[0], [2], [3]], [0, 1, 2]])
  26. if __name__ == "__main__":
  27.     test02()
复制代码

布尔索引

  1. import torch
  2. import numpy as np
  3. # 布尔索引
  4. def test01():
  5.     torch.manual_seed(0)
  6.     data = torch.randint(0, 10, [4, 5])
  7.     print(data)
  8.     # 能够获得该张量中所有大于3的元素
  9.     # 张量可以与数字做比较
  10.     print(data > 3)
  11.     print(data[data > 3])
  12.     # 返回第2列元素大于6的行
  13.     print(data[data[:, 1] > 6])
  14.     # 返回第2行元素大于3的所有列
  15.     print(data[:, data[1] > 3])
  16. # 多维索引
  17. def test02():
  18.     torch.manual_seed(0)
  19.     data = torch.randint(0, 10, [3, 4, 5])
  20.     print(data)
  21.     # 选择第0行的所有元素
  22.     print(data[0, :, :])
  23.     # 按照第1哥维度选择第0元素, 是按行
  24.     print(data[:, 0, :])
  25.     # 按照第2个维度选择第0元素, 是按列
  26.     print(data[:, :, 0])
  27. if __name__ == "__main__":
  28.     test02()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

杀鸡焉用牛刀

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表