论坛
潜水/灌水快乐,沉淀知识,认识更多同行。
ToB圈子
加入IT圈,遇到更多同好之人。
朋友圈
看朋友圈动态,了解ToB世界。
ToB门户
了解全球最新的ToB事件
博客
Blog
排行榜
Ranklist
文库
业界最专业的IT文库,上传资料也可以赚钱
下载
分享
Share
导读
Guide
相册
Album
记录
Doing
应用中心
搜索
本版
文章
帖子
ToB圈子
用户
免费入驻
产品入驻
解决方案入驻
公司入驻
案例入驻
登录
·
注册
只需一步,快速开始
账号登录
立即注册
找回密码
用户名
Email
自动登录
找回密码
密码
登录
立即注册
首页
找靠谱产品
找解决方案
找靠谱公司
找案例
找对的人
专家智库
悬赏任务
圈子
SAAS
IT评测·应用市场-qidao123.com技术社区
»
论坛
›
数据库
›
SQL-Server
›
从零编写一个神经网络完成手写数字的识别分类(pytorch ...
从零编写一个神经网络完成手写数字的识别分类(pytorch实现) ...
农民
论坛元老
|
2024-7-19 14:19:49
|
显示全部楼层
|
阅读模式
楼主
主题
1595
|
帖子
1595
|
积分
4785
马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要
登录
才可以下载或查看,没有账号?
立即注册
x
1. 前言
很多人都有这样的困惑:
“我已经看过很多有关神经网络的书和视频了,但为什么感觉还是似懂非懂呢?”
那是因为,你从来都没有完整的、从头编写并训练过一个神经网络
学习AI相关的算法,尤其是深度学习方向;
真的不是学几个公式,相识几个名词概念就可以的。
因为深度学习,是一门实践课程!
举个例子:
激活函数、丧失函数、前向流传和反向流传,这些概念,相信大家都听过。
几个相关的问题,大家看看能不能答复出来。
激活函数:
激活函数必须要有吗?一般要放在哪里?
是放在线性层盘算后,还是放在线性层盘算前?
或者有没有都可以、放在哪都可以?
丧失函数:
丧失函数是用来做什么的?
有哪些常用的丧失函数?
分类问题用什么丧失函数?回归问题用什么丧失函数?
前向流传和反向流传:
什么是前向流传、什么又是反向流传?
为什么使用Pytorch,要定义前向流传forward函数?
梯度盘算是在前向流传的过程中,还是在反向流传的过程中?
对于上述这些简单问题,如果你觉得很模糊;
好像能答复出来,又好像答复不出来;
那只能阐明一个问题:
就是你从来没动手编写过神经网络。
本文,我会详细解说一个神经网络的编程案例,并附上代码。
大家看完这个案例后,动手写一写,然后再想一想;
就会发现,前面的那些问题,都能迎刃而解。
2. 问题简述
我要使用手写数字识别,这个例子:
来阐明到底怎样设计、实现并训练一个标准的前馈神经网络。
详细来说,我们要设计并训练一个3层的神经网络:
这个神经网络会以数字图像作为输入。经过神经网络的盘算,就会识别出图像中的数字是几,从而实现数字图像的分类。
在这个过程中,重点解说3个块内容:
1)神经网络的设计和实现
2)训练数据的预备和处理
3)模子的训练和测试流程。
3.
神经网络的设计和实现
首先需要观察数据的样子?
为了设计一个处理数字图像的神经网络,首先要弄清晰输入图像的巨细和格式。
此中,分辨率就是是图像的高和宽
可以发现,我们要处理的图片是,28*28像素的灰色单通道图像。
这样的灰色图像,包括了28*28=784个数据点。
每次在处理数字图像时,输入给神经网络的,就是这784个数据点。
在将它输入给神经网络前,这个28*28的二维图片向量,会被展平为1*784巨细的一维线性向量【
因为我们使用的是线性模子,而非卷积模子
】
比如这张图,左侧代表了28*28个像素对应的图像;
右侧是一个展平后的一维向量,包括了x0到x783,一共784个像素点。
这样这个向量才能被神经网络的输入层所接收和处理。
3.1 输入层的设计
我们会使用一个3层神经网络,来处理图片对应的向量x:
如图
输入层需要接收784维的图片向量x。图中的红色箭头,就代表了数据的输入。
x中的每一个维度的数据,都有一个神经元来接收。
因此,输入层就要包含784个神经元。
3.2 隐藏层的设计
隐藏层是指除了输入层的背面层数,也有的是说包含权重的层数,只需要记住,隐藏层的个数即是神经网络层数-1即可。例如本文实现的是3层神经网络,那么隐藏层的个数就是2
隐藏层用于特征提取,它将输入的特征向量,处理为更高级的特征向量。
由于手写数字图像并不复杂,这里就将隐藏层的神经元个数,设置为256。
256就是个履历值,大家也可以设置为128、512,甚至999。
对于手写数字这个问题,并没有太大影响。
这样输入层与隐藏层之间,就会有一个784*256巨细的线性层。
它可以将一个784维的输入向量,转换为256维的输出向量。
该输出向量会继承向前流传,到达输出层。
3.3 输出层的设计
由于终极要将数字图像,识别为0到9,10种可能的数字;
因此,输出层需要定义
10个神经元
,对应这10种数字。
256维的向量,再经过隐藏层和输出层之间的线性层盘算后,就得到了10维的输出效果。
这个10维的向量,就是代表了10个数字的预测得分。
不要忘了还得有softmax层!
为了继承得到10个数字的预测概率,我们还要将输出层的输出,输入到softmax层。
softmax层会将10维的向量,转换为10个概率值,p0到p9。
每个概率值,都对应一个数字,也就是输入图片,是某一个数字的可能性。
另外,p0到p9这10个概率值,相加到一起的总和是1。
这是由softmax函数的性子决定的。
以上就是神经网络的设计思路。
3.4 代码实现
对于初学者,我知道很难直接按照这个设计思路,将代码编写出来。
大家最开始的时间,可以先模仿着写,举行练习;
慢慢的本身就会写出完整的模子了。
下面我会基于刚刚的思路,实现Pytorch代码。
如果想进一步理解代码,最好的方式还是将代码编写出来后,然后再将代码跑起来。
首先,定义神经网络Network。
在init函数中:
定义两个线性层layer1和layer2。
layer1和layer2分别是输入层和隐藏层、隐藏层和输出层之间的线性层。
它们的巨细分别是784*256和256*10。
也就是右侧图中,红色标志的layer1和layer2。
在前向流传,forward函数中:
函数的输入为图像x。
这个x就是1个或者多个,28*28像素数字图像。
在函数中,需要先将输入的图像x,使用view函数,将x展平。
也就是将n*28*28的数据,展平成n*784的数据。
然后将x输入至layer1;
接着使用relu激活;
最后输入至layer2盘算效果,再返回。
另外,需要注意的是:
我没有在forward中直接定义softmax层,
这是因为背面会使用CrossEntropyLoss丧失函数。
在这个丧失函数中,会实现softmax的盘算。
4.
训练数据的预备和处理
如果想要理解一个模子,我们要先理解给它输入的数据。
理解了数据定义和读取,再去看模子,会事半功倍。
4.1 训练数据哪里来?
手写数字识别的训练数据,可以直接使用MNIST数据集。
这个数据集可以从
torchvision.datasets
中获取。
这里会将数据分别保存到train和test两个目录中,此中:
1) train有60000个数据
2)test有10000个数据
它们分别用来模子的训练和测试。
在train和test,这两个目录中,都包括了10个子目录:
子目录的名字就对应了图像中的数字。例如,在名为3的文件夹中,就保存了数字3的图像。
此中图像的名称是随机的字符串签名。
4.2 如那边理和读取这些数据?
完成数据的预备后,实现数据的读取功能,我会基于这一部门的代码举行解说。
初学者在学习这一部门时,只要知道大致的数据处理流程就可以了。
数据的处理包括三块内容。
第1步,图像数据预处理:
需要实现图像的预处理pipeline,transform。
它包括了将图像转为灰度图和转张量两个功能。
这一步可以简单的理解为,将数组数据处理为训练时所用的张量数据。
第2步,构建数据集对象:
数据集对象的作用,就是用来整体操纵训练数据,可以更方便的访问这些数据。
详细来说,使用ImageFolder函数,读取数据文件夹,构建数据集dataset。这个函数会将保存数据的文件夹的名字,作为数据的标签,组织数据。
例如,对于名字为“3”的文件夹,就会将“3”就会作为文件夹中的图像数据的标签。
标签和图像配对,用于后续的训练,ImageFolder使用起来非常方便。
这里我们分别读取训练数据文件夹train和测试数据文件夹test;
这样会得到train_dataset和test_dataset,两个数据集对象。
如果我们此时运行程序,会打印出它们的长度;
会看到,train_dataset是60000,test_dataset是10000。
这就代表了在训练集有60000个数据,测试会合有10000个数据。
第3步,小批量加载数据:
小批量加载数据直接和模子的训练有关。
小批量的数据读取,是训练各类深度学习模子的前提!
以下是创建小批量读取器dataloader的样例代码:
我们会使用train_loader,实现小批量的数据读取。
这里设置小批量的巨细,batch_size=64。
也就是每个批次,包括64个数据,一次盘算64个数据的梯度!
这时如果运行程序,会打印train_loader的长度,然后看到效果是938。
详细来说,60000个训练数据,如果每个小批量,读入64个样本;
那么60000个数据会被分成938组。
我们可以盘算938*64=60032,不敷60000;
这就阐明最后一组,会不够64个数据。
小批量的遍历数据,是训练的关键前提
我们可以通过循环遍历train_loader来获取每个小批量数据。
这里的每一次循环,都会取出64个图像数据,作为一个小批量batch。
此时如果,打印前3个batch观察:
可以看到数据的尺寸data.shape是64*1*28*28:
它表现了每组数据包括64个图像;
每个图像有1个灰色通道;
图像的尺寸是28*28。
接着打印图像的标签label:
可以看到64个图片对应的数字。
此中保存的数值是0到9,对应了10个数字。
5. 模子训练
实际上,对于训练一个深度学习模子,训练后再测试这个深度学习模子;
这两个过程,都是定式。
也就是,无论你训练的模子简单还是复杂,是前馈神经网络还是Transformer,都是哪几个步骤。
当然,对于一些特殊的神经网络,可能会做一些专门的训练优化。
但本质还是那几个步骤,大家在看下面的解说时,重点是相识这些步骤;
对于每句代码的详细含义,如果真相搞懂,最好还是将代码写出来,然后举行运行和调试。
雷同的数据读入步骤
关于模子的训练,前半部门是图像数据的读入。
包括:
1)图像的预处理transform
2)读入并构造数据集train_dataset
3)使用train_loader举行小批量的数据读入。
这一块和刚刚讲的是一样的。
创建核心对象(变量)
在使用Pytorch训练模子时,需要创建三个核心对象(变量)。
大家要记住,无论训练哪种深度学习模子;
下面说的这三个对象,都要创建!
第1个是:
模子本身model,它就是我们设计的神经网络。
第2个是:
优化器optimizer,它用来优化模子中的参数。
初学的时间,直接使用Adam优化器就可以了。
第3个是:
丧失函数criterion,对于分类问题,就直接使用CrossEntropyLoss,交叉熵丧失误差;
进入模子的循环迭代
模子的循环迭代,同样是定式!
大家记住,迭代深度学习模子,就是两层循环。
这两层循环,分别是:
表现训练轮数的外层循环;
表现梯度下降的内层循环!
详细来说:
外层循环,代表了整个训练数据集的遍历次数。
整个训练集要循环多少轮,是10次、20次或者100次都是可能的。
这里根据履历,设置为10次。
内层循环使用train_loader,举行小批量的数据读取。
内层循环,每循环一次,就会举行一次梯度下降算法。
梯度下降算法
内层循环所包含的梯度下降算法,包括了5个步骤。
这5个步骤,又是使用pytorch框架训练模子的定式。
初学的时间,可以先记住。
详细来说:
1)盘算神经网络的前向流传效果output。
2)盘算output和标签label之间的丧失loss。
3)使用backward盘算梯度。
4)使用optimizer.step更新参数。
5)最后将梯度清零。
另外,我们每迭代100个小批量,就打印一次模子的丧失,观察训练的过程。
运行程序,就会观察到,模子的丧失loss,不断变小。
最后使用torch.save保存模子,模子名字为mnist.pth。
这个“mnist.pth”就是我们最后得到的神经网络模子;
将来再举行数字图片的预测时,就要用它来识别图像。
6. 模子测试
完成模子训练后,需要对模子举行测试。
测试的流程与训练差不多,我们要测试出模子的效果。
测试的过程,也相当于模子的“使用过程”了。
前面是雷同的数据读入和模子定义:
首先需要读取测试数据集test_dataset。
然后定义神经网络模子,并加载刚刚训练好的模子文件mnist.pth。
然后是遍历测试数据集,举行预测,统计正确率:
定义变量right,保存正确识别的数量。
遍历test_dataset,将此中的数据x输入到模子model中,盘算效果output。
然后从output中,使用argmax,选择概率最大标签的作为预测效果,保存到predict。
接着对比预测值predict和真实标签y。
这里将识别错误的样本打印了出来。
可以看到错误case的预测值predict、真实值y和文件路径。
终极盘算出的测试效果为0.978。
也就是
10000个数据,有9779个数据识别正确
。
以上就是从零设计并训练神经网络的过程。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复
使用道具
举报
0 个回复
倒序浏览
返回列表
快速回复
高级模式
B
Color
Image
Link
Quote
Code
Smilies
您需要登录后才可以回帖
登录
or
立即注册
本版积分规则
发表回复
回帖并转播
回帖后跳转到最后一页
发新帖
回复
农民
论坛元老
这个人很懒什么都没写!
楼主热帖
数据库入门
肝了五万字把SQL数据库从基础到高级所 ...
java反射大白话
iOS WebRTC 点对点实时音视频流程介绍 ...
Java中set集合简介说明
【R语言数据科学】(十二):有趣的概 ...
每日算法之数组中的逆序对
消息队列常见的使用场景
flume基本安装与使用
CentOS 7.9 安装 rocketmq-4.9.2
标签云
集成商
AI
运维
CIO
存储
服务器
浏览过的版块
物联网
登录参与点评抽奖加入IT实名职场社区
下次自动登录
忘记密码?点此找回!
登陆
新用户注册
用其它账号登录:
关闭
快速回复
返回顶部
返回列表