一、模型生存与加载
1、序列化方式
生存方式:torch.save(model, "model.pkl")
打开方式:model = torch.load("model.pkl", map_location="cpu")
-
- import torch
- import torch.nn as nn
- class MyModle(nn.Module):
- def __init__(self, input_size, output_size):
- super(MyModle, self).__init__()
- self.fc1 = nn.Linear(input_size, 128)
- self.fc2 = nn.Linear(128, 64)
- self.fc3 = nn.Linear(64, output_size)
- def forward(self, x):
- x = self.fc1(x)
- x = self.fc2(x)
- output = self.fc3(x)
- return output
- model = MyModle(input_size=128, output_size=32)
- # 序列化方式保存模型对象
- torch.save(model, "model.pkl")
- # 注意设备问题
- model = torch.load("model.pkl", map_location="cpu")
- print(model)
-
复制代码 2、生存模型参数
设置需要生存的模型参数:
save_dict = {
"init_params": {
"input_size": 128, # 输入特性数
"output_size": 32, # 输出特性数
},
"accuracy": 0.99, # 模型正确率
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
生存模型参数:torch.save(save_dict, "名称.pth"),一般使用 pth 作为后缀
创建新模型时调用生存的模型参数:
加载模型参数:torch.load("名称.pth")
input_size = save_dict["init_params"]["input_size"]
- import torch
- import torch.nn as nn
- import torch.optim as optim
- class MyModle(nn.Module):
- def __init__(self, input_size, output_size):
- super(MyModle, self).__init__()
- self.fc1 = nn.Linear(input_size, 128)
- self.fc2 = nn.Linear(128, 64)
- self.fc3 = nn.Linear(64, output_size)
- def forward(self, x):
- x = self.fc1(x)
- x = self.fc2(x)
- output = self.fc3(x)
- return output
- save_dict = torch.load("模型参数保存名称.pth")
- model = MyModle(
- input_size=save_dict["init_params"]["input_size"],
- output_size=save_dict["init_params"]["output_size"],
- )
- # 初始化模型参数
- model.load_state_dict(save_dict["model_state_dict"])
- optimizer = optim.SGD(model.parameters(), lr=0.01)
- # 初始化优化器参数
- optimizer.load_state_dict(save_dict["optimizer_state_dict"])
- # 打印模型信息
- print(save_dict["accuracy"])
- print(model)
复制代码 二、数据增强
具体参考官方文档:Illustration of transforms — Torchvision 0.20 documentation
1、官方代码-主体
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- from PIL import Image
- plt.rcParams["savefig.bbox"] = "tight"
- torch.manual_seed(0)
- orig_img = Image.open("../../data/1.png")
- def plot(imgs, title, with_orig=True, row_title=None, **imshow_kwargs):
- if not isinstance(imgs[0], list):
- # Make a 2d grid even if there's just 1 row
- imgs = [imgs]
- num_rows = len(imgs)
- num_cols = len(imgs[0]) + with_orig
- fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
- plt.title(title)
- for row_idx, row in enumerate(imgs):
- row = [orig_img] + row if with_orig else row
- for col_idx, img in enumerate(row):
- ax = axs[row_idx, col_idx]
- ax.imshow(np.asarray(img), **imshow_kwargs)
- ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
- if with_orig:
- axs[0, 0].set(title="Original image")
- axs[0, 0].title.set_size(8)
- if row_title is not None:
- for row_idx in range(num_rows):
- axs[row_idx, 0].set(ylabel=row_title[row_idx])
- plt.tight_layout()
- plt.show()
复制代码 2、固定转换
2.1、pad 边缘添补
就是在照片四周添加黑色框区域
padded_imgs = [v2.Pad(padding=padding)(orig_img) for padding in (3, 10, 30, 50)]
plot([orig_img] + padded_imgs, "v2.Pad")
2.2、resize 巨细调解
resized_imgs = [v2.Resize(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)]
plot([orig_img] + resized_imgs, "v2.Resize")
2.3、center crop 中央裁剪
center_crops = [
v2.CenterCrop(size=size)(orig_img) for size in (30, 50, 100, orig_img.size)
]
plot([orig_img] + center_crops, "v2.CenterCrop")
2.4、five crop 周边裁剪
(top_left, top_right, bottom_left, bottom_right, center) = v2.FiveCrop(size=(100, 100))(
orig_img
)
plot(
[orig_img] + [top_left, top_right, bottom_left, bottom_right, center], "v2.FiveCrop"
)
3、随机转换
3.1、RandomRotation 随机旋转
rotater = v2.RandomRotation(degrees=(0, 180)) # 随机从0-180获取一个数值
rotated_imgs = [rotater(orig_img) for _ in range(4)] # 根据随机数值得到角度转变
plot([orig_img] + rotated_imgs)
3.2、RandomAffine 随机仿射
affine_transfomer = v2.RandomAffine(degrees=(30, 70), translate=(0.1, 0.3), scale=(0.5, 0.75))
affine_imgs = [affine_transfomer(orig_img) for _ in range(4)]
plot([orig_img] + affine_imgs)
4、数据增强整合
- from PIL import Image
- from pathlib import Path
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- from torchvision import transforms, datasets, utils
- def test001():
- # 定义数据增强和预处理步骤
- transform = transforms.Compose(
- [transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.RandomRotation(10), # 随机旋转 ±10 度
- transforms.RandomResizedCrop( 32, scale=(0.8, 1.0) ), # 随机裁剪到 32x32,缩放比例在0.8到1.0之间
- transforms.ColorJitter(
- brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
- ), # 随机调整亮度、对比度、饱和度、色调
- transforms.ToTensor(), # 转换为 Tensor
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化
- ]
- )
- # 加载 CIFAR-10 数据集,并应用数据增强
- trainset = datasets.CIFAR10(
- root="../../data", train=True, download=True, transform=transform
- )
- trainloader = torch.utils.data.DataLoader(
- trainset, batch_size=4, shuffle=True, num_workers=2
- )
- # 显示增强后的图像
- dataiter = iter(trainloader)
- images, labels = next(dataiter)
- def imshow(img):
- img = img / 2 + 0.5 # 反归一化
- npimg = img.numpy()
- plt.imshow(np.transpose(npimg, (1, 2, 0)))
- plt.show()
- imshow(utils.make_grid(images))
- test001()
复制代码 三、神经网络
1、人工神经元
吸收多个输入的信息并举行加权求和,使用激活函数处理得到最后效果。
人工神经元的设置方法是对比生物神经元的结构:
生物神经元人工神经元细胞核节点 (加权求和 + 激活函数)树突输入轴突带权重的连接突触输出 2、神经网络
由大量人工神经元按层次结构连接而成的盘算模型,上一层的神经元输出作为下一层神经元的输入,层之间的神经元并无连接。
2.1、结构
输入层:整个神经网络的第一层,负责吸收外部数据,不做任何盘算。
隐蔽层:位于神经网络输入层与输出层之间的内容,举行特性提取、转化、盘算等操作,一般为多层神经元组成。
输出层:吸收隐蔽层的盘算效果,产生预测效果或分类效果
2.2、全连接神经网络
每一层的单个神经元都与上一层的所有神经元连接,一般用于图像分类、文本等。
3、参数初始化(权重、偏置)
权重和偏置:model.weight、model.bias
初始化使用 torch.nn.init 库的方法
3.1、固定值初始化--全零化、全1化、常数化
将参数所有数据变成固定值,一般不消于初始化权重(会粉碎对称性),用于初始化偏置。
torch.nn.init.zeros_(model.weight) :参数为初始化对象,只能一个tensor。
torch.nn.init.ones_(model.weight) :参数为初始化对象,只能一个tensor。
torch.nn.init.constant_(model.weight) :参数1为初始化对象,只能一个tensor;参数2为设置的浮点数。
- from torch.nn import Linear
- import torch.nn
- model = Linear(4,1)
- # 参数只有一个,初始化对象,不能一次初始化多个对象
- torch.nn.init.zeros_(model.weight)
- torch.nn.init.zeros_(model.bias)
- print(model.weight,model.bias)
- torch.nn.init.ones_(model.weight)
- torch.nn.init.ones_(model.bias)
- print(model.weight,model.bias)
- model = Linear(4,1)
- # 使用叶子节点来初始化为0
- model.weight.detach().zero_()
- model.bias.detach().zero_()
- print(model.weight,model.bias)
- model = Linear(4,1)
- # 参数第一个为初始化对象,第二个为设定的浮点数;
- # 不能一次初始化多个对象
- torch.nn.init.constant_(model.weight,5.)
- torch.nn.init.constant_(model.bias,5.)
- print(model.weight,model.bias)
复制代码 3.2、随机初始化
normal_、uniform_:将权重初始化为随机的小值,通常从正态分布或均匀分布中采样;能避免对称性粉碎。
- from torch.nn import Linear
- import torch.nn
- model = Linear(4,1)
- # 参数第一个为初始化对象;
- # 参数2、3为均值和标准差,默认 0,1标准正太分布
- torch.nn.init.normal_(model.weight)
- torch.nn.init.normal_(model.bias)
- print(model.weight,model.bias)
- # 参数第二个和第三个为下界和上界,默认0-1
- torch.nn.init.uniform_(model.weight, 0,1)
- torch.nn.init.uniform_(model.bias)
- print(model.weight,model.bias)
复制代码 3.3、Xavier 初始化
对随机初始化添加取值限定。
平衡了输入和输出的方差,适合Sigmoid 和 Tanh 激活函数或浅层网络。
- from torch.nn import Linear
- import torch.nn
- model = Linear(4,1)
- # 参数第一个为初始化对象;
- # 第二个参数 gain 是缩放因子
- torch.nn.init.xavier_normal_(model.weight)
- print(model.weight)
- torch.nn.init.xavier_uniform_(model.weight)
- print(model.weight)
- """
- 常见的 gain 值:
- 线性激活函数:gain = 1.0(默认值)
- Sigmoid 激活函数:gain = 1.0
- Tanh 激活函数:gain = 5/3(约等于 1.653)
- ReLU 激活函数:gain = sqrt(2)(约等于 1.414)
- Leaky ReLU 激活函数:gain = sqrt(2 / (1 + negative_slope^2)),其中 negative_slope 是 Leaky ReLU 的负斜率,默认值为 0.01。
- """
复制代码 3.4、He初始化 (kaiming 初始化 )
专门为 ReLU 激活函数设计;权重从以下分布中采样 , 是当前层的输入神经元数目。
- from torch.nn import Linear
- import torch.nn
- model = Linear(4,1)
- # 参数第一个为初始化对象;
- # a 为负斜率的值(relu负数为0,所以此参数只有在relu衍生的函数有效 leaky_relu)
- # nonlinearity 默认 leaky_relu
- # mode 默认 fan-in 使用输入单元数量计算初始化值
- torch.nn.init.kaiming_normal_(model.weight)
- print(model.weight)
- torch.nn.init.kaiming_uniform_(model.weight)
- print(model.weight)
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |