yolov11剪枝

打印 上一主题 下一主题

主题 976|帖子 976|积分 2928

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
思路:yolov11中的C3k2与yolov8的c2f的不同,以是与之前yolov8剪枝有稍许不同;

后续:会将剪枝流程写全,以及增加蒸馏、注意力、改loss;

注意:
1.在代码105行修改pruning.get_threshold(yolo.model, 0.65),可以得到不同的剪枝率;
2.改代码放在训练代码同一页面下即可;
3.在最后修改文件夹地点来得到剪枝后的模型;
  1. from ultralytics import YOLO
  2. import torch
  3. from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect, C3k2
  4. from torch.nn.modules.container import Sequential
  5. import os
  6. # os.environ["CUDA_VISIBLE_DEVICES"] = "2"
  7. class PRUNE():
  8.     def __init__(self) -> None:
  9.         self.threshold = None
  10.     def get_threshold(self, model, factor=0.8):
  11.         ws = []
  12.         bs = []
  13.         for name, m in model.named_modules():
  14.             if isinstance(m, torch.nn.BatchNorm2d):
  15.                 w = m.weight.abs().detach()
  16.                 b = m.bias.abs().detach()
  17.                 ws.append(w)
  18.                 bs.append(b)
  19.                 print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
  20.                 print()
  21.         # keep
  22.         ws = torch.cat(ws)
  23.         self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
  24.     def prune_conv(self, conv1: Conv, conv2: Conv):
  25.         ## a. 根据BN中的参数,获取需要保留的index================
  26.         gamma = conv1.bn.weight.data.detach()
  27.         beta = conv1.bn.bias.data.detach()
  28.         keep_idxs = []
  29.         local_threshold = self.threshold
  30.         while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
  31.             keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
  32.             local_threshold = local_threshold * 0.5
  33.         n = len(keep_idxs)
  34.         # n = max(int(len(idxs) * 0.8), p)
  35.         print(n / len(gamma) * 100)
  36.         # scale = len(idxs) / n
  37.         ## b. 利用index对BN进行剪枝============================
  38.         conv1.bn.weight.data = gamma[keep_idxs]
  39.         conv1.bn.bias.data = beta[keep_idxs]
  40.         conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
  41.         conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
  42.         conv1.bn.num_features = n
  43.         conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
  44.         conv1.conv.out_channels = n
  45.         if isinstance(conv2, list) and len(conv2) > 3 and conv2[-1]._get_name() == "Proto":
  46.             proto = conv2.pop()
  47.             proto.cv1.conv.in_channels = n
  48.             proto.cv1.conv.weight.data = proto.cv1.conv.weight.data[:, keep_idxs]
  49.         ## c. 利用index对conv1进行剪枝=========================
  50.         if conv1.conv.bias is not None:
  51.             conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
  52.         ## d. 利用index对conv2进行剪枝=========================
  53.         if not isinstance(conv2, list):
  54.             conv2 = [conv2]
  55.         for item in conv2:
  56.             if item is None: continue
  57.             if isinstance(item, Conv):
  58.                 conv = item.conv
  59.             else:
  60.                 conv = item
  61.             if isinstance(item, Sequential):
  62.                 conv1 = item[0]
  63.                 conv = item[1].conv
  64.                 conv1.conv.in_channels = n
  65.                 conv1.conv.out_channels = n
  66.                 conv1.conv.groups = n
  67.                 conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs, :]
  68.                 conv1.bn.bias.data = conv1.bn.bias.data[keep_idxs]
  69.                 conv1.bn.weight.data = conv1.bn.weight.data[keep_idxs]
  70.                 conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
  71.                 conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
  72.                 conv1.bn.num_features = n
  73.             conv.in_channels = n
  74.             conv.weight.data = conv.weight.data[:, keep_idxs]
  75.     def prune(self, m1, m2):
  76.         if isinstance(m1, C3k2):  # C2f as a top conv
  77.             m1 = m1.cv2
  78.         if isinstance(m1, Sequential):
  79.             m1 = m1[1]
  80.         if not isinstance(m2, list):  # m2 is just one module
  81.             m2 = [m2]
  82.         for i, item in enumerate(m2):
  83.             if isinstance(item, C3k2) or isinstance(item, SPPF):
  84.                 m2[i] = item.cv1
  85.         self.prune_conv(m1, m2)
  86. def do_pruning(modelpath, savepath):
  87.     pruning = PRUNE()
  88.     ### 0. 加载模型
  89.     yolo = YOLO(modelpath)  # build a new model from scratch
  90.     pruning.get_threshold(yolo.model, 0.65)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。
  91.     ### 1. 剪枝c2f 中的Bottleneck
  92.     for name, m in yolo.model.named_modules():
  93.         if isinstance(m, Bottleneck):
  94.             pruning.prune_conv(m.cv1, m.cv2)
  95.     ### 2. 指定剪枝不同模块之间的卷积核
  96.     seq = yolo.model.model
  97.     for i in [3, 5, 7, 8]:
  98.         pruning.prune(seq[i], seq[i + 1])
  99.     ### 3. 对检测头进行剪枝
  100.     # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
  101.     # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1]
  102.     # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2]
  103.     detect: Detect = seq[-1]
  104.     proto = detect.proto
  105.     last_inputs = [seq[16], seq[19], seq[22]]
  106.     colasts = [seq[17], seq[20], None]
  107.     for idx, (last_input, colast, cv2, cv3, cv4) in enumerate(zip(last_inputs, colasts, detect.cv2, detect.cv3, detect.cv4)):
  108.         if idx == 0:
  109.             pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0], proto])
  110.         else:
  111.             pruning.prune(last_input, [colast, cv2[0], cv3[0], cv4[0]])
  112.         pruning.prune(cv2[0], cv2[1])
  113.         pruning.prune(cv2[1], cv2[2])
  114.         pruning.prune(cv3[0], cv3[1])
  115.         pruning.prune(cv3[1], cv3[2])
  116.         pruning.prune(cv4[0], cv4[1])
  117.         pruning.prune(cv4[1], cv4[2])
  118.     ### 4. 模型梯度设置与保存
  119.     for name, p in yolo.model.named_parameters():
  120.         p.requires_grad = True
  121.     yolo.val(data='data.yaml', batch=2, device=0, workers=0)
  122.     torch.save(yolo.ckpt, savepath)
  123.     # yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))
  124.     # yolo.export(format="onnx")
  125.     #
  126.     # ## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小
  127.     # yolo = YOLO(modelpath)  # build a new model from scratch
  128.     # yolo.export(format="onnx")
  129. if __name__ == "__main__":
  130.     modelpath = "runs/segment/Constraint/weights/best.pt"
  131.     savepath = "runs/segment/Constraint/weights/last_prune.pt"
  132.     do_pruning(modelpath, savepath)
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

魏晓东

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表