pytorch torch.scatter_reduce函数先容

打印 上一主题 下一主题

主题 1738|帖子 1738|积分 5214

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
PyTorch torch.scatter_reduce 函数

torch.scatter_reduce 是 PyTorch 中的一种高级操作,用于在特定维度上将源张量的值按索引归约到目标张量中。它联合了 scatter 和 reduce 操作,非常得当处理须要对特定索引举行归约(如求和、最大值等)的场景。

函数签名

  1. torch.scatter_reduce(input, dim, index, src, reduce, *, include_self=True, out=None)
复制代码
参数阐明


  • input:

    • 目标张量,表示归约操作的初始值。

  • dim:

    • 指定在目标张量 input 中举行归约操作的维度。

  • index:

    • 张量,表示目标张量中归约操作的索引位置。
    • index 的形状必须与 src 兼容,大概可以广播成 src 的形状。

  • src:

    • 源张量,提供要归约到 input 中的值。

  • reduce:

    • 指定归约操作的范例,支持以下选项:

      • "sum":按索引举行求和。
      • "prod":按索引举行乘积。
      • "mean":按索引计算平均值。
      • "amax":按索引取最大值。
      • "amin":按索引取最小值。


  • include_self (可选, 默认 True):

    • 是否在归约时包括 input 中的原始值。
    • 假如为 False,只使用 src 中的值举行归约。

  • out (可选):

    • 用于存储结果的张量。假如提供,将直接修改此张量。


返回值

返回一个张量,包含归约操作的结果,形状与 input 相同。

示例

1. 按索引求和 (reduce="sum")

  1. import torch
  2. input = torch.zeros(3, 5)
  3. index = torch.tensor([[0, 1, 2, 0, 1],
  4.                       [1, 2, 0, 1, 2]])
  5. src = torch.tensor([[10., 20., 30., 40., 50.],
  6.                     [1., 2., 3., 4., 5.]])
  7. result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum")
  8. print(result)
复制代码
输出
  1. tensor([[50., 70., 30.,  0.,  0.],
  2.         [ 3.,  5.,  7.,  0.,  0.],
  3.         [ 0.,  0.,  0.,  0.,  0.]])
复制代码
2. 按索引取最大值 (reduce="amax")

  1. result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="amax")
  2. print(result)
复制代码
输出
  1. tensor([[40., 50., 30.,  0.,  0.],
  2.         [ 3.,  4.,  5.,  0.,  0.],
  3.         [ 0.,  0.,  0.,  0.,  0.]])
复制代码
3. 使用 include_self=False

  1. result = torch.scatter_reduce(input, dim=1, index=index, src=src, reduce="sum", include_self=False)
  2. print(result)
复制代码
输出:
  1. tensor([[50., 70., 30.,  0.,  0.],
  2.         [ 3.,  5.,  7.,  0.,  0.],
  3.         [ 0.,  0.,  0.,  0.,  0.]])
复制代码
注意事项


  • index 范围

    • index 的值必须在 [0, input.shape[dim]) 范围内,否则会引发错误。

  • 广播规则

    • index 和 src 必须具有相同的形状,大概可以通过广播匹配。

  • 性能优化

    • torch.scatter_reduce 对于希罕更新和归约非常高效,制止了循环操作。


应用场景



  • 聚合数据(如按索引分组求和或求最大值)。
  • 构造希罕张量。
  • 实现自定义的归约操作(如图神经网络中的消息传递)。


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

惊落一身雪

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表