MOEFeedForward 模块

[复制链接]
发表于 2025-10-15 13:43:55 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
代码
  1. class FeedForward(nn.Module):
  2.     def __init__(self, config: LMConfig):
  3.         super().__init__()
  4.         if config.hidden_dim is None:
  5.             hidden_dim = 4 * config.dim
  6.             hidden_dim = int(2 * hidden_dim / 3)
  7.             config.hidden_dim = config.multiple_of * ((hidden_dim + config.multiple_of - 1) // config.multiple_of)
  8.         self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
  9.         self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
  10.         self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
  11.         self.dropout = nn.Dropout(config.dropout)
  12.     def forward(self, x):
  13.         return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
  14. class MoEGate(nn.Module):
  15.     def __init__(self, config: LMConfig):
  16.         super().__init__()
  17.         self.config = config
  18.         self.top_k = config.num_experts_per_tok
  19.         self.n_routed_experts = config.n_routed_experts
  20.         self.scoring_func = config.scoring_func
  21.         self.alpha = config.aux_loss_alpha
  22.         self.seq_aux = config.seq_aux
  23.         self.norm_topk_prob = config.norm_topk_prob
  24.         self.gating_dim = config.dim
  25.         self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
  26.         self.reset_parameters()
  27.     def reset_parameters(self) -> None:
  28.         import torch.nn.init as init
  29.         init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  30.     def forward(self, hidden_states):
  31.         bsz, seq_len, h = hidden_states.shape
  32.         hidden_states = hidden_states.view(-1, h)
  33.         logits = F.linear(hidden_states, self.weight, None)
  34.         if self.scoring_func == 'softmax':
  35.             scores = logits.softmax(dim=-1)
  36.         else:
  37.             raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
  38.         topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
  39.         if self.top_k > 1 and self.norm_topk_prob:
  40.             denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
  41.             topk_weight = topk_weight / denominator
  42.         if self.training and self.alpha > 0.0:
  43.             scores_for_aux = scores
  44.             aux_topk = self.top_k
  45.             topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
  46.             if self.seq_aux:
  47.                 scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
  48.                 ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
  49.                 ce.scatter_add_(1, topk_idx_for_aux_loss,
  50.                                 torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(
  51.                     seq_len * aux_topk / self.n_routed_experts)
  52.                 aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
  53.             else:
  54.                 mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
  55.                 ce = mask_ce.float().mean(0)
  56.                 Pi = scores_for_aux.mean(0)
  57.                 fi = ce * self.n_routed_experts
  58.                 aux_loss = (Pi * fi).sum() * self.alpha
  59.         else:
  60.             aux_loss = 0
  61.         return topk_idx, topk_weight, aux_loss
  62. class MOEFeedForward(nn.Module):
  63.     def __init__(self, config: LMConfig):
  64.         super().__init__()
  65.         self.config = config
  66.         self.experts = nn.ModuleList([
  67.             FeedForward(config)
  68.             for _ in range(config.n_routed_experts)
  69.         ])
  70.         self.gate = MoEGate(config)
  71.         if config.n_shared_experts is not None:
  72.             self.shared_experts = FeedForward(config)
  73.     def forward(self, x):
  74.         identity = x
  75.         orig_shape = x.shape
  76.         bsz, seq_len, _ = x.shape
  77.         # 使用门控机制选择专家
  78.         topk_idx, topk_weight, aux_loss = self.gate(x)
  79.         x = x.view(-1, x.shape[-1])
  80.         flat_topk_idx = topk_idx.view(-1)
  81.         if self.training:
  82.             # 训练模式下,重复输入数据
  83.             x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
  84.             y = torch.empty_like(x, dtype=torch.float16)
  85.             for i, expert in enumerate(self.experts):
  86.                 y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)  # 确保类型一致
  87.             y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
  88.             y = y.view(*orig_shape)
  89.         else:
  90.             # 推理模式下,只选择最优专家
  91.             y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
  92.         if self.config.n_shared_experts is not None:
  93.             y = y + self.shared_experts(identity)
  94.         self.aux_loss = aux_loss
  95.         return y
  96.     @torch.no_grad()
  97.     def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
  98.         expert_cache = torch.zeros_like(x)
  99.         idxs = flat_expert_indices.argsort()
  100.         tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
  101.         token_idxs = idxs // self.config.num_experts_per_tok
  102.         # 例如当tokens_per_expert=[6, 15, 20, 26, 33, 38, 46, 52]
  103.         # 当token_idxs=[3, 7, 19, 21, 24, 25,  4,  5,  6, 10, 11, 12...]
  104.         # 意味着当token_idxs[:6] -> [3,  7, 19, 21, 24, 25,  4]位置的token都由专家0处理,token_idxs[6:15]位置的token都由专家1处理......
  105.         for i, end_idx in enumerate(tokens_per_expert):
  106.             start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
  107.             if start_idx == end_idx:
  108.                 continue
  109.             expert = self.experts[i]
  110.             exp_token_idx = token_idxs[start_idx:end_idx]
  111.             expert_tokens = x[exp_token_idx]
  112.             expert_out = expert(expert_tokens).to(expert_cache.dtype)
  113.             expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
  114.             # 使用 scatter_add_ 进行 sum 操作
  115.             expert_cache.scatter_add_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out)
  116.         return expert_cache
复制代码
代码表明

表明一下这段代码的重要构成部门:

  • FeedForward 类:


  • 实现了一个根本的前馈网络
  • 使用 SwiGLU 激活函数(F.silu(self.w1(x)) * self.w3(x))
  • 包罗三个线性层(w1、w2、w3)和一个 dropout 层

  • MoEGate 类(门控机制):


  • 负责决定每个 token 应该由哪些专家处置惩罚
  • 重要步调:

    • 盘算每个 token 对应每个专家的分数(使用 softmax)
    • 选择 top-k 个最高分的专家
    • 盘算辅助丧失(aux_loss)来平衡专家的使用


  • MOEFeedForward 类(混淆专家体系):


  • 包罗多个专家(FeedForward)和一个门控网络(MoEGate)
  • 练习模式:

    • 使用门控网络选择每个 token 的专家
    • 将输入数据复制多份,分发给差别专家
    • 专家并行处置惩罚数据
    • 根据门控权重归并结果

  • 推理模式(moe_infer):

    • 对专家索引排序,将雷同专家的 token 批量处置惩罚
    • 使用 scatter_add_ 将专家输出累加到准确位置
    • 更高效的推理实现,制止了数据重复



  • 支持共享专家(n_shared_experts)
  • 实现了专家负载平衡(通过辅助丧失)
  • 支持每个 token 选择多个专家(num_experts_per_tok)
这是一个范例的 MoE(Mixture of Experts)实现,用于大型语言模子中进步模子容量和盘算服从。
示例
  1. # 创建 MoE 实例
  2. dim = 512                    # 输入维度
  3. n_routed_experts = 4         # 专家数量
  4. num_experts_per_tok = 2      # 每个token选择的专家数量
  5. moe = MOEFeedForward(
  6.     dim=dim,
  7.     n_routed_experts=n_routed_experts,
  8.     num_experts_per_tok=num_experts_per_tok,
  9.     hidden_dim=None,         # FFN隐藏层维度,None时自动计算
  10.     dropout=0.1             # dropout比率
  11. )
  12. # 创建示例输入
  13. batch_size = 2
  14. seq_len = 10
  15. x = torch.randn(batch_size, seq_len, dim)  # 形状: [2, 10, 512]
  16. moe(x)
复制代码
输出
  1. After gate - topk_idx.shape: torch.Size([20, 2]), topk_weight.shape: torch.Size([20, 2])
  2. After view - x.shape: torch.Size([20, 512]), flat_topk_idx.shape: torch.Size([40])
  3. After repeat_interleave - x.shape: torch.Size([40, 512])
  4. Empty y tensor shape: torch.Size([40, 512])
  5. Expert 0 - input shape: torch.Size([9, 512])
  6. Expert 0 - output shape: torch.Size([9, 512])
  7. Expert 1 - input shape: torch.Size([13, 512])
  8. Expert 1 - output shape: torch.Size([13, 512])
  9. Expert 2 - input shape: torch.Size([11, 512])
  10. Expert 2 - output shape: torch.Size([11, 512])
  11. Expert 3 - input shape: torch.Size([7, 512])
  12. Expert 3 - output shape: torch.Size([7, 512])
  13. Before view - y.shape: torch.Size([40, 512])
  14. topk_weight.shape: torch.Size([20, 2])
  15. After view and sum - y.shape: torch.Size([20, 512])
  16. Final y.shape: torch.Size([2, 10, 512])
复制代码
相应的torch函数
  1. import torch
  2. # empty: 创建未初始化的张量
  3. x = torch.empty((2, 3))  # 创建形状为 2x3 的未初始化张量
  4. # zeros_like: 创建与输入相同形状的全零张量
  5. a = torch.tensor([[1, 2], [3, 4]])
  6. b = torch.zeros_like(a)  # 创建形状为 2x2 的全零张量
  7. print(b)  # tensor([[0, 0], [0, 0]])
复制代码
  1. tensor([[0, 0],
  2.         [0, 0]])
复制代码
  1. import torch.nn.functional as F
  2. x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
  3. # view: 改变张量形状
  4. y = x.view(-1)  # 展平为一维
  5. print(y)  # tensor([1, 2, 3, 4, 5, 6, 7, 8])
  6. # -1 表示自动计算该维度大小
  7. z = x.view(-1, 2)  # 重塑为 4x2
  8. print(z)  # tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
复制代码
  1. tensor([1, 2, 3, 4, 5, 6, 7, 8])
  2. tensor([[1, 2],
  3.         [3, 4],
  4.         [5, 6],
  5.         [7, 8]])
复制代码
  1. # linear: 线性变换 y = xA^T + b
  2. input = torch.randn(2, 3)  # 2个样本,每个3维
  3. weight = torch.randn(4, 3)  # 输出4维
  4. output = F.linear(input, weight)  # 形状变为 [2, 4]
  5. # softmax: 将数值转换为概率分布
  6. logits = torch.tensor([1.0, 2.0, 3.0])
  7. probs = F.softmax(logits, dim=0)
  8. print(probs)  # tensor([0.0900, 0.2447, 0.6652])
复制代码
  1. tensor([0.0900, 0.2447, 0.6652])
复制代码
  1. # 找出最大的k个值及其索引
  2. x = torch.tensor([1, 5, 2, 8, 3])
  3. values, indices = torch.topk(x, k=2)
  4. print(values)   # tensor([8, 5])
  5. print(indices)  # tensor([3, 1])
复制代码
  1. tensor([8, 5])
  2. tensor([3, 1])
复制代码
  1. x = torch.tensor([1, 2, 3])
  2. # 每个元素重复2次
  3. y = x.repeat_interleave(2)
  4. print(y)  # tensor([1, 1, 2, 2, 3, 3])
复制代码
  1. tensor([1, 1, 2, 2, 3, 3])
复制代码
  1. # 统计每个数字出现的次数
  2. x = torch.tensor([1, 1, 2, 3, 1, 2])
  3. counts = x.bincount()
  4. print(counts)  # tensor([0, 3, 2, 1])
  5.   # 0出现0次,1出现3次,2出现2次,3出现1次
复制代码
  1. tensor([0, 3, 2, 1])
复制代码
  1. # 在指定位置累加值
  2. src = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)  # 指定数据类型为 float
  3. index = torch.tensor([[0, 1], [0, 1]])
  4. out = torch.zeros(2, 2, dtype=torch.float)  # 确保与 src 的数据类型相同
  5. out.scatter_add_(0, index, src)
  6. print(out)
复制代码
  1. tensor([[4., 0.],
  2.         [0., 6.]])
复制代码
  1. # 返回排序后的索引
  2. x = torch.tensor([3, 1, 4, 1, 5])
  3. indices = x.argsort()
  4. print(indices)  # tensor([1, 3, 0, 2, 4])
  5.   # 最小值在位置1和3,然后是0,2,4
复制代码
  1. tensor([1, 3, 0, 2, 4])
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表