模型压缩——训练后剪枝

打印 上一主题 下一主题

主题 1766|帖子 1766|积分 5298

1. 弁言

前文基于粒度的剪枝中主要是基于一个权重矩阵来介绍差别粒度下的剪枝方法,本文会介绍如何对一个实际的神经网络模型来实施剪枝操纵。
剪枝是利用稀疏性来压缩模型的,卷积神经网络(CNN)每每具有较高的参数冗余性,冗余参数被剪除后,每每不会显著影响整体性能。因此,模型剪枝在图像处理范畴的应用较为广泛。
   相比图像模型来说,语言模型的上下文依赖特性使得模型性能对参数的敏感度较高,某些参数的删除可能会影响到模型的整体表现,所以用稀疏化剪枝在语言模型中的应用相对较少。
  本文我们将以一个经典的卷积神经网络LeNet来例,来介绍模型剪枝操纵的具体使用。
2. 模型介绍

LeNet 是一种经典的卷积神经网络,由 Yann LeCun 等人于 1998 年提出,主要用于手写数字识别(Minst数据集)。它的网络结构如下:
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class LeNet(nn.Module):
  5.     def __init__(self, num_classes=10):
  6.         super().__init__()
  7.         self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
  8.         self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
  9.         self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
  10.         self.fc1 = nn.Linear(in_features=16*4*4, out_features=120)
  11.         self.fc2 = nn.Linear(in_features=120, out_features=84)
  12.         self.fc3 = nn.Linear(in_features=84, out_features=num_classes)
  13.     def forward(self, x):
  14.         x = self.maxpool(F.relu(self.conv1(x)))  # 14x14x6
  15.         x = self.maxpool(F.relu(self.conv2(x)))  # 5x5x16
  16.         x = x.view(x.size()[0], -1)              #
  17.         x = F.relu(self.fc1(x))
  18.         x = F.relu(self.fc2(x))
  19.         return self.fc3(x)
复制代码
模型结构分析:


  • 卷积层1:输入层吸收一个32x32的单通道图像,应用6个5x5的卷积核后,得到一个28x28x6的特征图(每个卷积核都会得到一个28x28的特征图),卷积层有助于捕捉图像中的边沿、纹理等局部特征;
  • 池化1:用ReLu函数对卷积层的输出结果举行激活后,应用一个2x2的最大池化操纵,特征图变为14x14x6;       最大池化是一种在小窗口内选择最大值的操纵,例如2x2池化就是在2x2的区域内选择一个值最大的元素来取代这个区域,如许就淘汰特征图的尺寸和参数数量,可以防止过拟合,同时生存最重要的特征。
  • 卷积层2:应用16个5x5的卷积核后,得到一个10x10x16的特征图,进一步提取更高条理的特征和更复杂的模式。
  • 池化2:再次使用ReLu激活,并应用2x2的池化操纵,池化后特征图变为5x5x16;
  • 展平:使用view操纵将池化后的特征图展平成一维向量,形状变为1x400;
  • 全连接层fc1、fc2、fc3:负责将卷积层提取的局部特征整合起来,形玉成局的高级特征;
  1. import numpy as np
  2. import random
  3. random.seed(0)
  4. np.random.seed(0)
  5. torch.manual_seed(0)
  6. def count_parameters(model: nn.Module):
  7.     return sum([param.numel() for param in model.parameters()])
  8. model = LeNet()
  9. print("parameters num:", count_parameters(model))
  10. print("model structure:", model)
复制代码
  1. parameters num: 44426
  2. model structure:
  3. LeNet(
  4.     (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  5.     (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  6.     (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  7.     (fc1): Linear(in_features=256, out_features=120, bias=True)
  8.     (fc2): Linear(in_features=120, out_features=84, bias=True)
  9.     (fc3): Linear(in_features=84, out_features=10, bias=True)
  10.   )
复制代码
我们的主要目标是演示剪枝操纵,所以就不再做模型的训练操纵,直接加载训练好的模型参数。
  1. model.load_state_dict(torch.load("./checkpoint/model.pt"))
复制代码
查看此时模型的稀疏度:
  1. def get_model_sparsity(model: nn.Module) -> float:
  2.     num_nonzeros, num_params = 0, 0
  3.     for param in model.parameters():
  4.         num_nonzeros += param.count_nonzero()
  5.         num_params += param.numel()
  6.     return 1 - float(num_nonzeros) / num_params
  7. get_model_sparsity(model)
复制代码
  1.     0.0
复制代码
此时还没有剪枝,所以模型的稀疏度为0,表现模型所有参数都是有效的。
3. 加载数据集

Minst数据集包含70000个手写数字(0-9)图像,其中有60000个训练样本和10000个测试样本,每张图像都是一个灰度图像,分辨率为 28x28 像素。
先使用datasets.MNIST类来加载数据集,其中:


  • download=True 参数会查抄root指定的目次下是否已经存在 MNIST 数据集文件,如果不存在,它会主动从互联网上下载并解压,如果已经存在,则会直接使用现有的数据集文件。
  • torchvision.transforms 模块用于对数据举行预处理操纵,transforms.ToTensor() 用于将 PIL 图片或 NumPy 数组转换为 PyTorch 张量(Tensor)。transforms.Normalize的作用是将张量从 [0, 255]的像素主动归一化到 均值为0.1307,标准差为0.3081的正态分布上。
  1. from torchvision.transforms import *
  2. from torchvision import datasets
  3. # 设置归一化
  4. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  5. # 获取数据集,train=True训练集,=False测试集
  6. train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
  7. test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
复制代码
  1.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
  2.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
  3.     100%|██████████| 9912422/9912422 [00:11<00:00, 860380.13it/s]
  4.     Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
  5.    
  6.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
  7.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
  8.     100%|██████████| 28881/28881 [00:00<00:00, 124555.87it/s]
  9.     Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
  10.    
  11.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
  12.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
  13.     100%|██████████| 1648877/1648877 [00:03<00:00, 484979.62it/s]
  14.     Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
  15.    
  16.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
  17.     Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  18.     100%|██████████| 4542/4542 [00:00<00:00, 1517365.89it/s]
  19.     Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
复制代码
那这个数据集究竟长什么样呢?我们可以用matplotlib库将这些图像的原始内容显示出来,以便直观察看。
  1. import matplotlib.pyplot as plt
  2. def show_image(dataset, rows=1):
  3.     # 设置可视化的画布大小  
  4.     plt.figure(figsize=(10, 10))  
  5.     # 从数据集中提取rows * 10个样本  
  6.     num_samples = rows * 10
  7.     indices = np.random.choice(len(dataset), num_samples, replace=False)  
  8.     sample_dataset = [dataset[i] for i in indices]
  9.    
  10.     # 显示随机样本  
  11.     for i, (image, label) in enumerate(sample_dataset):   
  12.         # 转换为 numpy 数组并去除通道信息  
  13.         image = image.numpy().squeeze()  # (C, H, W) -> (H, W)  
  14.         plt.subplot(rows, 10, i + 1)  
  15.         plt.imshow(image, cmap='gray')  
  16.         plt.title(f'Label: {label}', fontsize=12)  
  17.         plt.axis('off')  
  18.    
  19.     plt.tight_layout()  
  20.     plt.show()
  21. show_image(train_dataset, 1)
复制代码

可以看到,每个图像都与一个标签label(0 到 9 的数字)相干联,表现图像中显示的数字。
4. 剪枝前评估

为了在剪枝前对模型的性能预先有一个了解,我们会测试数据集对AlexNet模型举行一个准确率测试。
先实现一个evaluate评估方法,逻辑大概如下。


  • 用模型model对输入的批量数据inputs作分类猜测,得到所有分类可能性的数值logits,并用argmax取可能性最大的值作为猜测分类结果outputs。
  • 将猜测分类结果outputs和目标分类targets举行比对,统计猜测正确的数量num_correct。
  • 最后,盘算正确猜测数量num_correct与总数量num_samples的比值,得到准确率。
  1. from tqdm.auto import tqdm
  2. def evaluate(model, dataloader):
  3.     model.eval()
  4.     num_samples, num_correct = 0, 0
  5.     for inputs, targets in tqdm(dataloader, desc="eval"):
  6.         logits = model(inputs)
  7.         outputs = logits.argmax(dim=1)
  8.         num_samples += targets.size(0)
  9.         num_correct += (outputs == targets).sum()
  10.     return (num_correct / num_samples * 100).item()        
复制代码
将数据集封装为小批量数据加载器,批量大小batch_size设为64,shuffle=True表现对数据集举行随机性打乱顺序。
  1. from torch.utils.data import DataLoader
  2. # 设置DataLoader
  3. batch_size = 64
  4. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  5. test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
复制代码
对测试数据集上运行评估方法,得到准确率。
  1. origin_accuracy = evaluate(model, test_loader)
  2. origin_accuracy
复制代码
  1.     97.98999786376953
复制代码
5. 剪枝

这一部分我们将实现一个基于稀疏度的剪枝,例如:稀疏度0.8表现剪掉张量中80%的参数。
在剪枝前,首先必要盘算一个掩码,它决定了哪些权重必要被剪枝,哪些权重必要被生存。
  1. def calc_mask(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
  2.     sparsity = min(max(0.0, sparsity), 1.0)
  3.     if sparsity == 1.0:
  4.         return torch.zeros_like(tensor)
  5.     elif sparsity == 0.0:
  6.         return torch.ones_like(tensor)
  7.     # 计算张量中的元素总数和需要置零的个数
  8.     num_elements = tensor.numel()
  9.     num_zeros = round(num_elements * sparsity)
  10.     # 计算每个元素的绝对值,作为重要性度量
  11.     importance = tensor.abs()
  12.     # 根据需要置零的元素数量找到相应分位阀值
  13.     threshold = importance.view(-1).kthvalue(num_zeros).values
  14.     # 计算掩码:将大于阀值的置为1,小于阀值的置为0
  15.     mask = torch.gt(importance, threshold)
  16.     return mask
复制代码
创建一个剪枝器类来管理剪枝操纵,其中:


  • 在构造方法中,完成了模型每层掩码的盘算,并缓存在masks变量中;
  • 在prune方法中,在一个新克隆的模型实例上,对每一层权重应用剪枝掩码,得到剪枝后的张量。
  1. import copy
  2. class SparsePruner:
  3.     def __init__(self, model, sparsity_dict):
  4.         masks = dict()
  5.         for name, param in model.named_parameters():
  6.             if param.dim() > 1:
  7.                 masks[name] = calc_mask(param, sparsity_dict[name])
  8.         self.masks = masks
  9.     @torch.no_grad()
  10.     def prune(self, model):
  11.         # to_prune_model = copy.deepcopy(self.origin_model)
  12.         for name, param in model.named_parameters():
  13.             if name in self.masks:
  14.                 param.mul_(self.masks[name])
  15.         return model
复制代码
在sparsity_dict字典中定义每一层的稀疏度,并调用剪枝方法。
  1. sparsity_dict = {
  2.     'conv1.weight': 0.85,
  3.     'conv2.weight': 0.8,
  4.     'fc1.weight': 0.75,
  5.     'fc2.weight': 0.7,
  6.     'fc3.weight': 0.8,
  7. }
  8. pruner = SparsePruner(model, sparsity_dict)
  9. pruned_model = pruner.prune(copy.deepcopy(model))
  10. print(f"model sparsity after prune: {get_model_sparsity(pruned_model):.4f}")
复制代码
  1.     model sparsity after prune: 0.7387
复制代码
可以看到,剪枝后模型的稀疏度为0.7387, 这表现模型中有73.87%的参数都被置为了0,相应的模型大小已经变成原来的26.13%。
下面评估下剪枝操纵给模型的准确率带来有多洪流平的影响。
  1. accuracy_after_prune = evaluate(pruned_model, test_loader)
  2. accuracy_after_prune
复制代码
  1.     66.45999908447266
复制代码
剪枝后,模型的准确率从97.99% 降落到了 66.46%。
6. 微调

这部分我们将对上面剪枝后的模型举行微调,目标是尽可能将模型性能恢复到靠近剪枝前的程度。
首先,写一个训练函数train,功能是在指定的数据集上完成一轮训练。
  1. def train(model, dataloader, loss_fn, optimizer, pruner):
  2.     model.train()
  3.     for inputs, targets in tqdm(dataloader, desc="train"):
  4.         optimizer.zero_grad()  
  5.         
  6.         logits = model(inputs)
  7.         loss = loss_fn(logits, targets)
  8.         loss.backward()
  9.         optimizer.step()
  10.         
  11.         pruner.prune(model)
复制代码
  

  • optimizer.zero_grad():用于在每个小批量迭代前清零梯度;
  • model(inputs):前向传播完成分类猜测;
  • loss_fn:盘算猜测结果outputs与目标结果targets之间的损失;
  • loss.backward(): 损失反向传播盘算每层的梯度;
  • optimizer.step(): 根据梯度来更新权重参数值;
  • pruner.prune(model): 对模型参数举行剪枝,始终保证训练期间模型参数的稀疏度;
  1. num_epochs = 5
  2. optimizer = torch.optim.SGD(pruned_model.parameters(),  lr=0.01, momentum=0.5)
  3. loss_fn = nn.CrossEntropyLoss()  
  4. best_pruned_model_checkpoint = None
  5. best_accuracy = 0
  6. for i in range(num_epochs):
  7.     train(pruned_model, train_loader, loss_fn, optimizer, pruner)
  8.     accuracy = evaluate(pruned_model, test_loader)
  9.     if accuracy > best_accuracy:
  10.         best_accuracy = accuracy
  11.         best_pruned_model_checkpoint = copy.deepcopy(pruned_model.state_dict())
  12.     print(f"epoch: {i+1}, accuracy: {accuracy:.2f}%, best_accuracy: {best_accuracy:.2f}%")
  13.    
复制代码

通过微调,我们将剪枝后模型的准确率从66.46%恢复到了98.31%,比剪枝前97.99%还要略高一点,到达了预期目标。
小结:本文以一个比力经典的LeNet卷积神经网络作为开始,介绍了一种训练后剪枝的实施过程,通过先剪枝并评估性能损失,再通过微调来恢复模型性能。这个网络比力简单,所以只举行了一轮剪枝-微调步调,实际场景中对于参数量大的模型,可能必要用迭代剪枝的方法循环多次剪枝-微调步调,以便让剪枝的影响和结果更为可控。
参考阅读



  • 模型剪枝的粒度
  • 模型压缩概览

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

熊熊出没

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表