从代码学习深度学习 - 自注意力和位置编码 PyTorch 版

  论坛元老 | 2025-4-16 18:12:21 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1715|帖子 1715|积分 5145

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
媒介

深度学习近年来在天然语言处置惩罚、计算机视觉等领域取得了巨大乐成,而 Transformer 模型无疑是其中的明星架构。自注意力和位置编码作为 Transformer 的两大焦点组件,不光赋予了模型强盛的序列建模能力,还推动了 BERT、GPT 等模型的广泛应用。然而,理解这些概念的理论公式往往令人望而生畏,直接从代码入手则能让学习过程更加直观和风趣。
在这篇博客中,我们将基于 PyTorch,通太过析提供的代码文件(utils_for_huitu.py、MultiHeadAttention.py 以及一个 Jupyter 笔记本),深入探讨自注意力机制和位置编码的实现细节。从多头注意力的矩阵运算到位置编码的正弦余弦设计,我们将一步步拆解代码,揭示 Transformer 的工作原理。同时,通过可视化工具,我们将直观展示这些机制的内部表示,资助读者建立对深度学习模型的感性认知。
无论你是深度学习初学者,还是希望通过代码加深对 Transformer 理解的开发者,这篇文章都将为你提供一个清晰的学习路径。让我们一起从代码中发现深度学习的魅力吧!
完整代码:下载链接
<hr> 一、自注意力:Transformer 的焦点

自注意力机制(Self-Attention)是 Transformer 模型的根本,它允许模型在处置惩罚序列数据时动态地关注输入序列的不同部门。这种机制在天然语言处置惩罚使命(如 BERT、GPT)中表现尤为精彩。让我们从代码入手,探索自注意力机制的具体实现。
1.1 多头注意力机制的实现

MultiHeadAttention.py 文件中的 MultiHeadAttention 类实现了多头注意力机制,通过并行计算多个注意力头来加强模型的表达能力。以下是代码的焦点部门:
  1. import math
  2. import torch
  3. from torch import nn
  4. import torch.nn.functional as F
  5. class MultiHeadAttention(nn.Module):
  6.     """多头注意力机制"""
  7.     def __init__(self, key_size, query_size, value_size, num_hiddens,
  8.                  num_heads, dropout, bias=False, **kwargs):
  9.         super(MultiHeadAttention, self).__init__(**kwargs)
  10.         self.num_heads = num_heads
  11.         self.attention = DotProductAttention(dropout)
  12.         self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
  13.         self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
  14.         self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
  15.         self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  16.     def forward(self, queries, keys, values, valid_lens):
  17.         queries = transpose_qkv(self.W_q(queries), self.num_heads)
  18.         keys = transpose_qkv(self.W_k(keys), self.num_heads)
  19.         values = transpose_qkv(self.W_v(values), self.num_heads)
  20.         if valid_lens is not None:
  21.             valid_lens = torch.repeat_interleave(
  22.                 valid_lens, repeats=self.num_heads, dim=0)
  23.         output = self.attention(queries, keys, values, valid_lens)
  24.         output_concat = transpose_output(output, self.num_heads)
  25.         return self.W_o(output_concat)
复制代码
代码剖析


  • 初始化

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表