IT评测·应用市场-qidao123.com

标题: 深度学习中的知识蒸馏 [打印本页]

作者: 嚴華    时间: 2025-2-18 17:08
标题: 深度学习中的知识蒸馏
知识蒸馏(Knowledge Distillation)是一种模型压缩技术,旨在将大型、复杂的模型(通常称为西席模型)的知识迁徙到小型、简单的模型(门生模型)中。通过这种方式,门生模型可以在保持较高性能的同时,显著减少计算资源和存储需求。
知识蒸馏广泛用于深度学习领域,尤其在计算资源有限的场景(如移动端设备、嵌入式设备)中,用于加速推理、减少存储本钱,同时尽可能保持模型性能。

核心思想

知识蒸馏的核心思想是利用西席模型的输出(通常是软标签,即概率分布)来引导门生模型的训练。与传统的监督学习差别,知识蒸馏不仅使用真实标签(硬标签),还利用西席模型生成的软标签来通报更多的信息。
通过这种方式,门生模型不仅学习到数据的类别信息,还能够捕获到类别之间的相似性和关系,从而提拔其泛化本领。
步骤


知识蒸馏的核心在于让门生模型不仅仅学习真实标签,还学习西席模型提供的软标签,即西席模型输出的概率分布。这种方式可以让门生模型获得更丰富的信息。
传统神经网络的交叉熵损失

在传统的神经网络训练中,我们通常用交叉熵损失(Cross-Entropy Loss)来训练分类模型:

 
其中:


 其中
是模型最后一层的 logit 值。
 
传统的交叉熵损失函数仅利用了数据的硬标签(hard labels),即
仅在真实类别处为 1,其他类别为 0,导致模型无法学习类别之间的相似性信息。
知识蒸馏的损失函数

在知识蒸馏中,西席模型提供了一种软标签(soft targets),即对全部类别的预测分布,而不仅仅是单个类别。
这些软标签由温度化 Softmax 得到。

 
其中:

在训练门生模型时,通常使用两部分损失函数:

 
2.软标签损失(基于 Kullback-Leibler 散度的损失)
用于让门生模型学习西席模型的类别间关系。

 

其中,
是一个超参数,用于控制硬标签损失和软标签损失的相对重要性。
通过加权组合这两部分损失,可以平衡门生模型对硬标签和软标签的学习。

知识蒸馏的优势


知识蒸馏的变种

除了标准的知识蒸馏方法,研究职员还提出了多个改进版本。
案例分享

下面是一个完整的知识蒸馏的示例代码,使用 PyTorch 训练一个西席模型并将其知识蒸馏到门生模型。
这里,我们采用 MNIST 数据集,西席模型使用一个较大的神经网络,而门生模型是一个较小的神经网络。
首先,定义西席模型和门生模型。
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torch.nn.functional as F
  5. from torchvision import datasets, transforms
  6. from torch.utils.data import DataLoader
  7. import matplotlib.pyplot as plt
  8. # 教师模型(较大的神经网络)
  9. class TeacherModel(nn.Module):
  10.     def __init__(self):
  11.         super(TeacherModel, self).__init__()
  12.         self.fc1 = nn.Linear(28 * 28, 512)
  13.         self.fc2 = nn.Linear(512, 256)
  14.         self.fc3 = nn.Linear(256, 10)
  15.     def forward(self, x):
  16.         x = x.view(-1, 28 * 28)
  17.         x = F.relu(self.fc1(x))
  18.         x = F.relu(self.fc2(x))
  19.         x = self.fc3(x)  # 注意这里没有 Softmax
  20.         return x
  21. # 学生模型(较小的神经网络)
  22. class StudentModel(nn.Module):
  23.     def __init__(self):
  24.         super(StudentModel, self).__init__()
  25.         self.fc1 = nn.Linear(28 * 28, 128)
  26.         self.fc2 = nn.Linear(128, 10)
  27.     def forward(self, x):
  28.         x = x.view(-1, 28 * 28)
  29.         x = F.relu(self.fc1(x))
  30.         x = self.fc2(x)  # 注意这里没有 Softmax
  31.         return x
复制代码
  1. # 数据预处理
  2. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  3. # 加载 MNIST 数据集
  4. train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
  5. test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
  6. train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  7. test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
复制代码
训练西席模型
  1. def train_teacher(model, train_loader, epochs=5, lr=0.001):
  2.     optimizer = optim.Adam(model.parameters(), lr=lr)
  3.     criterion = nn.CrossEntropyLoss()
  4.     
  5.     for epoch in range(epochs):
  6.         model.train()
  7.         total_loss = 0
  8.         
  9.         for images, labels in train_loader:
  10.             optimizer.zero_grad()
  11.             output = model(images)
  12.             loss = criterion(output, labels)
  13.             loss.backward()
  14.             optimizer.step()
  15.             total_loss += loss.item()
  16.         
  17.         print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")
  18. # 初始化并训练教师模型
  19. teacher_model = TeacherModel()
  20. train_teacher(teacher_model, train_loader)
复制代码
知识蒸馏训练门生模型
  1. def distillation_loss(student_logits, teacher_logits, labels, T=3.0, alpha=0.5):
  2.     """
  3.     计算蒸馏损失,结合知识蒸馏损失和交叉熵损失
  4.     """
  5.     soft_targets = F.softmax(teacher_logits / T, dim=1)  # 教师模型的软标签
  6.     soft_predictions = F.log_softmax(student_logits / T, dim=1)  # 学生模型的预测
  7.     
  8.     distillation_loss = F.kl_div(soft_predictions, soft_targets, reduction="batchmean") * (T ** 2)
  9.     ce_loss = F.cross_entropy(student_logits, labels)
  10.     
  11.     return alpha * ce_loss + (1 - alpha) * distillation_loss
  12. def train_student_with_distillation(student_model, teacher_model, train_loader, epochs=5, lr=0.001, T=3.0, alpha=0.5):
  13.     optimizer = optim.Adam(student_model.parameters(), lr=lr)
  14.     
  15.     teacher_model.eval()  # 设定教师模型为评估模式
  16.     for epoch in range(epochs):
  17.         student_model.train()
  18.         total_loss = 0
  19.         
  20.         for images, labels in train_loader:
  21.             optimizer.zero_grad()
  22.             student_logits = student_model(images)
  23.             with torch.no_grad():
  24.                 teacher_logits = teacher_model(images)  # 获取教师模型输出
  25.             
  26.             loss = distillation_loss(student_logits, teacher_logits, labels, T=T, alpha=alpha)
  27.             loss.backward()
  28.             optimizer.step()
  29.             total_loss += loss.item()
  30.         
  31.         print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")
  32. # 初始化学生模型
  33. student_model = StudentModel()
  34. train_student_with_distillation(student_model, teacher_model, train_loader)
复制代码
评估模型
  1. def evaluate(model, test_loader):
  2.     model.eval()
  3.     correct = 0
  4.     total = 0
  5.     
  6.     with torch.no_grad():
  7.         for images, labels in test_loader:
  8.             outputs = model(images)
  9.             _, predicted = torch.max(outputs, 1)
  10.             correct += (predicted == labels).sum().item()
  11.             total += labels.size(0)
  12.     
  13.     accuracy = 100 * correct / total
  14.     return accuracy
  15. # 评估教师模型
  16. teacher_acc = evaluate(teacher_model, test_loader)
  17. print(f"教师模型准确率: {teacher_acc:.2f}%")
  18. # 评估知识蒸馏训练的学生模型
  19. student_acc_distilled = evaluate(student_model, test_loader)
  20. print(f"知识蒸馏训练的学生模型准确率: {student_acc_distilled:.2f}%")
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/) Powered by Discuz! X3.4