llama源码学习·model.py[2]SwiGLU激活函数

打印 上一主题 下一主题

主题 934|帖子 934|积分 2802

一、激活函数的目标

激活函数的目标是为网络引入非线性,并使其能够学习并迫近复杂的数据模式
二、介绍GLU(Gated Linear Unit)

GLU:将输入分成两部分,一部分直接颠末线性变动,另一部分颠末                                    s                         i                         g                         m                         o                         i                         d                              sigmoid                  sigmoid 函数变动,然后将这两部分的输出逐点相乘
                                         G                            L                            U                            (                            x                            ,                            W                            ,                            V                            ,                            B                            ,                            c                            )                            =                            σ                            (                            x                            W                            +                            b                            )                            ⊗                            (                            x                            V                            +                            c                            )                                  GLU(x, W, V, B, c) = \sigma (xW + b) \otimes (xV + c)                     GLU(x,W,V,B,c)=σ(xW+b)⊗(xV+c)


  • $ \sigma $ 是 $ sigmoid $ 激活函数
  • $ W, V $ 权重
  • $ b, c $ 偏置
绘制GLU激活函数

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import matplotlib.pyplot as plt
  5. # 定义GLU激活函数
  6. class GLU(nn.Module):
  7.     def forward(self, x):
  8.         a, b = x.chunk(2, dim=-1)  
  9.         print('a:', a, 'b:', b)
  10.         return a * F.sigmoid(b)  # 应用sigmoid函数然后进行逐元素乘法(权重和偏置为1)
  11. # 实例化GLU模块
  12. glu = GLU()
  13. # torch.linspace(-3, 3, 100):在-3到3中生成一个等距的一维数组,数量为100个
  14. # unsqueeze(-1)将原先 100 个元素 的一维数组,转换成 100*1 的二维数组
  15. # expand(-1, 2)  复制 100*1的单列,生成 100*2的两列
  16. x_range = torch.linspace(-3, 3, 100).unsqueeze(-1).expand(-1, 2)  
  17. y_glu = glu(x_range) # 得到经过GLU变换的结果
  18. plt.figure(figsize=(10, 4))
  19. plt.plot(x_range[:, 0].numpy(), y_glu.detach().numpy(), label='GLU Function')
  20. plt.xlabel('Input value')
  21. plt.ylabel('Output value')
  22. plt.title('GLU Activation Function')
  23. plt.legend()
  24. plt.grid(True)
  25. plt.show()
复制代码

三、介绍Swish激活函数

$ SwiGLU $ 是 $ GLU $ 的一种变体,其中包含了                                    G                         L                         U                              GLU                  GLU 和                                    S                         w                         i                         s                         h                              Swish                  Swish 激活函数。
                                         S                            w                            i                            s                                       h                               β                                      (                            x                            )                            =                            x                            σ                            (                            β                            x                            )                                  Swish_{\beta}(x) = x \sigma(\beta x)                     Swishβ​(x)=xσ(βx)


  • $ \beta $ 是一个可学习参数
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import matplotlib.pyplot as plt
  5. class Swish(nn.Module):
  6.     def forward(self, x, beta):
  7.         print(x)
  8.         return x * F.sigmoid(beta * x)
  9. swish = Swish()
  10. x_range = torch.linspace(-3, 3, 100).unsqueeze(-1)
  11. betas = [0.1, 1.0, 10.0]
  12. plt.figure(figsize=(10, 4))
  13. for beta in betas:
  14.     y_swish = swish(x_range, beta)
  15.     plt.plot(x, y_swish, label=f'beta={beta}')
  16. plt.xlabel('Input value')
  17. plt.ylabel('Output value')
  18. plt.title('Swish Activation Function')
  19. plt.legend()
  20. plt.grid(True)
  21. plt.show()
复制代码

四、介绍SwiGLU

将                                    G                         L                         U                              GLU                  GLU 中的激活函数                                    s                         i                         g                         m                         o                         i                         d                              sigmoid                  sigmoid 改为                                    S                         w                         i                         s                         h                              Swish                  Swish 就是                                    S                         w                         i                         G                         L                         U                              SwiGLU                  SwiGLU 激活函数。
                                         S                            w                            i                            G                            L                            U                            (                            x                            ,                            W                            ,                            V                            ,                            B                            ,                            c                            )                            =                            S                            w                            i                            s                                       h                               β                                      (                            x                            W                            +                            b                            )                            ⊗                            (                            x                            V                            +                            c                            )                                  SwiGLU(x, W, V, B, c) = Swish_\beta(xW + b) \otimes (xV + c)                     SwiGLU(x,W,V,B,c)=Swishβ​(xW+b)⊗(xV+c)
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import matplotlib.pyplot as plt
  5. class SwiGLU(nn.Module):
  6.     def forward(self, x):
  7.         a, b = x.chunk(2, dim=-1)
  8.         return a * F.silu(b)  # 使用Swish激活函数,F.silu就是Swish
  9. swiglu = SwiGLU()
  10. x_range = torch.linspace(-3, 3, 100)  # 创建一个范围为-3到3的线性空间
  11. y_swiglu = swiglu(x_range.unsqueeze(-1).expand(-1, 2))  # 应用 SwiGLU 函数,确保维度是偶数
  12. # 绘制 SwiGLU 函数的图像
  13. plt.figure(figsize=(10, 4))
  14. plt.plot(x_range.numpy(), y_swiglu.detach().numpy(), label='SwiGLU Function')
  15. plt.xlabel('Input value')
  16. plt.ylabel('Output value')
  17. plt.title('SwiGLU Activation Function')
  18. plt.legend()
  19. plt.grid(True)
  20. plt.show()
复制代码

五、GLU 和 SwiGLU 的区别

仅为                                    G                         L                         U                              GLU                  GLU 使用                                    s                         i                         g                         m                         o                         i                         d                              sigmoid                  sigmoid ,                                   S                         w                         i                         G                         L                         U                              SwiGLU                  SwiGLU 使用                                    S                         w                         i                         s                         h                              Swish                  Swish。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

雁过留声

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表