深度学习中的知识蒸馏

嚴華  论坛元老 | 2025-2-18 17:08:32 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1049|帖子 1049|积分 3147

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为西席模型)的知识迁徙到小型、简单的模型(门生模型)中。通过这种方式,门生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。
知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储本钱,同时尽可能保持模型性能。

核心思想

知识蒸馏的核心思想是利用西席模型的输出(通常是软标签,即概率分布)来引导门生模型的训练。与传统的监督学习差别,知识蒸馏不仅使用真实标签(硬标签),还利用西席模型生成的软标签来通报更多的信息。
通过这种方式,门生模型不仅学习到数据的类别信息,还能够捕获到类别之间的相似性和关系,从而提拔其泛化本领。
步骤



  • 训练西席模型
    首先,训练一个大型、复杂的西席模型,使其在目标任务上到达较高的性能。
    西席模型可以是任何高性能的深度学习模型,如深层神经网络、Transformer等。
  • 生成软标签
    使用西席模型对训练数据举行推理,生成软标签(即概率分布)。
  • 训练门生模型
    门生模型在训练时,不仅使用真实标签,还使用西席模型生成的软标签作为额外的监督信号。
  • 优化与调解
    通过调解温度参数、损失函数权重等超参数,优化门生模型的性能,使其尽可能接近西席模型。
知识蒸馏的核心在于让门生模型不仅仅学习真实标签,还学习西席模型提供的软标签,即西席模型输出的概率分布。这种方式可以让门生模型获得更丰富的信息。
传统神经网络的交叉熵损失

在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:

 
其中:


  • 是真实类别的独热编码。
  • 是模型的预测概率,通常由 Softmax 变换得到。

 其中
是模型最后一层的 logit 值。
 
传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即
仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。
知识蒸馏的损失函数

在知识蒸馏中,西席模型提供了一种软标签(soft targets),即对全部类别的预测分布,而不仅仅是单个类别。
这些软标签由温度化 Softmax 得到。

 
其中:


  • 其中, Zi是第i类的未归一化分数(logits),T是温度系数, qi是颠末温度调解后的概率。
  • 较高的 T 值会使得概率分布更加平滑,保存更多类别之间的关系信息,从而提供更丰富的知识给门生模型。
在训练门生模型时,通常使用两部分损失函数:

  • 硬标签损失(传统的交叉熵损失)
    用于确保门生模型能够精确分类。

 
2.软标签损失(基于 Kullback-Leibler 散度的损失)
用于让门生模型学习西席模型的类别间关系。

 

其中,
是一个超参数,用于控制硬标签损失和软标签损失的相对重要性。
通过加权组合这两部分损失,可以平衡门生模型对硬标签和软标签的学习。

知识蒸馏的优势



  • 模型压缩:门生模型通常比西席模型小得多,得当在资源受限的设备上部署。
  • 性能保持:通过知识蒸馏,门生模型能够在保持较高性能的同时,显著减少计算资源和存储需求。
  • 泛化本领:软标签提供了更多的信息,有助于门生模型更好地泛化。
知识蒸馏的变种

除了标准的知识蒸馏方法,研究职员还提出了多个改进版本。

  • 自蒸馏(Self-Distillation):模型自身作为西席,将深层网络的知识蒸馏到浅层部分。
  • 多西席蒸馏(Multi-Teacher Distillation):多个西席模型团结引导门生模型,融合差别西席的知识。
  • 在线蒸馏(Online Distillation):西席模型和门生模型同步训练,而不是先训练西席模型再训练门生模型。
案例分享

下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个西席模型并将其知识蒸馏到门生模型。
这里,我们采用 MNIST 数据集,西席模型使用一个较大的神经网络,而门生模型是一个较小的神经网络。
首先,定义西席模型和门生模型。
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. from torchvision import datasets, transforms
  6. from torch.utils.data import DataLoader
  7. import matplotlib.pyplot as plt
  8. # 教师模型(较大的神经网络)
  9. class TeacherModel(nn.Module):
  10.     def __init__(self):
  11.         super(TeacherModel, self).__init__()
  12.         self.fc1 = nn.Linear(28 * 28, 512)
  13.         self.fc2 = nn.Linear(512, 256)
  14.         self.fc3 = nn.Linear(256, 10)
  15.     def forward(self, x):
  16.         x = x.view(-1, 28 * 28)
  17.         x = F.relu(self.fc1(x))
  18.         x = F.relu(self.fc2(x))
  19.         x = self.fc3(x)  # 注意这里没有 Softmax
  20.         return x
  21. # 学生模型(较小的神经网络)
  22. class StudentModel(nn.Module):
  23.     def __init__(self):
  24.         super(StudentModel, self).__init__()
  25.         self.fc1 = nn.Linear(28 * 28, 128)
  26.         self.fc2 = nn.Linear(128, 10)
  27.     def forward(self, x):
  28.         x = x.view(-1, 28 * 28)
  29.         x = F.relu(self.fc1(x))
  30.         x = self.fc2(x)  # 注意这里没有 Softmax
  31.         return x
复制代码
  1. # 数据预处理
  2. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  3. # 加载 MNIST 数据集
  4. train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
  5. test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
  6. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  7. test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
复制代码
训练西席模型
  1. def train_teacher(model, train_loader, epochs=5, lr=0.001):
  2.     optimizer = optim.Adam(model.parameters(), lr=lr)
  3.     criterion = nn.CrossEntropyLoss()
  4.     
  5.     for epoch in range(epochs):
  6.         model.train()
  7.         total_loss = 0
  8.         
  9.         for images, labels in train_loader:
  10.             optimizer.zero_grad()
  11.             output = model(images)
  12.             loss = criterion(output, labels)
  13.             loss.backward()
  14.             optimizer.step()
  15.             total_loss += loss.item()
  16.         
  17.         print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")
  18. # 初始化并训练教师模型
  19. teacher_model = TeacherModel()
  20. train_teacher(teacher_model, train_loader)
复制代码
知识蒸馏训练门生模型
  1. def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
  2.     """
  3.     计算蒸馏损失,结合知识蒸馏损失和交叉熵损失
  4.     """
  5.     soft_targets = F.softmax(teacher_logits / T, dim=1)  # 教师模型的软标签
  6.     soft_predictions = F.log_softmax(student_logits / T, dim=1)  # 学生模型的预测
  7.     
  8.     distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)
  9.     ce_loss = F.cross_entropy(student_logits, labels)
  10.     
  11.     return alpha * ce_loss + (1 - alpha) * distillation_loss
  12. def train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):
  13.     optimizer = optim.Adam(student_model.parameters(), lr=lr)
  14.     
  15.     teacher_model.eval()  # 设定教师模型为评估模式
  16.     for epoch in range(epochs):
  17.         student_model.train()
  18.         total_loss = 0
  19.         
  20.         for images, labels in train_loader:
  21.             optimizer.zero_grad()
  22.             student_logits = student_model(images)
  23.             with torch.no_grad():
  24.                 teacher_logits = teacher_model(images)  # 获取教师模型输出
  25.             
  26.             loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)
  27.             loss.backward()
  28.             optimizer.step()
  29.             total_loss += loss.item()
  30.         
  31.         print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")
  32. # 初始化学生模型
  33. student_model = StudentModel()
  34. train_student_with_distillation(student_model, teacher_model, train_loader)
复制代码
评估模型
  1. def evaluate(model, test_loader):
  2.     model.eval()
  3.     correct = 0
  4.     total = 0
  5.     
  6.     with torch.no_grad():
  7.         for images, labels in test_loader:
  8.             outputs = model(images)
  9.             _, predicted = torch.max(outputs, 1)
  10.             correct += (predicted == labels).sum().item()
  11.             total += labels.size(0)
  12.     
  13.     accuracy = 100 * correct / total
  14.     return accuracy
  15. # 评估教师模型
  16. teacher_acc = evaluate(teacher_model, test_loader)
  17. print(f"教师模型准确率: {teacher_acc:.2f}%")
  18. # 评估知识蒸馏训练的学生模型
  19. student_acc_distilled = evaluate(student_model, test_loader)
  20. print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

嚴華

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表