ToB企服应用市场:ToB评测及商务社交产业平台

标题: 知识蒸馏的蒸馏丧失方法代码总结(包罗:基于logits的方法:KLDiv,dist,dk [打印本页]

作者: 宝塔山    时间: 2024-8-17 12:59
标题: 知识蒸馏的蒸馏丧失方法代码总结(包罗:基于logits的方法:KLDiv,dist,dk
有三种知识蒸馏方法:
  1. 1.利用教师模型的输出概率(基于logits的方法)
  2. 2.利用教师模型的中间特征(基于提示的方法)
  3. 3.自蒸馏方法
复制代码
一.利用西席模子的输出概率(基于logits的方法)

该类方法丧失函数为:

1.1 DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.
  1. import torch.nn as nn
  2. def cosine_similarity(a, b, eps=1e-8):
  3.     return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)
  4. def pearson_correlation(a, b, eps=1e-8):
  5.     return cosine_similarity(a - a.mean(1).unsqueeze(1),
  6.                              b - b.mean(1).unsqueeze(1), eps)
  7. def inter_class_relation(soft_student_outputs, soft_teacher_outputs):
  8.     return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()
  9. def intra_class_relation(soft_student_outputs, soft_teacher_outputs):
  10.     return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))
  11. class DIST(nn.Module):
  12.     def __init__(self, beta=1.0, gamma=1.0, temp=1.0):
  13.         super(DIST, self).__init__()
  14.         self.beta = beta
  15.         self.gamma = gamma
  16.         self.temp = temp
  17.     def forward(self, student_preds, teacher_preds, **kwargs):
  18.         soft_student_outputs = (student_preds / self.temp).softmax(dim=1)
  19.         soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)
  20.         inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)
  21.         intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)
  22.         kd_loss = self.beta * inter_loss + self.gamma * intra_loss
  23.         return kd_loss
复制代码
1.2 KLDiv (2015年的原始方法)

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. # loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
  4. class KLDiv(nn.Module):
  5.     def __init__(self, temp=1.0):
  6.         super(KLDiv, self).__init__()
  7.         self.temp = temp
  8.     def forward(self, student_preds, teacher_preds, **kwargs):
  9.         soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
  10.         soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
  11.         kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
  12.         kd_loss *= self.temp ** 2
  13.         return kd_loss
复制代码
1.3 dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
  5.     gt_mask = _get_gt_mask(logits_student, target)
  6.     other_mask = _get_other_mask(logits_student, target)
  7.     pred_student = F.softmax(logits_student / temperature, dim=1)
  8.     pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
  9.     pred_student = cat_mask(pred_student, gt_mask, other_mask)
  10.     pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
  11.     log_pred_student = torch.log(pred_student)
  12.     tckd_loss = (
  13.             F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
  14.             * (temperature ** 2)
  15.     )
  16.     pred_teacher_part2 = F.softmax(
  17.         logits_teacher / temperature - 1000.0 * gt_mask, dim=1
  18.     )
  19.     log_pred_student_part2 = F.log_softmax(
  20.         logits_student / temperature - 1000.0 * gt_mask, dim=1
  21.     )
  22.     nckd_loss = (
  23.             F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')
  24.             * (temperature ** 2)
  25.     )
  26.     return alpha * tckd_loss + beta * nckd_loss
  27. def _get_gt_mask(logits, target):
  28.     target = target.reshape(-1)
  29.     mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
  30.     return mask
  31. def _get_other_mask(logits, target):
  32.     target = target.reshape(-1)
  33.     mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
  34.     return mask
  35. def cat_mask(t, mask1, mask2):
  36.     t1 = (t * mask1).sum(dim=1, keepdims=True)
  37.     t2 = (t * mask2).sum(1, keepdims=True)
  38.     rt = torch.cat([t1, t2], dim=1)
  39.     return rt
  40. class DKD(nn.Module):
  41.     def __init__(self, alpha=1., beta=2., temperature=1.):
  42.         super(DKD, self).__init__()
  43.         self.alpha = alpha
  44.         self.beta = beta
  45.         self.temperature = temperature
  46.     def forward(self, z_s, z_t, **kwargs):
  47.         target = kwargs['target']
  48.         if len(target.shape) == 2:  # mixup / smoothing
  49.             target = target.max(1)[1]
  50.         kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)
  51.         return kd_loss
复制代码
二.利用西席模子的中间表现(基于提示的方法)

该类方法丧失函数为:

2.1 ReviewKD (CVPR2021)

论文:
Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.
代码:
  1. https://github.com/dvlab-research/ReviewKD
复制代码
Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.
Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.
Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.
三.自蒸馏

ICCV2019:Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation
知识蒸馏之自蒸馏
https://www.xjx100.cn/news/1098187.html?action=onClick
关于知识蒸馏丧失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:
  1. https://zhuanlan.zhihu.com/p/603748226?utm_id=0
复制代码
待更新

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4