qidao123.com技术社区-IT企服评测·应用市场
标题:
PyTorch_张量索引操作
[打印本页]
作者:
杀鸡焉用牛刀
时间:
2025-5-7 03:00
标题:
PyTorch_张量索引操作
简单行,列索引操作
import torch
import numpy as np
# 简单行列索引
def test01():
data = torch.randint(0, 10, [4, 5])
print(data)
# 获得指定的某行元素
print(data[0])
# 获得指定某个列的元素
print(data[:, 0]) # 逗号前面表示行,逗号后面表示列。冒号表示所有行或者所有列
# 获得指定位置的某个元素
print(data[1, 2])
# 表示先获得前三行,再获得第三列的数据
print(data[:3, 2])
# 表示获得前三行的前两列
print(data[:3, :2])
# 列表索引
def test02():
data = torch.randint(0, 10, [4, 5])
print(data)
# 如果索引的行列都是一个一维的列表,那么两个列表的长度必须相等
# 表示获得 (0, 0), (2, 1), (3, 2) 三个位置的元素
print(data[[0, 2, 3], [0, 1, 2]])
# 表示获得 0, 2, 3 行的 0, 1, 2 列
print(data[[[0], [2], [3]], [0, 1, 2]])
if __name__ == "__main__":
test02()
复制代码
布尔索引
import torch
import numpy as np
# 布尔索引
def test01():
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data)
# 能够获得该张量中所有大于3的元素
# 张量可以与数字做比较
print(data > 3)
print(data[data > 3])
# 返回第2列元素大于6的行
print(data[data[:, 1] > 6])
# 返回第2行元素大于3的所有列
print(data[:, data[1] > 3])
# 多维索引
def test02():
torch.manual_seed(0)
data = torch.randint(0, 10, [3, 4, 5])
print(data)
# 选择第0行的所有元素
print(data[0, :, :])
# 按照第1哥维度选择第0元素, 是按行
print(data[:, 0, :])
# 按照第2个维度选择第0元素, 是按列
print(data[:, :, 0])
if __name__ == "__main__":
test02()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 qidao123.com技术社区-IT企服评测·应用市场 (https://dis.qidao123.com/)
Powered by Discuz! X3.4