yolov8(目标检测、图像分割、关键点检测)知识蒸馏:logit和feature-based ...

打印 上一主题 下一主题

主题 808|帖子 808|积分 2424

1.知识蒸馏的原理

在目标检测中,知识蒸馏的原理紧张是利用教师模型(通常是大型的深度神经网络)的丰富知识来指导门生模型(轻量级的神经网络)的学习过程。通过蒸馏,门生模型能够在保持较高性能的同时,减小模型的复杂度和计算本钱。
知识蒸馏实现的方式有多种,但焦点目标是将教师模型学习到的知识迁移到门生中去(通常是通过各种损失函数举行实现)。

本项目支持yolov8检测、分割、关键点任务的知识蒸馏,并对蒸馏代码举行详解,比较容易上手。蒸馏方式多种,支持 logit和 feature-based蒸馏以及在线蒸馏。:
2.logit 蒸馏原理

Logit蒸馏原理紧张基于深度学习中的知识迁移技能,特别是在模型压缩和加速范畴。其焦点思想是利用大型、复杂的教师模型(Teacher Model)的logits(逻辑层的原始输出得分)来指导小型、轻量的门生模型(Student Model)的学习。
Logits是教师模型在做出终极决议之前的原始得分,这些得分在数值上表示了模型对每个类别的预测置信度。相较于终极的类别概率分布,logits包含了更丰富的信息,尤其是当不同类别之间存在细微差异时。
在Logit蒸馏过程中,教师模型的logits被用作额外的监督信号来训练门生模型。通过最小化教师模型和门生模型在logits层面上的差异(通常使用均方误差MSE或KL散度等损失函数),可以使门生模型学习到教师模型在决议边界附近的细致区分本领。这种蒸馏方式有助于提升门生模型在保持较高性能的同时,减小模型的复杂度和计算本钱。
逻辑蒸馏损失定义的代码在:ultralytics/utils/distill_loss.py
  1. class Distill_LogitLoss:
  2.     def __init__(self,p, t_p, alpha =0.25):
  3.         
  4.         t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
  5.         self.p =p
  6.         self.t_p = t_p
  7.         self.logit_loss = t_ft([0])
  8.         self.DLogitLoss = nn.MSELoss(reduction="none")
  9.         self.bs = p[0].shape[0]
  10.         self.alpha = alpha
  11.    
  12.     def __call__(self):
  13.         # per output
  14.         assert len(self.p) == len(self.t_p)
  15.         for i, (pi,t_pi) in enumerate(zip(self.p,self.t_p)):  # layer index, layer predictions
  16.             assert pi.shape == t_pi.shape
  17.             self.logit_loss += torch.mean(self.DLogitLoss(pi, t_pi))
  18.         return self.logit_loss[0]*self.alpha
复制代码

3.feature-base蒸馏原理

Feature-based蒸馏原理是知识蒸馏中的一种紧张方法,其关键在于利用教师模型的潜伏层特征来指导门生模型的学习过程。这种蒸馏方式旨在使门生模型能够学习到教师模型在特征提取和表示方面的本领,从而提升其性能。
具体来说,Feature-based蒸馏通过比较教师模型和门生模型在某一或多个潜伏层的特征表示来实现知识的迁移。在训练过程中,教师模型的潜伏层特征被提取出来,并作为监督信号来指导门生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使门生模型逐渐逼近教师模型的特征表示本领。
这种蒸馏方式有几个明显的上风。起首,它充分利用了教师模型在特征提取方面的上风,资助门生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导门生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏计谋,以进一步提升模型性能。
必要注意的是,在选择举行Feature-based蒸馏的潜伏层时,必要谨慎思量。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层举行蒸馏对于终极效果至关紧张。此外,蒸馏过程中的损失函数和权重设置也必要根据具体任务和数据集举行调整。
综上所述,Feature-based蒸馏原理是通过利用教师模型的潜伏层特征来指导门生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习范畴具有广泛的应用远景,尤其在必要进步模型特征提取本领的场景中体现出色。
本文将给出3种feature-base的蒸馏损失方法,代码分别如下


  • MimicLoss
  1. class MimicLoss(nn.Module):
  2.     def __init__(self, channels_s, channels_t):
  3.         super(MimicLoss, self).__init__()
  4.         device = 'cuda' if torch.cuda.is_available() else 'cpu'
  5.         self.mse = nn.MSELoss()
  6.     def forward(self, y_s, y_t):
  7.         """Forward computation.
  8.         Args:
  9.             y_s (list): The student model prediction with
  10.                 shape (N, C, H, W) in list.
  11.             y_t (list): The teacher model prediction with
  12.                 shape (N, C, H, W) in list.
  13.         Return:
  14.             torch.Tensor: The calculated loss value of all stages.
  15.         """
  16.         assert len(y_s) == len(y_t)
  17.         losses = []
  18.         for idx, (s, t) in enumerate(zip(y_s, y_t)):
  19.             assert s.shape == t.shape
  20.             losses.append(self.mse(s, t))
  21.         loss = sum(losses)
  22.         return loss
复制代码


  • CWDLoss
  1. class CWDLoss(nn.Module):
  2.     """PyTorch version of `Channel-wise Distillation for Semantic Segmentation.
  3.     <https://arxiv.org/abs/2011.13256>`_.
  4.     """
  5.     def __init__(self, channels_s, channels_t,tau=1.0):
  6.         super(CWDLoss, self).__init__()
  7.         self.tau = tau
  8.     def forward(self, y_s, y_t):
  9.         """Forward computation.
  10.         Args:
  11.             y_s (list): The student model prediction with
  12.                 shape (N, C, H, W) in list.
  13.             y_t (list): The teacher model prediction with
  14.                 shape (N, C, H, W) in list.
  15.         Return:
  16.             torch.Tensor: The calculated loss value of all stages.
  17.         """
  18.         assert len(y_s) == len(y_t)
  19.         losses = []
  20.         for idx, (s, t) in enumerate(zip(y_s, y_t)):
  21.             assert s.shape == t.shape
  22.             
  23.             N, C, H, W = s.shape
  24.             
  25.             # normalize in channel diemension
  26.             softmax_pred_T = F.softmax(t.view(-1, W * H) / self.tau, dim=1)  # [N*C, H*W]
  27.             
  28.             logsoftmax = torch.nn.LogSoftmax(dim=1)
  29.             cost = torch.sum(
  30.                 softmax_pred_T * logsoftmax(t.view(-1, W * H) / self.tau) -
  31.                 softmax_pred_T * logsoftmax(s.view(-1, W * H) / self.tau)) * (self.tau ** 2)
  32.             losses.append(cost / (C * N))
  33.         loss = sum(losses)
  34.         return loss
复制代码


  • MGDLoss
  1. class MGDLoss(nn.Module):
  2.     def __init__(self, channels_s, channels_t, alpha_mgd=0.00002, lambda_mgd=0.65):
  3.         super(MGDLoss, self).__init__()
  4.         device = 'cuda' if torch.cuda.is_available() else 'cpu'
  5.         self.alpha_mgd = alpha_mgd
  6.         self.lambda_mgd = lambda_mgd
  7.         self.generation = [
  8.             nn.Sequential(
  9.                 nn.Conv2d(channel, channel, kernel_size=3, padding=1),
  10.                 nn.ReLU(inplace=True),
  11.                 nn.Conv2d(channel, channel, kernel_size=3, padding=1)).to(device) for channel in channels_t
  12.         ]
  13.     def forward(self, y_s, y_t):
  14.         """Forward computation.
  15.         Args:
  16.             y_s (list): The student model prediction with
  17.                 shape (N, C, H, W) in list.
  18.             y_t (list): The teacher model prediction with
  19.                 shape (N, C, H, W) in list.
  20.         Return:
  21.             torch.Tensor: The calculated loss value of all stages.
  22.         """
  23.         assert len(y_s) == len(y_t)
  24.         losses = []
  25.         for idx, (s, t) in enumerate(zip(y_s, y_t)):
  26.             assert s.shape == t.shape
  27.             losses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)
  28.         loss = sum(losses)
  29.         return loss
  30.     def get_dis_loss(self, preds_S, preds_T, idx):
  31.         loss_mse = nn.MSELoss(reduction='sum')
  32.         N, C, H, W = preds_T.shape
  33.         device = preds_S.device
  34.         mat = torch.rand((N, 1, H, W)).to(device)
  35.         mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)
  36.         masked_fea = torch.mul(preds_S, mat)
  37.         new_fea = self.generation[idx](masked_fea)
  38.         dis_loss = loss_mse(new_fea, preds_T) / N
  39.         return dis_loss
复制代码
以上三种feature-based的蒸馏损失,其中MimicLoss是最常见的特征蒸馏损失,而MGD和CWD是当前的SOTA特征蒸馏方案。
4.yolov8 蒸馏代码实现

(1)蒸馏参数的设置

将以下代码放置在ultralytics\engine\trainer.py文件种142行位置处
  1.         self.dfea_loss = 0                           # feature distill loss
  2.         self.dlogit_loss = 0                         # logit distill loss
  3.         self.loss_t = 0                              # teacher model distill online loss
  4.         self.distill_loss =None
  5.         self.model_t = overrides.get("model_t",None)
  6.         self.distill_feat_type = "cwd"               # "cwd","mgd","mimic"
  7.         self.distill_online = True                   # False or True
  8.         self.logit_loss =  True                      # False or True
  9.         
  10.         #self.distill_layers =  [6,8,12,15,18,21]      # distill layers
  11.         self.distill_layers = [2,4,6,8,12,15,18,21]
  12.         # self.distill_layers = [15,18,21]
  13.         # self.model_t: 获取蒸馏训练的教师模型,如果在训练模型时,没传入model_t, 则不会进行蒸馏训练,只进行一般的模型训练
  14.         # self.distill_feat_type: 设置feature - based蒸馏的类型,支持"cwd", "mgd", "mimic", 任意一种
  15.         # self.distill_online: 设置是否使用在线蒸馏, 默认为False即离线蒸馏,你也可以设置为True
  16.         # self.logit_loss: 设置是否使用logit蒸馏
  17.         # self.distill_layers: 设置特征蒸馏的层数,可根据需要选择需要蒸馏的特征层
复制代码

   (2)蒸馏损失代码实现

新建ultralytics/utils/distill_loss.py文件,并将以上有关蒸馏损失放置在其中(完整代码可关注博主加私信获取)

(3) 优化器optimizer修改

(完整代码可关注博主加私信获取。获取后直接更换trainer.py即可)代码在ultralytics/engine/trainer.py的build_optimizer函数中
将如下的300行左右build_optimizer,按下图举行修改

build_optimizer函数内容如下
  1.     def build_optimizer(self, model, model_t,distill_loss,distill_online=False,name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
  2.         """
  3.         Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
  4.         weight decay, and number of iterations.
  5.         Args:
  6.             model (torch.nn.Module): The model for which to build an optimizer.
  7.             name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
  8.                 based on the number of iterations. Default: 'auto'.
  9.             lr (float, optional): The learning rate for the optimizer. Default: 0.001.
  10.             momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
  11.             decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
  12.             iterations (float, optional): The number of iterations, which determines the optimizer if
  13.                 name is 'auto'. Default: 1e5.
  14.         Returns:
  15.             (torch.optim.Optimizer): The constructed optimizer.
  16.         """
  17.         g = [], [], []  # optimizer parameter groups
  18.         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
  19.         if name == 'auto':
  20.             LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
  21.                         f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
  22.                         f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
  23.             nc = getattr(model, 'nc', 10)  # number of classes
  24.             lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
  25.             name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
  26.             self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for Adam
  27.         for v in model.modules():
  28.             if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias (no decay)
  29.                 g[2].append(v.bias)
  30.             if isinstance(v, bn):  # weight (no decay)
  31.                 g[1].append(v.weight)
  32.             elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
  33.                 g[0].append(v.weight)
  34.         
  35.         if model_t is not None and distill_online:
  36.             for v in model_t.modules():
  37.                 # print(v)
  38.                 if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias (no decay)
  39.                     g[2].append(v.bias)
  40.                 if isinstance(v, bn):  # weight (no decay)
  41.                     g[1].append(v.weight)
  42.                 elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
  43.                     g[0].append(v.weight)
  44.         
  45.         if model_t is not None and distill_loss is not None:
  46.             for k, v in distill_loss.named_modules():
  47.                 # print(v)
  48.                 if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias (no decay)
  49.                     g[2].append(v.bias)
  50.                 if isinstance(v, bn) or 'bn' in k:  # weight (no decay)
  51.                     g[1].append(v.weight)
  52.                 elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
  53.                     g[0].append(v.weight)
  54.         
  55.       
  56.         if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
  57.             optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
  58.         elif name == 'RMSProp':
  59.             optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
  60.         elif name == 'SGD':
  61.             optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
  62.         else:
  63.             raise NotImplementedError(
  64.                 f"Optimizer '{name}' not found in list of available optimizers "
  65.                 f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
  66.                 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
  67.         optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
  68.         optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
  69.         LOGGER.info(
  70.             f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
  71.             f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
  72.         return optimizer
复制代码
5.yolov8 蒸馏训练步骤

在项目中,教师模型model_t选择是yolov8l, 门生模型model_s,选择的是yolov8n
(1) 训练教师模型

  1. from ultralytics import YOLO
  2. data = r"ultralytics\datasets\coco128.yaml"
  3. model_t = YOLO(r'weights\yolov8l.pt')
  4. model_t.train(data=data, epochs=300, imgsz=640)
复制代码
(2) 训练门生模型baseline

  1. from ultralytics import YOLO
  2. data = r"ultralytics\datasets\coco128.yaml"
  3. model_s = YOLO(r'weights\yolov8n.pt')
  4. model_s.train(data=data, epochs=300, imgsz=640)
复制代码

(3) 蒸馏训练

将已经训练好的教师模型model_t的知识通过logit与feature-base知识蒸馏的方式迁移到门生模型model_s上,从而提升门生模型的性能。
  1. import torch
  2. from ultralytics import YOLO
  3. data = r"/home/xxx/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/ultralytics/cfg/datasets/coco128.yaml"
  4. model_t = YOLO(r'/home/xxx/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/weights/yolov8l.pt')
  5. model_t.model.model[-1].set_Distillation = True
  6. model_s = YOLO(r'/home/yuanwushui/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/yolov8n.pt')
  7. model_s.train(data=data, epochs=300, imgsz=640, model_t= model_t.model)
复制代码
如果传入了model_t,则会举行蒸馏训练,否则为平凡训练
注:feature-based蒸馏的类型设置(支持"cwd","mgd","mimic", 任意一种);设置是否使用在线蒸馏, (默以为False即离线蒸馏,你也可以设置为True);设置是否使用logit蒸馏;设置特征蒸馏的层数,(可根据必要选择必要蒸馏的特征层)。均在ultralytics/engine/trainer.py中的BaseTrainer类的初始化函数中__init__.py中举行设置。如下图

6.训练成功
注,以上全部代码均可关注博主,私信后获取,仅需在某些位置举行代码与文件的更换即可,基于你的代码改写后并不影响你的原始代码使用,是否开启蒸馏、开启什么样的蒸馏取决于你的参数设置

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

用户国营

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表