IT评测·应用市场-qidao123.com

标题: 图像分类数据集 [打印本页]

作者: 十念    时间: 2025-3-15 15:14
标题: 图像分类数据集
《动手学深度学习》-3.5-学习笔记
  1. # 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
  2. # 并除以255使得所有像素的数值均在0~1之间
  3. trans = transforms.ToTensor()#用于将图像数据从 PIL 图像格式(Python Imaging Library,Python 的图像处理库)转换为 PyTorch 张量(Tensor)。
  4. mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)#加载训练数据集
  5. mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)#加载测试数据集
复制代码

  1. def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
  2.     """绘制图像列表"""
  3.     figsize = (num_cols * scale, num_rows * scale)
  4.     _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
  5.     axes = axes.flatten()
  6.     for i, (ax, img) in enumerate(zip(axes, imgs)):
  7.         if torch.is_tensor(img):
  8.             # 图片张量
  9.             ax.imshow(img.numpy())
  10.         else:
  11.             # PIL图片
  12.             ax.imshow(img)
  13.         ax.axes.get_xaxis().set_visible(False)
  14.         ax.axes.get_yaxis().set_visible(False)
  15.         if titles:
  16.             ax.set_title(titles[i])
  17.     return axes
复制代码
 show_images 是一个用于批量显示图像的工具函数,
  1. X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
  2. show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
复制代码
从 FashionMNIST 数据集中加载一批图像,使用 show_images 函数将图像以 2 行 9 列的网格形式显示,并为每张图像添加文本标签。

 
创建Dataloader
  1. batch_size = 256
  2. def get_dataloader_workers():  
  3.     """使用4个进程来读取数据"""
  4.     return 4
  5. train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
复制代码
  1. def load_data_fashion_mnist(batch_size, resize=None):
  2.     """下载Fashion-MNIST数据集"""
  3.     trans = [transforms.ToTensor()]
  4.     if resize:
  5.         trans.insert(0, transforms.Resize(resize))
  6.     trans = transforms.Compose(trans)
  7.     mnist_train = torchvision.datasets.FashionMNIST(
  8.         root="../data", train=True, transform=trans, download=True)
  9.     mnist_test = torchvision.datasets.FashionMNIST(
  10.         root="../data", train=False, transform=trans, download=True)
  11.     return (data.DataLoader(mnist_train, batch_size, shuffle=True,
  12.                             num_workers=get_dataloader_workers()),
  13.             data.DataLoader(mnist_test, batch_size, shuffle=False,
  14.                             num_workers=get_dataloader_workers()))
复制代码
用于下载并加载 FashionMNIST 数据集,并将其转换为适合训练和测试的 DataLoader 对象。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/) Powered by Discuz! X3.4