IT评测·应用市场-qidao123.com

标题: Pytorch中的torch.utils.data.Dataset 类 [打印本页]

作者: 渣渣兔    时间: 2025-3-23 11:59
标题: Pytorch中的torch.utils.data.Dataset 类
1、使用方法

  1. from torch.utils.data import Dataset
复制代码
2、torch.utils.data.Dataset 类的界说

        使用以下操纵可以查看该类的界说:

  1. class Dataset(Generic[_T_co]):
  2.     r"""An abstract class representing a :class:`Dataset`.
  3.     All datasets that represent a map from keys to data samples should subclass
  4.     it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
  5.     data sample for a given key. Subclasses could also optionally overwrite
  6.     :meth:`__len__`, which is expected to return the size of the dataset by many
  7.     :class:`~torch.utils.data.Sampler` implementations and the default options
  8.     of :class:`~torch.utils.data.DataLoader`. Subclasses could also
  9.     optionally implement :meth:`__getitems__`, for speedup batched samples
  10.     loading. This method accepts list of indices of samples of batch and returns
  11.     list of samples.
  12.     .. note::
  13.       :class:`~torch.utils.data.DataLoader` by default constructs an index
  14.       sampler that yields integral indices.  To make it work with a map-style
  15.       dataset with non-integral indices/keys, a custom sampler must be provided.
  16.     """
  17.     def __getitem__(self, index) -> _T_co:
  18.         raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
  19.     # def __getitems__(self, indices: List) -> List[_T_co]:
  20.     # Not implemented to prevent false-positives in fetcher check in
  21.     # torch.utils.data._utils.fetch._MapDatasetFetcher
  22.     def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
  23.         return ConcatDataset([self, other])
  24.     # No `def __len__(self)` default?
  25.     # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  26.     # in pytorch/torch/utils/data/sampler.py
复制代码
原文解释:
3、示例

示例 1:简单的整数索引数据集
 假设我们有一个数据集,数据储存在一个列表中,我们可以通过整数索引来访问。
  1. from torch.utils.data import Dataset
  2. class SimpleDataset(Dataset):
  3.     def __init__(self, data):
  4.         self.data = data
  5.     def __getitem__(self, index):
  6.         return self.data[index]
  7.     def __len__(self):
  8.         return len(self.data)
  9. # 示例数据
  10. data = [1, 2, 3, 4, 5]
  11. # 创建数据集
  12. dataset = SimpleDataset(data)
  13. # 使用
  14. print(dataset[0])  # 输出:1
  15. print(len(dataset))  # 输出:5
复制代码
示例 2:字符串键的数据集

假设我们有一个数据集,数据以字典形式存储,键是字符串。
  1. from torch.utils.data import Dataset
  2. class StringKeyDataset(Dataset):
  3.     def __init__(self, data):
  4.         self.data = data
  5.         self.keys = list(data.keys())
  6.     def __getitem__(self, key):
  7.         return self.data[key]
  8.     def __len__(self):
  9.         return len(self.keys)
  10. # 示例数据
  11. data = {"a": 1, "b": 2, "c": 3}
  12. # 创建数据集
  13. dataset = StringKeyDataset(data)
  14. # 使用
  15. print(dataset["a"])  # 输出:1
  16. print(len(dataset))  # 输出:3
复制代码
留意:假如需要与 DataLoader 一起使用,必须提供一个自界说的采样器,由于默认的采样器生成整数索引。
示例 3:实现 __getitems__ 方法

为了实现批量加载数据,我们可以实现 __getitems__ 方法。
  1. from torch.utils.data import Dataset
  2. class BatchableDataset(Dataset):
  3.     def __init__(self, data):
  4.         self.data = data
  5.     def __getitem__(self, index):
  6.         return self.data[index]
  7.     def __getitems__(self, indices):
  8.         return [self.data[i] for i in indices]
  9.     def __len__(self):
  10.         return len(self.data)
  11. # 示例数据
  12. data = [10, 20, 30, 40, 50]
  13. # 创建数据集
  14. dataset = BatchableDataset(data)
  15. # 使用
  16. print(dataset[0])  # 输出:10
  17. print(dataset.__getitems__([1, 3]))  # 输出:[20, 40]
复制代码
示例 4:图像数据集

假设我们有一个图像数据集,图像路径存储在列表中。
  1. from torch.utils.data import Dataset
  2. from PIL import Image
  3. import os
  4. class ImageDataset(Dataset):
  5.     def __init__(self, img_dir):
  6.         self.img_dir = img_dir
  7.         self.img_names = os.listdir(img_dir)
  8.     def __getitem__(self, index):
  9.         img_name = self.img_names[index]
  10.         img_path = os.path.join(self.img_dir, img_name)
  11.         image = Image.open(img_path)
  12.         return image
  13.     def __len__(self):
  14.         return len(self.img_names)
  15. # 创建数据集
  16. img_dir = "path/to/images"
  17. dataset = ImageDataset(img_dir)
  18. # 使用
  19. print(len(dataset))  # 输出图像数量
  20. print(dataset[0])  # 输出第一张图像
复制代码
示例 5:自界说采样器

假如你的数据集使用非整数键(如字符串),而且你想与 DataLoader 一起使用,可以界说一个自界说采样器。
  1. from torch.utils.data import Dataset, DataLoader, Sampler
  2. import random
  3. class StringKeyDataset(Dataset):
  4.     def __init__(self, data):
  5.         self.data = data
  6.         self.keys = list(data.keys())
  7.     def __getitem__(self, key):
  8.         return self.data[key]
  9.     def __len__(self):
  10.         return len(self.keys)
  11. class StringSampler(Sampler):
  12.     def __init__(self, keys):
  13.         self.keys = keys
  14.     #每次调用时(如新的epoch开始),先打乱键的顺序,再返回迭代器。
  15.     #实现数据加载时的随机化顺序。
  16.     def __iter__(self):
  17.         random.shuffle(self.keys)
  18.         return iter(self.keys)
  19.     def __len__(self):
  20.         return len(self.keys)
  21. # 示例数据
  22. data = {"a": 1, "b": 2, "c": 3}
  23. # 创建数据集和采样器
  24. dataset = StringKeyDataset(data)
  25. sampler = StringSampler(dataset.keys)
  26. # 使用 DataLoader
  27. dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)
  28. for batch in dataloader:
  29.     print(batch)  # 输出批次数据
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/) Powered by Discuz! X3.4