IT评测·应用市场-qidao123.com
标题:
【Pytorch常用模块总结】
[打印本页]
作者:
十念
时间:
2025-3-9 03:42
标题:
【Pytorch常用模块总结】
数据准备
数据集预处理
torchvision.transforms
transforms.ToTensor():将 PIL 图像或 NumPy 数组转换为张量
transforms.Normalize(mean, std):标准化数据,指定均值和标准差
transforms.Resize(size):调解图像大小
transforms.RandomCrop(size):随机裁剪
transforms.RandomHorizontalFlip():随机水平翻转,用于数据加强
transforms.Compose(transforms_list):组合多个变换
数据集的导入
自建数据集
torch.utils.data.Dataset
__init__:初始化数据集(如加载文件路径、标签)
__len__:返回数据集大小
__getitem__:定义如何获取单个样本及其标签
可搭配 torchvision.transforms 进行预处理
通用数据集
torchvision.datasets
示例:torchvision.datasets.MNIST(root='./data', train=True, download=True)
参数:train=True/False 区分练习集和测试集
数据集的加载
torch.utils.data.DataLoader
参数:
batch_size:批次大小
shuffle=True/False:是否打乱数据(练习 True,测试 False)
num_workers:多线程加载数据的线程数
drop_last=True:抛弃最后一个不完整批次
定义模型
torch.nn
nn.Module:自定义模型需继承并实现 forward 方法
常用层
nn.Linear(in_features, out_features):全连接层
nn.Conv2d(in_channels, out_channels, kernel_size):二维卷积层
nn.MaxPool2d(kernel_size):最大池化层
激活函数
nn.ReLU()、nn.Sigmoid()
正则化和归一化
nn.Dropout(p):随机抛弃,防止过拟合
nn.BatchNorm2d(num_features):批归一化
nn.Sequential:快速构建简单网络
定义损失函数
torch.nn
nn.CrossEntropyLoss():交叉熵损失(含 Softmax),多分类任务
nn.MSELoss():均方偏差,回归任务
nn.BCELoss() / nn.BCEWithLogitsLoss():二分类任务
根据任务选择合适的损失函数
定义优化器
torch.optim
optim.SGD(model.parameters(), lr, momentum):随机梯度下降
optim.Adam(model.parameters(), lr, weight_decay):Adam 优化器
参数:
lr:学习率
momentum:动量法参数(SGD)
weight_decay:L2 正则化参数
学习率调度器
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma):按步长衰减
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer):根据指标调解
练习模型
torch.nn.Module
model.to(device):将模型移到 GPU/CPU,device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train():进入练习模式
torch.no_grad():禁用梯度盘算(验证/测试时利用)
模型生存与加载
torch.save(model.state_dict(), 'path.pth'):生存模型参数
torch.save(model, 'path.pth'):生存整个模型
model.load_state_dict(torch.load('path.pth')):加载模型参数
练习流程
盘算损失 → 梯度置零 → 反向流传 → 更新参数
loss = loss_fn(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
示例:
for epoch in range(num_epochs):
model.train()
for batch_x, batch_y in data_loader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
output = model(batch_x)
loss = loss_fn(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/)
Powered by Discuz! X3.4