水军大提督 发表于 2025-11-5 14:15:19

PyTorch利用教程(7)-数据集处置惩罚

1、根本概念

在PyTorch中,torch.utils.data模块是处置惩罚数据集和数据加载的焦点工具。以下是该模块中一些根本概念的明白:
https://dis.qidao123.com/imgproxy/aHR0cHM6Ly9pLWJsb2cuY3NkbmltZy5jbi9kaXJlY3QvZjFhYTE5YWM5ZGYzNDdhNzgzZGE0NmJhNjlhYWZmNzcucG5nI3BpY19jZW50ZXI=
1.1 Dataset


[*]界说:Dataset是一个抽象类,用于表现数据集。用户必要通过继续Dataset类并实现其__len__和__getitem__方法来创建自界说的数据集。
[*]功能:Dataset界说了数据集的内容,它相称于一个类似列表的数据结构,具有确定的长度,并可以或许用索引获取数据会集的元素。
[*]范例:Dataset紧张分为两种范例:map-style和iterable-style。map-style数据集必要实现__getitem__和__len__方法,而iterable-style数据集则必要实现__iter__方法。
from typing import Generic, TypeVar, List

_T_co = TypeVar('_T_co', covariant=True)

class Dataset(Generic):
   
    def __getitem__(self, index: int) -> _T_co:
      raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    def __len__(self) -> int:
      raise NotImplementedError("Subclasses of Dataset should implement __len__.")

    def __add__(self, other: "Dataset") -> "ConcatDataset":
      """
      Adds two datasets. This can be useful when you have two datasets with potentially
      overlapping elements and you want to treat the elements as distinct.
      """
      from .dataset_ops import ConcatDataset
      return ConcatDataset()
1.2 DataLoader


[*]界说:DataLoader是一个迭代器,用于封装Dataset,并提供一个可迭代对象,方便举行批量加载、数据打乱、并行加载等操纵。
[*]功能:DataLoader可以或许控制batch的巨细、batch中元素的采样方法,以及将batch效果整理成模子所需输入情势的方法。
[*]参数:常用的参数包罗dataset(表现要加载的数据集对象)、batch_size(表现每个batch的巨细)、shuffle(表现是否在每个epoch开始时打乱数据)、num_workers(表现用于数据加载的进程数)等。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
         batch_sampler=None, num_workers=0, collate_fn=None,
         pin_memory=False, drop_last=False, timeout=0,
         worker_init_fn=None, *, prefetch_factor=2,
         persistent_workers=False)
1.3 Sampler


[*]界说:Sampler是一个抽象类,用于从数据会集天生索引。
[*]功能:Sampler的作用是在Dataset上面举行抽样,抽样的方式有多种,如按序次抽样、随机抽样、在子聚集中随机抽样、带权重的抽样等。
[*]范例:包罗SequentialSampler、RandomSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler等。
1.4 Batching


[*]界说:Batching是指将数据集分成多个小批次(batch)举行处置惩罚的过程。
[*]功能:Batching可以进步数据处置惩罚的服从,并有助于模子练习过程中的梯度更新和收敛。
[*]实现:通过DataLoader的batch_size参数来实现批量加载。
1.5 Shuffling


[*]界说:Shuffling是指在每个epoch开始时打乱数据会集的元素序次。
[*]功能:Shuffling有助于进步模子的泛化本领,防止模子对数据的序次产生依赖。
[*]实现:通过DataLoader的shuffle参数来启用数据打乱功能。
1.6 Multi-process Data Loading


[*]界说:Multi-process Data Loading是指利用多个进程来并行加载数据的过程。
[*]功能:Multi-process Data Loading可以明显进步数据加载的速率,尤其是在处置惩罚大规模数据集时。
[*]实现:通过DataLoader的num_workers参数来设置并行加载的进程数。
2、创建数据集

在PyTorch中,创建数据集通常涉及继续torch.utils.data.Dataset类并实现其必须的方法。以下是一个具体的步调指南,用于创建自界说数据集:

[*]导入须要的库
起首,确保你已经导入了PyTorch和其他大概必要的库。
import torch
from torch.utils.data import Dataset

[*]继续Dataset类
创建一个新的类,继续自Dataset。
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
      # 初始化数据集,存储数据和标签
      self.data = data
      self.labels = labels
      self.transform = transform
      
      # 确保数据和标签的长度相同
      assert len(self.data) == len(self.labels), "Data and labels must have the same length"

    def __len__(self):
      # 返回数据集的大小
      return len(self.data)

    def __getitem__(self, idx):
      # 根据索引获取数据和标签
      sample = self.data
      label = self.labels
      
      # 如果定义了转换,则应用转换
      if self.transform:
            sample = self.transform(sample)
            
      return sample, label

[*]准备数据和标签
在创建CustomDataset实例之前,你必要准备好数据和标签。这些数据可以是图像、文本、数值等,具体取决于你的使命。
# 假设你有一些数据和标签(这里只是示例)
data = # 100个3x32x32的随机图像
labels =    # 100个标签,0或1

[*]创建数据集实例
利用你准备好的数据和标签来创建CustomDataset的实例。
dataset = CustomDataset(data, labels)

[*](可选)应用转换
假如你必要对数据举行预处置惩罚或加强,可以界说一个转换函数,并在创建数据集实例时通报给它。
# 定义一个简单的转换函数(例如,将图像数据标准化)
def normalize(sample):
    return (sample - sample.mean()) / sample.std()

# 创建数据集实例时应用转换
dataset = CustomDataset(data, labels, transform=normalize)

[*]利用DataLoader加载数据
末了,利用torch.utils.data.DataLoader来加载数据集,以便举行批量处置惩罚、打乱数据等。
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 现在你可以遍历dataloader来加载数据了
for batch_data, batch_labels in dataloader:
    # 在这里进行模型训练或评估
    pass
留意事项

[*]确保你的数据和标签是可索引的,通常它们应该是列表、NumPy数组或PyTorch张量。
[*]假如你的数据是图像,而且存储在文件体系中,你大概必要在__getitem__方法中实现图像读取和预处置惩罚逻辑。
[*]对于大型数据集,思量利用torchvision.datasets中提供的预界说数据集类,它们通常包罗了常见的图像数据集(如CIFAR、MNIST等)的加载逻辑。
[*]假如数据集太大无法全部加载到内存中,你可以思量利用torch.utils.data.IterableDataset来创建一个可迭代的数据集,如许你就可以按需加载数据了。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: PyTorch利用教程(7)-数据集处置惩罚