qidao123.com技术社区-IT企服评测·应用市场

标题: 【深度学习】Pytorch:加载自界说数据集 [打印本页]

作者: 张国伟    时间: 2025-1-12 23:57
标题: 【深度学习】Pytorch:加载自界说数据集
本教程将利用 flower_photos 数据集演示如安在 PyTorch 中加载和导入自界说数据集。该数据集包罗不同花种的图像,每种花的图像存储在以混名定名的子文件夹中。我们将深入解说每个函数和对象的利用方法,使读者能够推广应用到其他数据集使命中。
  1. flower_photos/
  2. ├── daisy/
  3. │   ├── image1.jpg
  4. │   ├── image2.jpg
  5. └── rose/
  6.      ├── image1.jpg
  7.      ├── image2.jpg
  8. ...
复制代码
环境设置

所需工具和库

  1. pip install torch torchvision matplotlib
复制代码
导入必要的库

  1. import os
  2. import torch
  3. from torchvision import datasets, transforms
  4. from torch.utils.data import DataLoader
  5. import matplotlib.pyplot as plt
  6. from PIL import Image
  7. import pathlib
复制代码
数据集导入方法

界说数据转换

图像转换在计算机视觉使命中至关重要。通过 transforms 对象,我们可以实现图像大小调解、归一化、随机变换等预处理操作。
  1. # 定义图像转换  
  2. transform = transforms.Compose([  
  3.     transforms.Resize((150, 150)),  # 调整图像大小为 150x150  
  4.     transforms.ToTensor(),  # 将图像转换为 PyTorch 张量  
  5.     transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化图像数据  
  6. ])  
  7. # 数据路径  
  8. data_dir = r"E:\CodeSpace\Deep\data\flower_photos"  
  9. # 使用 ImageFolder 加载数据  
  10. full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)  
  11. # 计算训练集和测试集的样本数量(80%和20%的划分)  
  12. train_size = int(0.8 * len(full_dataset))  
  13. test_size = len(full_dataset) - train_size  
  14. # 随机划分数据集  
  15. train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  
  16. # 创建数据加载器  
  17. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
  18. test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  
  19. # 获取类别名  
  20. class_names = full_dataset.classes  
  21. print("类别名:", class_names)
复制代码
体现部分样本图像

可视化样本数据有助于理解数据集结构和数据质量。
  1. # 定义函数以绘制样本图像
  2. def plot_images(images, labels, class_names):
  3.     plt.figure(figsize=(10, 10))
  4.     for i in range(9):  # 绘制前 9 张图像
  5.         plt.subplot(3, 3, i + 1)
  6.         img = images[i].permute(1, 2, 0)  # 将张量维度从 (C, H, W) 转为 (H, W, C)
  7.         plt.imshow(img * 0.5 + 0.5)  # 反归一化处理,恢复到原始像素范围 [0, 1]
  8.         plt.title(class_names[labels[i]])  # 显示类别标签
  9.         plt.axis('off')  # 去掉坐标轴
  10. # 获取部分样本数据用于展示
  11. sample_images, sample_labels = next(iter(train_loader))
  12. plot_images(sample_images, sample_labels, class_names)
复制代码
自界说数据加载方法

当数据结构复杂或需要额外处理时,可以通过继承 torch.utils.data.Dataset 创建自界说数据加载类。
Dataset 类详解

Dataset 是 PyTorch 中的一个抽象类,用户需要实现以下焦点方法:
代码实现

  1. class CustomFlowerDataset(torch.utils.data.Dataset):
  2.     def __init__(self, data_dir, transform=None):
  3.         # 初始化数据集路径和图像转换方法
  4.         self.data_dir = pathlib.Path(data_dir)
  5.         self.transform = transform
  6.         self.image_paths = list(self.data_dir.glob('*/*.jpg'))  # 获取所有图像路径
  7.         self.label_names = sorted(item.name for item in self.data_dir.glob('*/') if item.is_dir())
  8.         self.label_to_index = {name: idx for idx, name in enumerate(self.label_names)}  # 将类别名映射为索引
  9.     def __len__(self):
  10.         # 返回数据集大小
  11.         return len(self.image_paths)
  12.     def __getitem__(self, idx):
  13.         # 根据索引获取图像及其标签
  14.         img_path = self.image_paths[idx]
  15.         label = self.label_to_index[img_path.parent.name]  # 通过父文件夹名获取标签
  16.         image = Image.open(img_path).convert("RGB")  # 确保图像是 RGB 模式
  17.         if self.transform:
  18.             image = self.transform(image)  # 进行图像预处理
  19.         return image, label
  20. # 使用自定义数据集
  21. custom_dataset = CustomFlowerDataset(data_dir, transform=transform)
  22. custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True)
复制代码
随机分别数据集

假如你还希望在这个自界说数据集上随机分别训练集和测试集,可以利用 torch.utils.data.random_split。以下是示例代码:
  1. from torch.utils.data import random_split  
  2. # 获取数据集长度  
  3. full_dataset = CustomFlowerDataset(data_dir, transform=transform)  
  4. # 计算训练集和测试集的样本数量(80%和20%的划分)  
  5. train_size = int(0.8 * len(full_dataset))  
  6. test_size = len(full_dataset) - train_size  
  7. # 随机划分数据集  
  8. train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])  
  9. # 创建数据加载器  
  10. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)  
  11. test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)  
  12. print(f"训练集大小: {len(train_dataset)}, 测试集大小: {len(test_dataset)}")  
复制代码
数据加载性能优化


  1. custom_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2)
复制代码
Dataset 类扩展发起

数据集的利用方法

遍历数据集

模子训练前需要遍历数据集以加载图像和标签:
  1. for images, labels in custom_loader:
  2.     # images 是图像张量,labels 是对应的类别标签
  3.     print(f"图像张量大小: {images.shape}, 标签: {labels}")
复制代码
模子输入

数据集加载完成后可直接用于模子训练:
  1. import torch.nn as nn
  2. import torch.optim as optim
  3. # 定义一个简单的神经网络模型
  4. model = nn.Sequential(
  5.     nn.Flatten(),  # 将输入张量展平成一维
  6.     nn.Linear(150*150*3, 128),  # 输入层到隐藏层的全连接层
  7.     nn.ReLU(),  # 激活函数
  8.     nn.Linear(128, len(class_names))  # 输出层,类别数量等于花的种类数
  9. )
  10. # 定义损失函数和优化器
  11. criterion = nn.CrossEntropyLoss()  # 交叉熵损失适用于多分类问题
  12. optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器
  13. # 示例训练过程
  14. for epoch in range(2):  # 简单训练两轮
  15.     for images, labels in custom_loader:
  16.         outputs = model(images)  # 前向传播计算输出
  17.         loss = criterion(outputs, labels)  # 计算损失
  18.         optimizer.zero_grad()  # 梯度清零
  19.         loss.backward()  # 反向传播计算梯度
  20.         optimizer.step()  # 更新模型参数
  21.     print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
复制代码
模子评估

加载后的数据集也可用于验证模子性能:
  1. correct = 0
  2. total = 0
  3. model.eval()  # 设置模型为评估模式
  4. with torch.no_grad():
  5.     for images, labels in test_loader:
  6.         outputs = model(images)
  7.         _, predicted = torch.max(outputs, 1)
  8.         total += labels.size(0)
  9.         correct += (predicted == labels).sum().item()
  10. accuracy = 100 * correct / total
  11. print(f"模型准确率: {accuracy:.2f}%")
复制代码
方法对比与扩展

ImageFolder vs 自界说 Dataset


进步模子泛化本领


总结

本教程具体解说了如安在 PyTorch 中加载和导入 flower_photos 数据集,结合不同方法的解说使你能根据项目需求灵活选择得当的数据加载方案。同时,我们探讨了优化和扩展方法,希望这些内容能为你的深度学习项目提供有力支持。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 qidao123.com技术社区-IT企服评测·应用市场 (https://dis.qidao123.com/) Powered by Discuz! X3.4