一、激活函数的目标
激活函数的目标是为网络引入非线性,并使其能够学习并迫近复杂的数据模式
二、介绍GLU(Gated Linear Unit)
GLU:将输入分成两部分,一部分直接颠末线性变动,另一部分颠末 s i g m o i d sigmoid sigmoid 函数变动,然后将这两部分的输出逐点相乘
G L U ( x , W , V , B , c ) = σ ( x W + b ) ⊗ ( x V + c ) GLU(x, W, V, B, c) = \sigma (xW + b) \otimes (xV + c) GLU(x,W,V,B,c)=σ(xW+b)⊗(xV+c)
- $ \sigma $ 是 $ sigmoid $ 激活函数
- $ W, V $ 权重
- $ b, c $ 偏置
绘制GLU激活函数
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import matplotlib.pyplot as plt
- # 定义GLU激活函数
- class GLU(nn.Module):
- def forward(self, x):
- a, b = x.chunk(2, dim=-1)
- print('a:', a, 'b:', b)
- return a * F.sigmoid(b) # 应用sigmoid函数然后进行逐元素乘法(权重和偏置为1)
- # 实例化GLU模块
- glu = GLU()
- # torch.linspace(-3, 3, 100):在-3到3中生成一个等距的一维数组,数量为100个
- # unsqueeze(-1)将原先 100 个元素 的一维数组,转换成 100*1 的二维数组
- # expand(-1, 2) 复制 100*1的单列,生成 100*2的两列
- x_range = torch.linspace(-3, 3, 100).unsqueeze(-1).expand(-1, 2)
- y_glu = glu(x_range) # 得到经过GLU变换的结果
- plt.figure(figsize=(10, 4))
- plt.plot(x_range[:, 0].numpy(), y_glu.detach().numpy(), label='GLU Function')
- plt.xlabel('Input value')
- plt.ylabel('Output value')
- plt.title('GLU Activation Function')
- plt.legend()
- plt.grid(True)
- plt.show()
复制代码
三、介绍Swish激活函数
$ SwiGLU $ 是 $ GLU $ 的一种变体,其中包含了 G L U GLU GLU 和 S w i s h Swish Swish 激活函数。
S w i s h β ( x ) = x σ ( β x ) Swish_{\beta}(x) = x \sigma(\beta x) Swishβ(x)=xσ(βx)
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import matplotlib.pyplot as plt
- class Swish(nn.Module):
- def forward(self, x, beta):
- print(x)
- return x * F.sigmoid(beta * x)
- swish = Swish()
- x_range = torch.linspace(-3, 3, 100).unsqueeze(-1)
- betas = [0.1, 1.0, 10.0]
- plt.figure(figsize=(10, 4))
- for beta in betas:
- y_swish = swish(x_range, beta)
- plt.plot(x, y_swish, label=f'beta={beta}')
- plt.xlabel('Input value')
- plt.ylabel('Output value')
- plt.title('Swish Activation Function')
- plt.legend()
- plt.grid(True)
- plt.show()
复制代码
四、介绍SwiGLU
将 G L U GLU GLU 中的激活函数 s i g m o i d sigmoid sigmoid 改为 S w i s h Swish Swish 就是 S w i G L U SwiGLU SwiGLU 激活函数。
S w i G L U ( x , W , V , B , c ) = S w i s h β ( x W + b ) ⊗ ( x V + c ) SwiGLU(x, W, V, B, c) = Swish_\beta(xW + b) \otimes (xV + c) SwiGLU(x,W,V,B,c)=Swishβ(xW+b)⊗(xV+c)
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import matplotlib.pyplot as plt
- class SwiGLU(nn.Module):
- def forward(self, x):
- a, b = x.chunk(2, dim=-1)
- return a * F.silu(b) # 使用Swish激活函数,F.silu就是Swish
- swiglu = SwiGLU()
- x_range = torch.linspace(-3, 3, 100) # 创建一个范围为-3到3的线性空间
- y_swiglu = swiglu(x_range.unsqueeze(-1).expand(-1, 2)) # 应用 SwiGLU 函数,确保维度是偶数
- # 绘制 SwiGLU 函数的图像
- plt.figure(figsize=(10, 4))
- plt.plot(x_range.numpy(), y_swiglu.detach().numpy(), label='SwiGLU Function')
- plt.xlabel('Input value')
- plt.ylabel('Output value')
- plt.title('SwiGLU Activation Function')
- plt.legend()
- plt.grid(True)
- plt.show()
复制代码
五、GLU 和 SwiGLU 的区别
仅为 G L U GLU GLU 使用 s i g m o i d sigmoid sigmoid , S w i G L U SwiGLU SwiGLU 使用 S w i s h Swish Swish。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |