IT评测·应用市场-qidao123.com技术社区

标题: GPT - 多头注意力机制(Multi-Head Attention)模块 [打印本页]

作者: 光之使者    时间: 2025-4-12 17:22
标题: GPT - 多头注意力机制(Multi-Head Attention)模块
本节代码实现了一个多头注意力机制(Multi-Head Attention)模块,它是Transformer架构中的核心组件之一。
 

⭐关于多头自注意力机制的数学原理请见文章:
Transformer - 多头自注意力机制复现-CSDN博客
本节要求理解原理后手敲实现多头注意力机制
1. 初始化部门

  1. class MultiHeadAttention(nn.Module):
  2.     def __init__(self, d_model, num_heads, dropout):
  3.         super().__init__()
  4.         self.num_heads = num_heads
  5.         self.d_k = d_model // num_heads
  6.         self.q_project = nn.Linear(d_model, d_model)
  7.         self.k_project = nn.Linear(d_model, d_model)
  8.         self.v_project = nn.Linear(d_model, d_model)
  9.         self.o_project = nn.Linear(d_model, d_model)
  10.         self.dropout = nn.Dropout(dropout)
复制代码

2. 前向传播部门

  1. def forward(self, x, attn_mask=None):
  2.     batch_size, seq_len, d_model = x.shape
  3.     Q = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
  4.     K = self.k_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
  5.     V = self.v_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
复制代码

  1.     atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)
复制代码

  1.     if attn_mask is not None:
  2.         attn_mask = attn_mask.unsqueeze(1)
  3.         atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
复制代码

  1.     atten_scores = torch.softmax(atten_scores, dim=-1)
  2.     out = atten_scores @ V
复制代码

  1.     out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
  2.     out = self.o_project(out)
  3.     return self.dropout(out)
复制代码

需复现完整代码
  1. class MultiHeadAttention(nn.Module):
  2.     def __init__(self, d_model, num_heads, dropout):
  3.         super().__init__()
  4.         self.num_heads = num_heads
  5.         self.d_k = d_model // num_heads
  6.         self.q_project = nn.Linear(d_model, d_model)
  7.         self.k_project = nn.Linear(d_model, d_model)
  8.         self.v_project = nn.Linear(d_model, d_model)
  9.         self.o_project = nn.Linear(d_model, d_model)
  10.         self.dropout = nn.Dropout(dropout)    def forward(self, x, attn_mask=None):                batch_size, seq_len, d_model = x.shape        Q = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)        K = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)        V = self.q_project(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)        atten_scores = Q @ K.transpose(2, 3) / math.sqrt(self.d_k)        if attn_mask is not None:            attn_mask = attn_mask.unsqueeze(1)            atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)        atten_scores = torch.softmax(atten_scores, dim=-1)        out = atten_scores @ V        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)        out = self.o_project(out)        return self.dropout(out)
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4