马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
PyTorch torch.scatter_reduce 函数
torch.scatter_reduce 是 PyTorch 中的一种高级操作,用于在特定维度上将源张量的值按索引归约到目标张量中。它联合了 scatter 和 reduce 操作,非常得当处理须要对特定索引举行归约(如求和、最大值等)的场景。
函数签名
- torch.scatter_reduce(input, dim, index, src, reduce, *, include_self=True, out=None)
复制代码 参数阐明
- input:
- dim:
- 指定在目标张量 input 中举行归约操作的维度。
- index:
- 张量,表示目标张量中归约操作的索引位置。
- index 的形状必须与 src 兼容,大概可以广播成 src 的形状。
- src:
- reduce:
- 指定归约操作的范例,支持以下选项:
- "sum":按索引举行求和。
- "prod":按索引举行乘积。
- "mean":按索引计算平均值。
- "amax":按索引取最大值。
- "amin":按索引取最小值。
- include_self (可选, 默认 True):
- 是否在归约时包括 input 中的原始值。
- 假如为 False,只使用 src 中的值举行归约。
- out (可选):
返回值
返回一个张量,包含归约操作的结果,形状与 input 相同。
示例
1. 按索引求和 (reduce="sum")
- import torch
- input = torch.zeros(3, 5)
- index = torch.tensor([[0, 1, 2, 0, 1],
- [1, 2, 0, 1, 2]])
- src = torch.tensor([[10., 20., 30., 40., 50.],
- [1., 2., 3., 4., 5.]])
- result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum")
- print(result)
复制代码 输出:
- tensor([[50., 70., 30., 0., 0.],
- [ 3., 5., 7., 0., 0.],
- [ 0., 0., 0., 0., 0.]])
复制代码 2. 按索引取最大值 (reduce="amax")
- result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="amax")
- print(result)
复制代码 输出:
- tensor([[40., 50., 30., 0., 0.],
- [ 3., 4., 5., 0., 0.],
- [ 0., 0., 0., 0., 0.]])
复制代码 3. 使用 include_self=False
- result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum", include_self=False)
- print(result)
复制代码 输出:
- tensor([[50., 70., 30., 0., 0.],
- [ 3., 5., 7., 0., 0.],
- [ 0., 0., 0., 0., 0.]])
复制代码 注意事项
- index 范围:
- index 的值必须在 [0, input.shape[dim]) 范围内,否则会引发错误。
- 广播规则:
- index 和 src 必须具有相同的形状,大概可以通过广播匹配。
- 性能优化:
- torch.scatter_reduce 对于希罕更新和归约非常高效,制止了循环操作。
应用场景
- 聚合数据(如按索引分组求和或求最大值)。
- 构造希罕张量。
- 实现自定义的归约操作(如图神经网络中的消息传递)。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |