深度学习:迁徙学习 [复制链接]
发表于 2025-9-21 22:23:01 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

×
迁徙学习

标题1.什么是迁徙学习

迁徙学习(Transfer Learning)是一种呆板学习方法,就是把为使命 A 开发 的模子作为初始点,重新利用在为使命 B 开发模子的过程中。迁徙学习是通过 从已学习的相干使命中转移知识来改进学习的新使命,固然大多数呆板学习算 法都是为了办理单个使命而计划的,但是促进迁徙学习的算法的开发是呆板学 习社区一连关注的话题。 迁徙学习对人类来说很常见,比方,我们可能会发现 学习辨认苹果可能有助于辨认梨,大概学习弹奏电子琴可能有助于学习钢琴。
找到目的题目的相似性,迁徙学习使命就是从相似性出发,将旧范畴 (domain)学习过的模子应用在新范畴上
标题2.迁徙学习的步骤

1、选择预练习的模子和适当的层
通常,我们会选择在大规模图像数据集(如ImageNet)上预练习的模子,如VGG、ResNet等。然后,根据新数据集的特点,选择必要微调的模子层。对于低级特征的使命(如边缘检测),最好利用浅层模子的层,而对于高级特征的使命(如分类),则应选择更深条理的模子。
2、冻结预练习模子的参数
保持预练习模子的权重稳定,只练习新增长的层大概微调一些层,避免由于在数据集中过拟合导致预练习模子过分拟合。
3、在新数据集上练习新增长的层
在冻结预练习模子的参数情况下,练习新增长的层。如许,可以使新模子顺应新的使命,从而得到更高的性能
4、微调预练习模子的层
在新层上举行练习后,可以解冻一些已经练习过的层,而且将它们作为微调的目的。如许做可以进步模子在新数据集上的性能
5、评估和测试
在练习完成之后,利用测试集对模子举行评估。如果模子的性能仍然不够好,可以实验调解超参数大概更改微调层。
标题3.迁徙学习实例

该实例利用的模子是ResNet-18残差神经网络模子
###1. 导入必要的库
  1. 在import torch
  2. import torchvision.models as models
  3. from torch import nn
  4. from torch.utils.data import Dataset,DataLoader
  5. from torchvision import transforms
  6. from PIL import Image
  7. import numpy as np
复制代码
这里导入了后续代码会用到的库,具体如下:
torch:PyTorch 深度学习框架的焦点库。
torchvision.models:包罗了预练习的模子,这里会用到 ResNet-18。
torch.nn:用于构建神经网络的模块。
torch.utils.data.Dataset 和 torch.utils.data.DataLoader:用于自界说数据集和加载数据。
torchvision.transforms:用于图像的预处理惩罚。
PIL.Image:用于读取图像。
numpy:用于数值盘算。
###2. 加载预练习模子并修改全连接层
  1. resnet_model= models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
  2. for param in resnet_model.parameters():
  3.     print(param)
  4.     param.requires_grad=False
  5. in_features=resnet_model.fc.in_features
  6. resnet_model.fc=nn.Linear(in_features,20)
  7. params_to_update=[]
  8. for param in resnet_model.parameters():
  9.     if param.requires_grad==True:
  10.         params_to_update.append(param)
复制代码
加载预练习的 ResNet-18 模子。
把模子中全部参数的 requires_grad 设置为 False,也就是冻结这些参数,使其在练习时不更新。
获取原模子全连接层的输入特征数,然后将全连接层更换为一个新的全连接层,输出维度为 20。
网络全部 requires_grad 为 True 的参数,这些参数会在练习时更新。
###3. 界说图像预处理惩罚变换
  1. data_transforms = {
  2.     'train':
  3.         transforms.Compose([
  4.         transforms.Resize([300,300]),
  5.         transforms.RandomRotation(45),
  6.         transforms.CenterCrop(224),
  7.         transforms.RandomHorizontalFlip(p=0.5),
  8.         transforms.RandomVerticalFlip(p=0.5),
  9.         transforms.RandomGrayscale(p=0.1),
  10.         transforms.ToTensor(),
  11.         transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
  12.     ]),
  13.     'valid':
  14.         transforms.Compose([
  15.         transforms.Resize([224,224]),
  16.         transforms.ToTensor(),
  17.         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  18.         ]),
  19. }
复制代码
界说了两个图像预处理惩罚的组合变换,分别用于练习集和验证集。
练习集的变换包罗了数据增强操纵,像随机旋转、程度翻转、垂直翻转等。
验证集的变换只包罗了调解巨细、转换为张量和标准化操纵。
4. 自界说数据集类

  1. class food_dataset(Dataset):
  2.     def __init__(self,file_path,transform=None):
  3.         self.file_path = file_path
  4.         self.imgs = []
  5.         self.labels = []
  6.         self.transform = transform
  7.         with open(self.file_path) as f:
  8.             samples = [x.strip().split(' ') for x in f.readlines()]
  9.             for img_path,label in samples:
  10.                 self.imgs.append(img_path)
  11.                 self.labels.append(label)
  12.     def __len__(self):
  13.         return  len(self.imgs)
  14.     def __getitem__(self, idx):
  15.         image = Image.open(self.imgs[idx])
  16.         if self.transform:
  17.             image = self.transform(image)
  18.         label = self.labels[idx]
  19.         label = torch.from_numpy(np.array(label,dtype=np.int64))
  20.         return image,label
复制代码
自界说了一个 food_dataset 类,继承自 torch.utils.data.Dataset。 init 方法:分析包罗图像路径和标签的文本文件,把图像路径和标签分别存到 self.imgs 和 self.labels 中。
len 方法:返回数据集的巨细。
getitem 方法:根据索引读取图像,对图像举行预处理惩罚,将标签转换为张量,然后返回图像和标签。
5. 创建数据集和数据加载器

  1. training_data = food_dataset(file_path='./trainbig.txt',transform=data_transforms['train'])
  2. test_data = food_dataset(file_path='./testbig.txt',transform=data_transforms['valid'])
  3. train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
  4. test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
复制代码
创建练习集和测试集的数据集对象。
创建练习集和测试集的数据加载器,设置批量巨细为 64,而且打乱数据
###6. 设置练习装备、丧失函数、优化器和学习率调理器
  1. device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
  2. print(f"Using {device} device")
  3. model=resnet_model.to(device)
  4. loss_fn = nn.CrossEntropyLoss()
  5. optimizer = torch.optim.Adam(params_to_update,lr=0.001)
  6. scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.5)
复制代码
选择符合的练习装备(GPU 或 CPU)。
把模子移动到所选装备上。
界说交织熵丧失函数。
界说 Adam 优化器,只对之前网络的必要更新的参数举行优化。
界说学习率调理器,每 5 个 epoch 将学习率乘以 0.5。
###7. 界说练习和测试函数
  1. def train(dataloader,model,loss_fn,optimizer):
  2.     model.train()
  3.     batch_size_num = 1
  4.     for X,y in dataloader:
  5.         X,y = X.to(device),y.to(device)
  6.         pred = model.forward(X)
  7.         loss = loss_fn(pred,y)
  8.         optimizer.zero_grad()
  9.         loss.backward()
  10.         optimizer.step()
  11. def test(dataloader, model,loss_fn):
  12.     global best_acc
  13.     size = len(dataloader.dataset)
  14.     num_batches =len(dataloader)
  15.     model.eval()
  16.     test_loss,correct =0,0
  17.     with torch.no_grad():
  18.         for X, y in dataloader:
  19.             X,y = X.to(device),y.to(device)
  20.             pred = model.forward(X)
  21.             test_loss+=loss_fn(pred,y).item()
  22.             correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  23.     test_loss /= num_batches
  24.     correct /= size
  25.     print(f"Test result:\n Accuracy:{(100 * correct)}%, Avg loss: {test_loss}")
  26.     acc_s.append(correct)
  27.     loss_s.append(test_loss)
  28.     if correct>best_acc:
  29.         best_acc=correct
复制代码
train 函数:将模子设置为练习模式,遍历练习数据加载器,盘算丧失,反向传播并更新模子参数。
test 函数:将模子设置为评估模式,遍历测试数据加载器,盘算测试集的精确率和匀称丧失,记录最佳精确率。
8. 练习模子并保存
  1. epochs = 20
  2. acc_s = []
  3. loss_s =[]
  4. for t in range(epochs):
  5.     print(f"Epoch {t + 1}\n-----------")
  6.     train(train_dataloader, model,loss_fn, optimizer)
  7.     scheduler.step()
  8.     test(test_dataloader,model,loss_fn)
  9. print('最优训练结果为:',best_acc)
  10. torch.save(model.state_dict(), 'food_classification_model.pt')
复制代码
练习模子 20 个 epoch。
每个 epoch 竣事后,更新学习率并举行测试。
打印最优练习效果。
保存模子的参数到 food_classification_model.pt 文件中。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表