PyTorch中实现开立方

[复制链接]
发表于 2025-6-26 09:30:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
技术背景

在PyTorch中,没有直接实现cbrt这一算子。这个算子是用于计算一个数的开立方,比方,最简单的-8开立方就是-2。但这里有个问题是,在PyTorch中,因为没有cbrt算子,如果直接用幂次计算去操作数字,就有可能出现问题。
代码示例

起首看一下numpy做开立方的代码示例:
  1. In [1]: import numpy as np
  2. In [2]: a = np.array(-8, np.float32)
  3. In [3]: a**(1/3)
  4. <ipython-input-3-f6e83d4e282e>:1: RuntimeWarning: invalid value encountered in power
  5.   a**(1/3)
  6. Out[3]: np.float32(nan)
  7. In [4]: np.cbrt(a)
  8. Out[4]: np.float32(-2.0)
复制代码
在这个示例中,如果直接开立方,结果会是一个nan,很明显不是我们想要的一个结果。而cbrt是一个单独实现的开立方算子,可以支持负数的输入,计算结果也是正确的。在PyTorch的场景下,只能用幂次运算:
  1. In [1]: import torch as tc
  2. In [2]: a=tc.tensor(-8,dtype=tc.float32)
  3. In [3]: a**(1/3)
  4. Out[3]: tensor(nan)
复制代码
这样得到的结果是错误的。因此必要我们本身实现一个cbrt函数:
  1. In [1]: import torch as tc
  2. In [2]: cbrt=lambda x: tc.sign(x)*tc.abs(x)**(1/3)
  3. In [3]: a=tc.tensor(-8,dtype=tc.float32)
  4. In [4]: cbrt(a)
  5. Out[4]: tensor(-2.)
复制代码
其实逻辑也比较简单,就是先把符号提取出来,然后再转化为正数正常计算就好了。
总结概要

本文介绍了在PyTorch中直接使用幂次函数计算有可能导致的计算结果异常的问题。由于PyTorch中并未像Numpy和MindSpore一样直接支持cbrt开立方函数,因此这里也提供了一个在PyTorch中计算开立方的函数。
版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/cbrt.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表