1、使用方法
- from torch.utils.data import Dataset
复制代码 2、torch.utils.data.Dataset 类的界说
使用以下操纵可以查看该类的界说:
- “ctrl”+左键点击"Dataset”
- 实行代码:help(Dataset),需要先导入该类(见1、使用方法)
- class Dataset(Generic[_T_co]):
- r"""An abstract class representing a :class:`Dataset`.
- All datasets that represent a map from keys to data samples should subclass
- it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
- data sample for a given key. Subclasses could also optionally overwrite
- :meth:`__len__`, which is expected to return the size of the dataset by many
- :class:`~torch.utils.data.Sampler` implementations and the default options
- of :class:`~torch.utils.data.DataLoader`. Subclasses could also
- optionally implement :meth:`__getitems__`, for speedup batched samples
- loading. This method accepts list of indices of samples of batch and returns
- list of samples.
- .. note::
- :class:`~torch.utils.data.DataLoader` by default constructs an index
- sampler that yields integral indices. To make it work with a map-style
- dataset with non-integral indices/keys, a custom sampler must be provided.
- """
- def __getitem__(self, index) -> _T_co:
- raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
- # def __getitems__(self, indices: List) -> List[_T_co]:
- # Not implemented to prevent false-positives in fetcher check in
- # torch.utils.data._utils.fetch._MapDatasetFetcher
- def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
- return ConcatDataset([self, other])
- # No `def __len__(self)` default?
- # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- # in pytorch/torch/utils/data/sampler.py
复制代码 原文解释:
- 全部表示从键到数据样本映射的数据集都应该继承自这个类。
这意味着,假如你有一个数据集,它通过某些键(大概是整数、字符串等)来访问数据样本,那么你应该从 Dataset 类继承来创建你的数据集类。
- 全部的子类都应该重写 __getitem__ 方法,支持通过给定的键获取数据样本。
__getitem__ 是 Python 的特殊方法,用于通过 dataset[key] 这样的语法来获取数据。在你的子类中,你需要实现这个方法,确保它可以或许返回与给定键对应的数据样本。
- 子类也可以选择性地重写 __len__ 方法,该方法通常被很多 Sampler 实现和 DataLoader 的默认选项所使用,用于返回数据集的巨细。
__len__ 方法用于返回数据集中样本的总数。虽然它不是强制要求的,但假如你希望使用 PyTorch 的 Sampler 或 DataLoader,通常需要实现这个方法。
- 子类还可以选择性地实现 __getitems__ 方法,以加快批量数据加载。这个方法担当一个包罗批次样本索引的列表,并返回一个样本列表。
__getitems__ 是一个可选的优化方法。假如你需要批量加载数据,实现这个方法可以进步效率。它担当一个索引列表,并返回对应的样本列表。
- DataLoader 默认构造一个生成整数索引的采样器(sampler)。要使它可以或许与具有非整数索引/键的 map-style 数据集一起工作,必须提供一个自界说的采样器。
DataLoader 默认情况下假设你的数据集是可以通过整数索引访问的(即 dataset[0], dataset[1] 等)。假如你的数据集使用非整数键(比如字符串或其他范例),你需要提供一个自界说的采样器来生成这些键。
3、示例
示例 1:简单的整数索引数据集
假设我们有一个数据集,数据储存在一个列表中,我们可以通过整数索引来访问。
- from torch.utils.data import Dataset
- class SimpleDataset(Dataset):
- def __init__(self, data):
- self.data = data
- def __getitem__(self, index):
- return self.data[index]
- def __len__(self):
- return len(self.data)
- # 示例数据
- data = [1, 2, 3, 4, 5]
- # 创建数据集
- dataset = SimpleDataset(data)
- # 使用
- print(dataset[0]) # 输出:1
- print(len(dataset)) # 输出:5
复制代码 示例 2:字符串键的数据集
假设我们有一个数据集,数据以字典形式存储,键是字符串。
- from torch.utils.data import Dataset
- class StringKeyDataset(Dataset):
- def __init__(self, data):
- self.data = data
- self.keys = list(data.keys())
- def __getitem__(self, key):
- return self.data[key]
- def __len__(self):
- return len(self.keys)
- # 示例数据
- data = {"a": 1, "b": 2, "c": 3}
- # 创建数据集
- dataset = StringKeyDataset(data)
- # 使用
- print(dataset["a"]) # 输出:1
- print(len(dataset)) # 输出:3
复制代码 留意:假如需要与 DataLoader 一起使用,必须提供一个自界说的采样器,由于默认的采样器生成整数索引。
示例 3:实现 __getitems__ 方法
为了实现批量加载数据,我们可以实现 __getitems__ 方法。
- from torch.utils.data import Dataset
- class BatchableDataset(Dataset):
- def __init__(self, data):
- self.data = data
- def __getitem__(self, index):
- return self.data[index]
- def __getitems__(self, indices):
- return [self.data[i] for i in indices]
- def __len__(self):
- return len(self.data)
- # 示例数据
- data = [10, 20, 30, 40, 50]
- # 创建数据集
- dataset = BatchableDataset(data)
- # 使用
- print(dataset[0]) # 输出:10
- print(dataset.__getitems__([1, 3])) # 输出:[20, 40]
复制代码 示例 4:图像数据集
假设我们有一个图像数据集,图像路径存储在列表中。
- from torch.utils.data import Dataset
- from PIL import Image
- import os
- class ImageDataset(Dataset):
- def __init__(self, img_dir):
- self.img_dir = img_dir
- self.img_names = os.listdir(img_dir)
- def __getitem__(self, index):
- img_name = self.img_names[index]
- img_path = os.path.join(self.img_dir, img_name)
- image = Image.open(img_path)
- return image
- def __len__(self):
- return len(self.img_names)
- # 创建数据集
- img_dir = "path/to/images"
- dataset = ImageDataset(img_dir)
- # 使用
- print(len(dataset)) # 输出图像数量
- print(dataset[0]) # 输出第一张图像
复制代码 示例 5:自界说采样器
假如你的数据集使用非整数键(如字符串),而且你想与 DataLoader 一起使用,可以界说一个自界说采样器。
- from torch.utils.data import Dataset, DataLoader, Sampler
- import random
- class StringKeyDataset(Dataset):
- def __init__(self, data):
- self.data = data
- self.keys = list(data.keys())
- def __getitem__(self, key):
- return self.data[key]
- def __len__(self):
- return len(self.keys)
- class StringSampler(Sampler):
- def __init__(self, keys):
- self.keys = keys
- #每次调用时(如新的epoch开始),先打乱键的顺序,再返回迭代器。
- #实现数据加载时的随机化顺序。
- def __iter__(self):
- random.shuffle(self.keys)
- return iter(self.keys)
- def __len__(self):
- return len(self.keys)
- # 示例数据
- data = {"a": 1, "b": 2, "c": 3}
- # 创建数据集和采样器
- dataset = StringKeyDataset(data)
- sampler = StringSampler(dataset.keys)
- # 使用 DataLoader
- dataloader = DataLoader(dataset, sampler=sampler, batch_size=2)
- for batch in dataloader:
- print(batch) # 输出批次数据
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |