从代码学习深度学习 - 多头留意力 PyTorch 版
<hr> 媒介在深度学习范畴,留意力机制(Attention Mechanism)是自然语言处置处罚(NLP)和盘算机视觉(CV)等任务中的核心组件之一。特殊是多头留意力(Multi-Head Attention),作为 Transformer 模子的根本,极大地提升了模子对复杂依赖关系的捕捉本领。本文通太过析一个完备的 PyTorch 实现,带你深入明白多头留意力的原理和代码实现。我们将从代码入手,徐徐分析每个函数和类的功能,联合笔墨阐明,让你不但能运行代码,还能明白其背后的计划逻辑。无论你是初学者还是有肯定履历的开发者,这篇博客都将资助你更直观地把握多头留意力机制。
完备代码:下载链接
<hr> 一、多头留意力机制先容
多头留意力(Multi-Head Attention)是 Transformer 模子的核心组件之一,广泛应用于自然语言处置处罚(NLP)、盘算机视觉(CV)等范畴。它通过并行运行多个留意力头(Attention Heads),允许模子同时关注输入序列中的差异部门,从而捕捉更丰富的语义和上下文依赖关系。相比单一的留意力机制,多头留意力极大地增强了模子的表达本领,可以或许处置处罚复杂的模式和长间隔依赖。
1.1 工作原理
多头留意力的核心头脑是将输入的查询(Queries)、键(Keys)和值(Values)通过线性变更映射到多个子空间,每个子空间由一个独立的留意力头处置处罚。具体步调如下:
[*]线性变更:对输入的查询、键和值分别应用线性层,将其映射到隐蔽维度(num_hiddens),并分割为多个头的表现。
[*]缩放点积留意力:每个留意力头独立盘算缩放点积留意力(Scaled Dot-Product Attention),即通过查询和键的点积盘算留意力分数,再与值加权求和。
[*]并行盘算:多个留意力头并行运行,每个头关注输入的差异方面,天生各自的输出。
[*]合并与变更:将全部头的输出拼接起来,并通过一个线性层融合,得到终极的多头留意力输出。
这种计划允许模子在差异子空间中学习差异的特性,例如在 NLP 任务中,一个头大概关注句法结构,另一个头大概关注语义关系。
https://i-blog.csdnimg.cn/direct/403fe0a2f6204068872acadd3b41e67e.png
1.2 上风
[*]多样性:多头机制使模子可以或许从多个角度明白输入,捕捉多样化的模式。
[*]并行性:多头盘算可以高效并行化,提升盘算效率。
[*]稳固性:通过缩放点积(除以特性维度的平方根),缓解了高维点积导致的数值不稳固题目。
1.3 代码实现概述
在本文的实现中,我们使用 PyTorch 构建了一个完备的多头留意力模块,包罗以下关键部门:
[*]序列掩码:处置处罚变长序列,屏蔽无效位置。
[*]缩放点积留意力:实现单个留意力头的盘算逻辑。
[*]张量转换:通过 transpose_qkv 和 transpose_output 函数实现多头分割与合并。
[*]多头留意力类:整合全部组件,完成并行盘算和输出融合。
接下来的代码分析将具体展示这些部门的实现,资助你从代码层面深入明白多头留意力的每一步盘算逻辑。
二、代码分析
以下是代码的完备实现和具体分析,代码按照 Jupyter Notebook(在最开始给出了完备代码下载链接) 的结构构造,并附上笔墨阐明,资助你明白每个部门的逻辑。
2.1 导入依赖
起首,我们导入须要的 Python 包,包括数学运算库 math 和 PyTorch 的核心模块 torch 和 nn。
# 导入包
import math
import torch
from torch import nn
[*]math:用于盘算缩放点积留意力中的归一化因子(即特性维度的平方根)。
[*]torch:PyTorch 的核心库,提供张量运算和自动求导功能。
[*]nn:PyTorch 的神经网络模块,包罗 nn.Module 和 nn.Linear 等工具,用于构建神经网络层。
<hr> 序列掩码函数
在处置处罚序列数据(如句子)时,差异序列的长度大概差异,我们必要通过掩码(Mask)来屏蔽无效位置,防止模子关注这些添补区域。以下是 sequence_mask 函数的实现:
def sequence_mask(X, valid_len, value=0):
"""
在序列中屏蔽不相关的项,使超出有效长度的位置被设置为指定值
参数:
X: 输入张量,形状 (batch_size, 最大序列长度, 特征维度) 或 (batch_size, 最大序列长度)
valid_len: 有效长度张量,形状 (batch_size,),表示每个序列的有效长度
value: 屏蔽值,标量,默认值为 0,用于填充无效位置
返回:
输出张量,形状与输入 X 相同,无效位置被设置为 value
"""
maxlen = X.size(1)# 最大序列长度,标量
# 创建掩码,形状 (1, 最大序列长度),与 valid_len 比较生成布尔张量,形状 (batch_size, 最大序列长度)
mask = torch.arange(maxlen, dtype=torch.float32, device=X.device) < valid_len[:, None]
# 将掩码取反后,X 的无效位置被设置为 value
X[~mask] = value
return X
分析:
[*]输入:
[*]X:输入张量,通常是序列数据,大概包罗添补(padding)部门。
[*]valid_len:每个样本的有效长度,例如 表现第一个样本有 3 个有效 token,第二个样本有 2 个。
[*]value:用于添补无效位置的值,默以为 0。
[*]逻辑:
[*]maxlen 获取序列的最大长度(即张量的第二维)。
[*]torch.arange(maxlen) 创建一个从 0 到 maxlen-1 的序列,形状为 (1, maxlen)。
[*]通过广播机制,与 valid_len(形状 (batch_size, 1))比较,天生布尔掩码 mask,形状为 (batch_size, maxlen)。
[*]mask 表现哪些位置是有效的(True),哪些是无效的(False)。
[*]使用 ~mask 选择无效位置,将其值设置为 value。
[*]输出:修改后的张量 X,无效位置被设置为 value,形状稳固。
作用:该函数用于在留意力盘算中屏蔽添补区域,确保模子只关注有效 token。
2.2 掩码 Softmax 函数
在留意力机制中,我们必要对留意力分数应用 Softmax 利用,将其转换为概率分布。但由于序列长度差异,必要屏蔽无效位置的贡献。以下是 masked_softmax 函数的实现:
import torch
import torch.nn.functional as F
def masked_softmax(X, valid_lens):
"""
通过在最后一个轴上掩蔽元素来执行softmax操作,忽略无效位置
参数:
X: 输入张量,形状 (batch_size, 查询个数, 键-值对个数),3D张量
valid_lens: 有效长度张量,形状 (batch_size,) 或 (batch_size, 查询个数),1D或2D张量,
表示每个序列的有效长度,即每个查询可以参考的有效键值对长度
返回:
输出张量,形状 (batch_size, 查询个数, 键-值对个数),softmax后的注意力权重
"""
if valid_lens is None:
# 如果没有有效长度,直接在最后一个轴上应用softmax
return F.softmax(X, dim=-1)
shape = X.shape# 保存原始形状,(batch_size, 查询个数, 键-值对个数)
if valid_lens.dim() == 1:
# valid_lens 为 1D,形状 (batch_size,),重复扩展到 (batch_size
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]