dataset 类
功能与作用
- 在PyTorch中,Dataset 类是torch.utils.data模块的一部门,它是一个抽象的基类,用于界说了数据集加载和处理的尺度接口。通过继承这个类并实现其方法,可以创建自界说的数据集来适应各种机器学习任务。
根本布局介绍
- 抽象基类界说:是一个泛型类,使用 Generic[_T_co] 来表示它可以担当一个协变范例参数 _T_co。这个类是全部数据集类的基类,它界说了数据集应该遵循的根本接口。
Dataset 类的主要构成部门:
- 文档字符串(docstring):提供了关于类的使用和实现的详细阐明。它指出全部映射键到数据样本的数据集都应该继承这个类,而且应该覆盖 __getitem__ 方法来支持给定键的数据样本获取。它还提到子类可以选择性地覆盖 __len__ 方法来返回数据集的大小,这在许多环境下是有用的,比如在 Sampler 实现和 DataLoader 的默认选项中。别的,子类也可以选择性地实现 __getitems__ 方法来加速批量样本加载。
- __getitem__ 方法:这是一个抽象方法,子类必须实现它。这个方法应该根据给定的索引返回对应的数据样本。假如子类没有实现这个方法,实验获取数据样本时会抛出 NotImplementedError。
- __getitems__ 方法:这个方法被注释掉了,但它是可选的,用于加速批量样本的加载。假如实现,它应该担当一个样本索引列表,并返回一个样本列表。
- __add__ 方法:这个方法答应将两个 Dataset 对象相加,返回一个新的 ConcatDataset 对象,该对象将两个数据集合并为一个连续的数据集。
- __len__ 方法的注释:这部门注释阐明为什么没有为 Dataset 类提供一个默认的 __len__ 方法。正如之前解释的,假如子类没有实现 __len__ 方法,那么在实验获取数据集大小时会抛出 TypeError,这是一种逼迫子类提供实现的方式。
- class Dataset(Generic[_T_co]):
- r"""An abstract class representing a :class:`Dataset`.
- All datasets that represent a map from keys to data samples should subclass
- it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
- data sample for a given key. Subclasses could also optionally overwrite
- :meth:`__len__`, which is expected to return the size of the dataset by many
- :class:`~torch.utils.data.Sampler` implementations and the default options
- of :class:`~torch.utils.data.DataLoader`. Subclasses could also
- optionally implement :meth:`__getitems__`, for speedup batched samples
- loading. This method accepts list of indices of samples of batch and returns
- list of samples.
- .. note::
- :class:`~torch.utils.data.DataLoader` by default constructs an index
- sampler that yields integral indices. To make it work with a map-style
- dataset with non-integral indices/keys, a custom sampler must be provided.
- """
- def __getitem__(self, index) -> _T_co:
- raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
- # def __getitems__(self, indices: List) -> List[_T_co]:
- # Not implemented to prevent false-positives in fetcher check in
- # torch.utils.data._utils.fetch._MapDatasetFetcher
- def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
- return ConcatDataset([self, other])
- # No `def __len__(self)` default?
- # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- # in pytorch/torch/utils/data/sampler.py
复制代码
- Dataset 类界说了以下两个焦点方法,任何自界说数据集都需要实现这些方法:
- __len__(self):返回数据会合的样本总数。
- __getitem__(self, idx):根据给定的索引idx返回一个样本。这个样本可以是一个数据点,也可以是一个数据点及其对应的标签。
- 继承自Dataset的其他常用类:
- TensorDataset:用于处理由张量构成的数据集。它将输入张量和目标张量组合在一起,形成一个数据集。
- ImageFolder:用于从文件体系中加载图像数据集。它假设每个子目录代表一个种别,并将每个图像文件作为一个样本。
- ConcatDataset:用于将多个数据集合并成一个大的数据集。
- Subset:用于从一个大数据会合选择一个子集。
- ChainDataset:用于将多个数据集串联起来,使得它们可以像一个数据集一样被迭代。
使用方法
- from torch.utils.data import Dataset, DataLoader
- class CustomDataset(Dataset):
- def __init__(self, data, labels):
- self.data = data
- self.labels = labels
- def __len__(self):
- return len(self.data)
- def __getitem__(self, idx):
- return self.data[idx], self.labels[idx]
- # 假设我们有一些自定义数据和标签
- custom_data = [...]
- custom_labels = [...]
- # 创建自定义数据集
- custom_dataset = CustomDataset(custom_data, custom_labels)
- # 使用DataLoader来迭代自定义数据集
- custom_data_loader = DataLoader(custom_dataset, batch_size=20, shuffle=True)
- for batch_idx, (data, target) in enumerate(custom_data_loader):
- # 在这里处理你的数据和目标
- pass
复制代码 dataset 类源码
- 源码:https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataset.py
- # mypy: allow-untyped-defsimport bisectimport itertoolsimport mathimport warningsfrom typing import ( cast, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union,)from typing_extensions import deprecated# No 'default_generator' in torch/__init__.pyifrom torch import default_generator, Generator, randperm, Tensor__all__ = [ "Dataset", "IterableDataset", "TensorDataset", "StackDataset", "ConcatDataset", "ChainDataset", "Subset", "random_split",]_T = TypeVar("_T")_T_co = TypeVar("_T_co", covariant=True)_T_dict = Dict[str, _T_co]_T_tuple = Tuple[_T_co, ...]_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict)class Dataset(Generic[_T_co]):
- r"""An abstract class representing a :class:`Dataset`.
- All datasets that represent a map from keys to data samples should subclass
- it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
- data sample for a given key. Subclasses could also optionally overwrite
- :meth:`__len__`, which is expected to return the size of the dataset by many
- :class:`~torch.utils.data.Sampler` implementations and the default options
- of :class:`~torch.utils.data.DataLoader`. Subclasses could also
- optionally implement :meth:`__getitems__`, for speedup batched samples
- loading. This method accepts list of indices of samples of batch and returns
- list of samples.
- .. note::
- :class:`~torch.utils.data.DataLoader` by default constructs an index
- sampler that yields integral indices. To make it work with a map-style
- dataset with non-integral indices/keys, a custom sampler must be provided.
- """
- def __getitem__(self, index) -> _T_co:
- raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
- # def __getitems__(self, indices: List) -> List[_T_co]:
- # Not implemented to prevent false-positives in fetcher check in
- # torch.utils.data._utils.fetch._MapDatasetFetcher
- def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
- return ConcatDataset([self, other])
- # No `def __len__(self)` default?
- # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
- # in pytorch/torch/utils/data/sampler.py
- class IterableDataset(Dataset[_T_co], Iterable[_T_co]): r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """ def __add__(self, other: Dataset[_T_co]): return ChainDataset([self, other]) # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all( tensors[0].size(0) == tensor.size(0) for tensor in tensors ), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)class StackDataset(Dataset[_T_stack]): r"""Dataset as a stacking of multiple datasets. This class is useful to assemble different parts of complex input data, given as datasets. Example: >>> # xdoctest: +SKIP >>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} Args: *args (Dataset): Datasets for stacking returned as tuple. **kwargs (Dataset): Datasets for stacking returned as dict. """ datasets: Union[tuple, dict] def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: if args: if kwargs: raise ValueError( "Supported either ``tuple``- (via ``args``) or" "``dict``- (via ``kwargs``) like input/output, but both types are given." ) self._length = len(args[0]) # type: ignore[arg-type] if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] raise ValueError("Size mismatch between datasets") self.datasets = args elif kwargs: tmp = list(kwargs.values()) self._length = len(tmp[0]) # type: ignore[arg-type] if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] raise ValueError("Size mismatch between datasets") self.datasets = kwargs else: raise ValueError("At least one dataset should be passed") def __getitem__(self, index): if isinstance(self.datasets, dict): return {k: dataset[index] for k, dataset in self.datasets.items()} return tuple(dataset[index] for dataset in self.datasets) def __getitems__(self, indices: list): # add batched sampling support when parent datasets supports it. if isinstance(self.datasets, dict): dict_batch: List[_T_dict] = [{} for _ in indices] for k, dataset in self.datasets.items(): if callable(getattr(dataset, "__getitems__", None)): items = dataset.__getitems__(indices) # type: ignore[attr-defined] if len(items) != len(indices): raise ValueError( "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) for data, d_sample in zip(items, dict_batch): d_sample[k] = data else: for idx, d_sample in zip(indices, dict_batch): d_sample[k] = dataset[idx] return dict_batch # tuple data list_batch: List
- [list] = [[] for _ in indices] for dataset in self.datasets: if callable(getattr(dataset, "__getitems__", None)): items = dataset.__getitems__(indices) # type: ignore[attr-defined] if len(items) != len(indices): raise ValueError( "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) for data, t_sample in zip(items, list_batch): t_sample.append(data) else: for idx, t_sample in zip(indices, list_batch): t_sample.append(dataset[idx]) tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch] return tuple_batch def __len__(self): return self._lengthclass ConcatDataset(Dataset[_T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[_T_co]] cumulative_sizes: List[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset]) -> None: super().__init__() self.datasets = list(datasets) assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] for d in self.datasets: assert not isinstance( d, IterableDataset ), "ConcatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError( "absolute value of index should not exceed dataset length" ) idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property @deprecated( "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", category=FutureWarning, ) def cummulative_sizes(self): return self.cumulative_sizesclass ChainDataset(IterableDataset): r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: Iterable[Dataset]) -> None: super().__init__() self.datasets = datasets def __iter__(self): for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" yield from d def __len__(self): total = 0 for d in self.datasets: assert isinstance( d, IterableDataset ), "ChainDataset only supports IterableDataset" total += len(d) # type: ignore[arg-type] return totalclass Subset(Dataset[_T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[_T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): if isinstance(idx, list): return self.dataset[[self.indices[i] for i in idx]] return self.dataset[self.indices[idx]] def __getitems__(self, indices: List[int]) -> List[_T_co]: # add batched sampling support when parent dataset supports it. # see torch.utils.data._utils.fetch._MapDatasetFetcher if callable(getattr(self.dataset, "__getitems__", None)): return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] else: return [self.dataset[self.indices[idx]] for idx in indices] def __len__(self): return len(self.indices)def random_split( dataset: Dataset[_T], lengths: Sequence[Union[int, float]], generator: Optional[Generator] = default_generator,) -> List[Subset[_T]]: r""" Randomly split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided. After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left. Optionally fix the generator for reproducible results, e.g.: Example: >>> # xdoctest: +SKIP >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths or fractions of splits to be produced generator (Generator): Generator used for the random permutation. """ if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: subset_lengths: List[int] = [] for i, frac in enumerate(lengths): if frac < 0 or frac > 1: raise ValueError(f"Fraction at index {i} is not between 0 and 1") n_items_in_split = int( math.floor(len(dataset) * frac) # type: ignore[arg-type] ) subset_lengths.append(n_items_in_split) remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] # add 1 to all the lengths in round-robin fashion until the remainder is 0 for i in range(remainder): idx_to_add_at = i % len(subset_lengths) subset_lengths[idx_to_add_at] += 1 lengths = subset_lengths for i, length in enumerate(lengths): if length == 0: warnings.warn( f"Length of split at index {i} is 0. " f"This might result in an empty dataset." ) # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore[arg-type] raise ValueError( "Sum of input lengths does not equal the length of the input dataset!" ) indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] lengths = cast(Sequence[int], lengths) return [ Subset(dataset, indices[offset - length : offset]) for offset, length in zip(itertools.accumulate(lengths), lengths) ]
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |