分类任务实现模子集成代码模版

打印 上一主题 下一主题

主题 844|帖子 844|积分 2532

分类任务实现模子(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模子训练代码模板-CSDN博客,自定义投票式集成,手动实现模子集成(投票法)的代码。末了通过tensorboard举行可视化,对每个基学习器的性能举行对比,直观的看出模子集成的作用。
代码

  1. # -*- coding:utf-8 -*-
  2. import os
  3. import torch
  4. import torchvision
  5. import torchmetrics
  6. import torch.nn as nn
  7. import my_utils as utils
  8. import torchvision.transforms as transforms
  9. from torch.utils.tensorboard import SummaryWriter
  10. from torch.utils.data import DataLoader
  11. from torchensemble.utils import set_module
  12. from torchensemble.voting import VotingClassifier
  13. classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
  14. def get_args_parser(add_help=True):
  15.     import argparse
  16.     parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)
  17.     parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,
  18.                         help="dataset path")
  19.     parser.add_argument("--model", default="resnet8", type=str, help="model name")
  20.     parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
  21.     parser.add_argument(
  22.         "-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
  23.     )
  24.     parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")
  25.     parser.add_argument(
  26.         "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)"
  27.     )
  28.     parser.add_argument("--opt", default="SGD", type=str, help="optimizer")
  29.     parser.add_argument("--random-seed", default=42, type=int, help="random seed")
  30.     parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
  31.     parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
  32.     parser.add_argument(
  33.         "--wd",
  34.         "--weight-decay",
  35.         default=1e-4,
  36.         type=float,
  37.         metavar="W",
  38.         help="weight decay (default: 1e-4)",
  39.         dest="weight_decay",
  40.     )
  41.     parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")
  42.     parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
  43.     parser.add_argument("--print-freq", default=80, type=int, help="print frequency")
  44.     parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")
  45.     parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
  46.     parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
  47.     return parser
  48. def main():
  49.     args = get_args_parser().parse_args()
  50.     utils.setup_seed(args.random_seed)
  51.     args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  52.     device = args.device
  53.     data_dir = args.data_path
  54.     result_dir = args.output_dir
  55.     # ------------------------------------  log ------------------------------------
  56.     logger, log_dir = utils.make_logger(result_dir)
  57.     writer = SummaryWriter(log_dir=log_dir)
  58.     # ------------------------------------ step1: dataset ------------------------------------
  59.     normMean = [0.4948052, 0.48568845, 0.44682974]
  60.     normStd = [0.24580306, 0.24236229, 0.2603115]
  61.     normTransform = transforms.Normalize(normMean, normStd)
  62.     train_transform = transforms.Compose([
  63.         transforms.Resize(32),
  64.         transforms.RandomCrop(32, padding=4),
  65.         transforms.ToTensor(),
  66.         normTransform
  67.     ])
  68.     valid_transform = transforms.Compose([
  69.         transforms.ToTensor(),
  70.         normTransform
  71.     ])
  72.     # root变量下需要存放cifar-10-python.tar.gz 文件
  73.     # cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载
  74.     train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)
  75.     test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)
  76.     # 构建DataLoder
  77.     train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
  78.     valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)
  79.     # ------------------------------------ tep2: model ------------------------------------
  80.     model_base = utils.resnet20()
  81.     # model_base = utils.LeNet5()
  82.     model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,
  83.                        classes=classes, writer=writer, save_dir=log_dir)
  84.     model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)
  85.     model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)
  86. class MyEnsemble(VotingClassifier):
  87.     def __init__(self, **kwargs):
  88.         # logger, device, args, classes, writer
  89.         super(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])
  90.         self.logger = kwargs["logger"]
  91.         self.writer = kwargs["writer"]
  92.         self.device = kwargs["device"]
  93.         self.args = kwargs["args"]
  94.         self.classes = kwargs["classes"]
  95.         self.save_dir = kwargs["save_dir"]
  96.     @staticmethod
  97.     def save(model, save_dir, logger):
  98.         """Implement model serialization to the specified directory."""
  99.         if save_dir is None:
  100.             save_dir = "./"
  101.         if not os.path.isdir(save_dir):
  102.             os.mkdir(save_dir)
  103.         # Decide the base estimator name
  104.         if isinstance(model.base_estimator_, type):
  105.             base_estimator_name = model.base_estimator_.__name__
  106.         else:
  107.             base_estimator_name = model.base_estimator_.__class__.__name__
  108.         # {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}
  109.         filename = "{}_{}_{}_ckpt.pth".format(
  110.             type(model).__name__,
  111.             base_estimator_name,
  112.             model.n_estimators,
  113.         )
  114.         # The real number of base estimators in some ensembles is not same as
  115.         # `n_estimators`.
  116.         state = {
  117.             "n_estimators": len(model.estimators_),
  118.             "model": model.state_dict(),
  119.             "_criterion": model._criterion,
  120.         }
  121.         save_dir = os.path.join(save_dir, filename)
  122.         logger.info("Saving the model to `{}`".format(save_dir))
  123.         # Save
  124.         torch.save(state, save_dir)
  125.         return
  126.     def fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):
  127.         # 模型、优化器、学习率调整器、评估器 列表创建
  128.         estimators = []
  129.         for _ in range(self.n_estimators):
  130.             estimators.append(self._make_estimator())
  131.         optimizers = []
  132.         schedulers = []
  133.         for i in range(self.n_estimators):
  134.             optimizers.append(set_module.set_optimizer(estimators[i],
  135.                                                        self.optimizer_name, **self.optimizer_args))
  136.             scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],
  137.                                                               gamma=self.args.lr_gamma)  # 设置学习率下降策略
  138.             # scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,
  139.             #                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略
  140.             schedulers.append(scheduler_)
  141.         acc_metrics = []
  142.         for i in range(self.n_estimators):
  143.             # task类型与任务一致
  144.             # num_classes与分类任务的类别数一致
  145.             acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))
  146.         self._criterion = nn.CrossEntropyLoss()
  147.         # epoch循环迭代
  148.         best_acc = 0.
  149.         for epoch in range(epochs):
  150.             # training
  151.             for model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):
  152.                 loss_m_train, acc_m_train, mat_train = \
  153.                     utils.ModelTrainerEnsemble.train_one_epoch(
  154.                         train_loader, estimator, self._criterion, optimizer, scheduler, epoch,
  155.                         self.device, self.args, self.logger, self.classes)
  156.                 # 学习率更新
  157.                 scheduler.step()
  158.                 # 记录
  159.                 self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):
  160.                                                            loss_m_train.avg}, epoch)
  161.                 self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):
  162.                                                                acc_m_train.avg}, epoch)
  163.                 self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)
  164.                 # 训练混淆矩阵图
  165.                 conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,
  166.                                                             verbose=epoch == epochs - 1, save=False)
  167.                 self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)
  168.             # validate
  169.             loss_valid_meter, acc_valid, top1_group, mat_valid = \
  170.                 utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)
  171.             # 日志
  172.             self.writer.add_scalars('Loss_group', {'valid_loss':
  173.                                                        loss_valid_meter.avg}, epoch)
  174.             self.writer.add_scalars('Accuracy_group', {'valid_acc':
  175.                                                            acc_valid * 100}, epoch)
  176.             # 验证混淆矩阵图
  177.             conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,
  178.                                                         verbose=epoch == epochs - 1, save=False)
  179.             self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)
  180.             self.logger.info(
  181.                 'Epoch: [{:0>3}/{:0>3}]  '
  182.                 'Train Loss avg: {loss_train:>6.4f}  '
  183.                 'Valid Loss avg: {loss_valid:>6.4f}  '
  184.                 'Train Acc@1 avg:  {top1_train:>7.2f}%   '
  185.                 'Valid Acc@1 avg: {top1_valid:>7.2%}    '
  186.                 'LR: {lr}'.format(
  187.                     epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,
  188.                     top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))
  189.             for model_idx, top1_meter in enumerate(top1_group):
  190.                 self.writer.add_scalars('Accuracy_group',
  191.                                         {'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)
  192.             if acc_valid > best_acc:
  193.                 best_acc = acc_valid
  194.                 self.estimators_ = nn.ModuleList()
  195.                 self.estimators_.extend(estimators)
  196.                 if save_model:
  197.                     self.save(self, self.save_dir, self.logger)
  198. if __name__ == "__main__":
  199.     main()
复制代码
效果图

本实验采用3个学习器举行投票式集成,因此绘制了7条曲线,此中各学习器在训练和验证各有2条曲线,集成模子的效果通过 valid_acc输出(蓝色),通过下图可发现,集成模子与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。



参考

7.7 TorchEnsemble 模子集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

张春

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表