自界说前向与反向传播:torch.autograd.Function

打印 上一主题 下一主题

主题 1025|帖子 1025|积分 3075

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
1. 引言

在现代深度学习框架中,自动求导机制是模子训练的核心技术之一。PyTorch 的 torch.autograd 提供了一种强盛的方式来实现这一机制,帮助开辟者在前向传播后自动计算梯度。然而,尽管 PyTorch 提供了丰富的自动求导支持,偶尔我们可能会遇到一些特殊操作,这些操作无法依赖 PyTorch 的自动求导。这时,我们就需要使用 torch.autograd.Function 来自界说前向和反向传播逻辑,从而适应模子的独特需求。
1.1 PyTorch 自动求导机制简介

PyTorch 的核心自动求导工具 torch.autograd 使用了一种基于动态计算图的机制。当你在 Tensor 上调用操作时,PyTorch 会根据这些操作动态地构建一个有向无环图(DAG)。在这个图中,叶子节点体现输入张量,根节点则是输出张量。每个节点都体现一个操作,而 autograd 通过从根节点回溯(backpropagation),徐徐计算各个节点的梯度。
PyTorch 自动求导的强盛之处在于其动态计算图构建方式。在前向传播期间,每当执行一次操作,PyTorch 就会创建相应的计算图,并允许你通过 backward() 调用计算梯度。在这种机制下,PyTorch 既可以或许高效计算复杂网络的梯度,也可以或许灵活地支持不同范例的张量操作。
然而,并非全部的操作都能轻松地通过 PyTorch 内置的机制实现梯度计算。比方,当你想要实现一个新的数学运算或优化方法时,可能会遇到 PyTorch 无法自动处置惩罚的梯度计算题目。这时间,就需要我们通过 torch.autograd.Function 自界说前向传播和反向传播逻辑。
1.2 为什么我们需要自界说 autograd.Function

虽然 PyTorch 的 autograd 足够强盛,但在某些情况下,开辟者可能盼望更加灵活地控制前向传播和反向传播过程。重要的使用场景包括:

  • 非标准操作的梯度计算:对于一些非常规的数学运算,如量子力学中的特定操作,或者在某些科学计算中涉及的复杂自界说函数,PyTorch 的自动求导机制可能并不能自动处置惩罚此类操作的梯度。
  • 性能优化:某些自界说的操作可能具有明确的梯度表达式,但在自动求导过程中计算服从不高。这时,我们可以通过手动界说反向传播,使用更高效的计算方法来加快训练。
  • 数值稳定性题目:在某些情况下,自动求导机制可能会导致数值稳定性题目。比方,在涉及非常小的数值时,梯度计算可能会变得禁绝确。这时,通过自界说 Function 可以对梯度进行准确控制,包管数值稳定性。
  • 实现自界说优化方法:当使用常规的优化方法无法满足需求时,开辟者可以通过自界说 Function 实现新的优化算法。
通过 torch.autograd.Function,我们可以自界说特定操作的前向传播和反向传播,这在处置惩罚复杂模子或需要更高性能时非常有用。
2. torch.autograd.Function 底子概念

2.1 Function 与 Module 的区别

在 PyTorch 中,torch.nn.Module 和 torch.autograd.Function 都能帮助开辟者进行模子扩展,但它们的角色和实现机制不同。


  • Module:实用于界说复杂的神经网络层结构,如卷积层、全毗连层等,并自动处置惩罚前向传播和反向传播中的梯度计算。
    torch.nn.Module 是 PyTorch 中用于构建深度学习模子的核心模块。它为模子的结构界说、参数管理和前向传播提供了标准接口。每个 Module 都可以包罗其他子模块,并通过调用 forward 方法执行前向传播。在使用 Module 时,PyTorch 会自动处置惩罚内部参数的梯度计算,因此开辟者无需关注具体的梯度计算细节。
    常见的 torch.nn.Module 示例包括卷积层(Conv2d)、全毗连层(Linear)和池化层(MaxPool2d)等。这些层已经内置了前向传播和梯度计算的机制,可以或许高效执行各种操作。
  • Function:实用于实现单一操作(如激活函数、损失函数等),需要手动界说前向传播和反向传播逻辑,尤其适合无法自动计算梯度的操作。
    torch.autograd.Function 是 PyTorch 中更底层的计算单元。与 Module 不同的是,Function 需要开辟者手动实现前向传播和反向传播。它实用于那些无法通过自动求导机制直接计算梯度的情况,允许开辟者完全自界说操作的活动。
    使用 Function 时,我们可以界说 forward 和 backward 两个静态方法,分别控制前向传播中的计算过程和反向传播中的梯度计算逻辑。这使得 Function 在特定的应用场景下非常灵活,特别是对于需要精细控制梯度计算的场合。
通过 Module,我们可以方便地设计网络层及其内部的参数。而 Function 则更底层,允许我们自界说具体的操作流程,特别是自界说梯度的计算过程。
2.2 Function 的使用场景与根本用法

torch.autograd.Function 提供了一种方式,允许用户自界说前向传播的计算过程和反向传播中的梯度计算。通过继续 Function 类,我们可以实现两个静态方法:


  • forward(ctx, *args):界说前向传播的计算逻辑。该方法接收输入张量,并将其返回的输出用于下一步的计算。在前向传播过程中,我们可以通过 ctx 生存一些中央结果,以便反向传播时使用。
  • backward(ctx, *grad_outputs):界说反向传播中的梯度计算。该方法接收上游通报的梯度值,并结合前向传播时生存的中央结果来计算输入的梯度。
假如有以下一条前向传播链:
                                                                                       x                                     →                                     f                                     →                                     y                                     →                                     g                                     →                                     z                                                                            (1)                                                       x \rightarrow f \rightarrow y \rightarrow g \rightarrow z \tag{1}                     x→f→y→g→z(1)
即                                    y                         =                         f                         (                         x                         )                              y = f(x)                  y=f(x),                                    z                         =                         g                         (                         y                         )                              z = g(y)                  z=g(y),根据链式法则:
                                                                ∂                                  z                                                      ∂                                  x                                                 =                                                   ∂                                  z                                                      ∂                                  y                                                                        ∂                                  y                                                      ∂                                  x                                                       \frac{\partial z}{\partial x} = \frac{\partial z}{\partial y} \frac{\partial y}{\partial x}                     ∂x∂z​=∂y∂z​∂x∂y​
假如我们想通过 torch.autograd.Function 自界说                                    f                              f                  f,则其中                                                         ∂                               z                                                 ∂                               y                                                 \frac{\partial z}{\partial y}                  ∂y∂z​ 就是 grad_output,我们在 forward 里需要返回                                    f                         (                         x                         )                              f(x)                  f(x),在 backward 里需要返回 grad_output * f'(x)。
具体来讲,假设                                    f                         (                         x                         )                         =                         2                         x                              f(x)=2x                  f(x)=2x,则                                              f                            ′                                  (                         x                         )                         =                         2                              f'(x)=2                  f′(x)=2。那么前向传播就需要返回                                    2                         x                              2x                  2x,其中                                    x                              x                  x 就是 input,反向传播则需要返回 grad_output * 2。
  1. import torch
  2. from torch.autograd import Function
  3. class CustomFunction(Function):
  4.     @staticmethod
  5.     def forward(ctx, input):
  6.         result = input * 2  # 前向传播的简单操作
  7.         ctx.save_for_backward(input)  # 保存输入用于反向传播
  8.         return result
  9.     @staticmethod
  10.     def backward(ctx, grad_output):
  11.         input, = ctx.saved_tensors  # 获取前向传播时保存的输入
  12.         grad_input = grad_output * 2  # 计算输入的梯度
  13.         return grad_input
  14. # 测试自定义的函数
  15. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  16. y = CustomFunction.apply(x)
  17. y.sum().backward()
  18. print(x.grad)  # 输出 [2, 2, 2],对应自定义函数的梯度
复制代码
在这个简单的示例中,forward 方法计算输入的两倍,而 backward 方法则根据前向传播时生存的中央结果,计算输入的梯度。通过这种方式,开辟者可以完全控制操作的前向传播和反向传播过程。
3. torch.autograd.Function 的核心方法

3.1 forward 方法

forward 方法负责实现自界说操作的前向传播逻辑。该方法接收输入张量,并将其返回的输出用于下一步的计算。在前向传播过程中,我们通常会生存一些中央计算结果,以便在反向传播时使用。这些数据可以通过 ctx.save_for_backward() 方法进行存储。
示例:自界说前向传播
  1. import torch
  2. from torch.autograd import Function
  3. class MyFunction(Function):
  4.     @staticmethod
  5.     def forward(ctx, input):
  6.         ctx.save_for_backward(input)
  7.         return input ** 2
  8. # 测试自定义前向传播
  9. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  10. y = MyFunction.apply(x)
  11. print(y)  # 输出 [1.0, 4.0, 9.0]
复制代码
在这个示例中,我们实现了一个简单的自界说平方函数。在前向传播过程中,我们生存了输入张量,以便在后续的反向传播中使用。
3.2 backward 方法

backward 方法负责反向传播中的梯度计算。它接收上游通报的梯度值 grad_output,并结合前向传播生存的中央结果来计算输入的梯度。
  1. import torch
  2. from torch.autograd import Function
  3. class MyFunction(Function):
  4.     @staticmethod
  5.     def forward(ctx, input):
  6.         ctx.save_for_backward(input)
  7.         return input ** 2
  8.     @staticmethod
  9.     def backward(ctx, grad_output):
  10.         input, = ctx.saved_tensors
  11.         return grad_output * 2 * input
  12. # 测试自定义反向传播
  13. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  14. y = MyFunction.apply(x)
  15. y.sum().backward()
  16. print(x.grad)  # 输出 [2.0, 4.0, 6.0],对应 x**2 的梯度
复制代码
这个示例展示了怎样根据前向传播生存的中央结果计算梯度。通过 ctx.saved_tensors,我们可以在反向传播中获取前向传播时生存的张量,并使用它们计算梯度。
3.3 ctx 对象

ctx 是 Function 类中前向传播和反向传播之间的信息桥梁。通过 ctx 对象,我们可以在前向传播中生存数据,并在反向传播中访问这些数据。常见的操作包括:


  • ctx.save_for_backward(*tensors):生存前向传播中计算的张量。
  • ctx.saved_tensors:获取生存的张量。
  • ctx.mark_dirty(*tensors):标记在前向传播中被当场修改的张量。
  • ctx.mark_non_differentiable(*tensors):标记某些张量为不可微分,从而提高计算服从。
ctx.save_for_backward 的使用示例
  1. class MyFunction(Function):
  2.     @staticmethod
  3.     def forward(ctx, input):
  4.         result = input ** 3
  5.         ctx.save_for_backward(result)
  6.         return result
  7.     @staticmethod
  8.     def backward(ctx, grad_output):
  9.         result, = ctx.saved_tensors
  10.         return grad_output * 3 * result ** 2
  11. # 测试带保存数据的自定义函数
  12. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  13. y = MyFunction.apply(x)
  14. y.sum().backward()
复制代码
ctx.save_for_backward 方法允许我们在前向传播中存储需要在反向传播中使用的张量数据。通过这种机制,我们可以在梯度计算中复用前向传播的结果,从而克制重复计算。
4. 自界说案例分析

接下来,我们将通过一些案例来演示怎样在 torch.autograd.Function 中自界说前向和反向传播。为了克制抄袭风险,以下案例是基于原有博客中的案例修改而成,并加入了一些全新的自界说操作。
4.1 自界说简单指数函数

在这个案例中,我们通过 torch.autograd.Function 自界说一个简单的指数函数。前向传播计算指数值,反向传播则使用指数函数的导数特性进行梯度计算。
  1. import torch
  2. from torch.autograd import Function
  3. class CustomExp(Function):
  4.     @staticmethod
  5.     def forward(ctx, input):
  6.         result = input.exp()
  7.         ctx.save_for_backward(result)
  8.         return result
  9.     @staticmethod
  10.     def backward(ctx, grad_output):
  11.         result, = ctx.saved_tensors
  12.         return grad_output * result
  13. # 测试自定义的指数函数
  14. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  15. y = CustomExp.apply(x)
  16. y.sum().backward()
  17. print(x.grad)  # 输出 [e^1, e^2, e^3] 的梯度
复制代码
该案例展示了怎样通过自界说 Function 实现一个简单的指数操作。反向传播使用指数的导数,即指数函数本身。
4.2 自界说平方和梯度的反向传播

在这一案例中,我们将实现一个计算平方和的自界说函数。前向传播计算输入张量的平方和,而反向传播则计算平方和相对于输入的梯度。
  1. class CustomSquareSum(Function):
  2.     @staticmethod
  3.     def forward(ctx, input):
  4.         ctx.save_for_backward(input)
  5.         return (input ** 2).sum()
  6.     @staticmethod
  7.     def backward(ctx, grad_output):
  8.         input, = ctx.saved_tensors
  9.         return grad_output * 2 * input
  10. # 测试自定义平方和函数
  11. x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
  12. y = CustomSquareSum.apply(x)
  13. y.backward()
  14. print(x.grad)  # 输出 [2*x1, 2*x2, 2*x3] 的梯度
复制代码
在这个案例中,前向传播计算的是输入张量元素的平方和,反向传播计算的是每个输入元素的梯度,遵循平方和的导数公式:
                                                                ∂                                  (                                               x                                     i                                     2                                              )                                                      ∂                                               x                                     i                                                             =                            2                                       x                               i                                            \frac{\partial (x_i^2)}{\partial x_i} = 2x_i                     ∂xi​∂(xi2​)​=2xi​
因此,最终输出的梯度是输入张量的两倍。
4.3 自界说复杂运算的梯度计算

为了展示 Function 可以处置惩罚更复杂的运算,我们设计一个计算输入张量平方根加反转的自界说函数。这个函数的前向传播包括对输入计算平方根以及反转张量的数值,反向传播则使用链式法则,计算梯度传播。
  1. class CustomSqrtInverse(Function):
  2.     @staticmethod
  3.     def forward(ctx, input):
  4.         result = input.sqrt() + torch.reciprocal(input)
  5.         ctx.save_for_backward(input)
  6.         return result
  7.     @staticmethod
  8.     def backward(ctx, grad_output):
  9.         input, = ctx.saved_tensors
  10.         grad_input = (0.5 / input.sqrt()) - (1.0 / input ** 2)
  11.         return grad_output * grad_input
  12. # 测试自定义平方根加反转函数
  13. x = torch.tensor([4.0, 9.0, 16.0], requires_grad=True)
  14. y = CustomSqrtInverse.apply(x)
  15. y.sum().backward()
  16. print(x.grad)  # 输出自定义梯度
复制代码
在这个复杂的案例中,我们自界说了一个同时涉及平方根和倒数的运算。前向传播首先对输入张量进行平方根计算,然后加上其倒数。反向传播的梯度计算需要分别对平方根和倒数求导,使用了以下导数公式:


  • 对平方根的导数:                                                               ∂                                               x                                                                  ∂                                  x                                                 =                                       1                                           2                                               x                                                                   \frac{\partial \sqrt{x}}{\partial x} = \frac{1}{2\sqrt{x}}                     ∂x∂x                     ​​=2x                     ​1​
  • 对倒数的导数:                                                               ∂                                               (                                                   1                                        x                                                  )                                                                  ∂                                  x                                                 =                            −                                       1                                           x                                  2                                                       \frac{\partial \left(\frac{1}{x}\right)}{\partial x} = -\frac{1}{x^2}                     ∂x∂(x1​)​=−x21​
最终,backward 方法结合这两个公式,计算出梯度传播的正确值。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

宝塔山

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表