ToB企服应用市场:ToB评测及商务社交产业平台

标题: 【机器学习】手写数字识别 [打印本页]

作者: 圆咕噜咕噜    时间: 2022-8-22 18:18
标题: 【机器学习】手写数字识别
前言

logistic回归,是一个分类算法,可以处理二元分类,多元分类。我们使用sklearn中的logistic对手写数字识别进行实践。
数据集

MNIST数据集来自美国国家标准与技术研究所,训练集由250个不同人手写数字构成,50%高中学生,50%来自人口普查局。
数据集展示


数据集下载

百度云盘:
链接:https://pan.baidu.com/s/1ZBU8XBsx7lp7gdN4ySSIWg
提取码:5mrf
关于使用pycharm图片不显示

pycharm默认会在右边进行绘图,由于某些原因导致图片不能显示,只能是白图的解决办法。
逻辑回归手写数字识别
  1. ## logistis回归,是一个分类算法,可以处理二元分类,多元分类。
  2. ## 首先逻辑回归构造冠以的线性回归函数,然后使用sigmoid函数将回归值映射到散列类别
  3. ## sklearn 分类算法与手写数字识别
  4. ## 数据介绍
  5. ## MNIST数据集来自美国国家标准与技术研究所,训练集由250个不同人手写数字构成,50%高中学生,50%来自人口普查局
  6. ## 导包
  7. import struct,os
  8. import numpy as np
  9. from array import array as pyarray
  10. from numpy import append,array,int8,uint8,zeros
  11. import matplotlib.pyplot as plt
  12. from sklearn.metrics import accuracy_score, classification_report
  13. from sklearn.linear_model import LogisticRegression
  14. ## 加载数据集
  15. def load_mnist(image_file,label_file,path="mnist"):
  16.     digits=np.arange(10)
  17.     fname_image = os.path.join(path,image_file)
  18.     fname_label = os.path.join(path, label_file)
  19.     flbl = open(fname_label,'rb')
  20.     magic_nr, size = struct.unpack(">II", flbl.read(8))
  21.     lbl = pyarray("b",flbl.read())
  22.     flbl.close()
  23.     fimg = open(fname_image,'rb')
  24.     magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
  25.     img = pyarray("B",fimg.read())
  26.     fimg.close()
  27.     ind = [ k for k in range(size) if lbl[k] in digits ]
  28.     N = len(ind)
  29.     images = zeros((N, rows*cols),dtype = uint8)
  30.     labels = zeros((N,1), dtype = int8)
  31.     for i in range(len(ind)):
  32.         images[i] = array(img[ind[i]*rows*cols : (ind[i]+1)*rows*cols]).reshape((1,rows*cols))
  33.         labels[i] = lbl[ind[i]]
  34.     return images,labels
  35. train_image, train_label = load_mnist('train-images.idx3-ubyte', 'train-labels.idx1-ubyte')
  36. test_image, test_label = load_mnist('t10k-images.idx3-ubyte','t10k-labels.idx1-ubyte')
  37. ## 数据展示
  38. ## 28*28
  39. def show_image(imgdata, imgtarget, show_column, show_row,titlename):
  40.     for index, (im, it) in enumerate(list(zip(imgdata, imgtarget))):
  41.         xx = im.reshape(28,28)
  42.         plt.subplots_adjust(left=1, bottom=None, right=3,top=2, wspace=None, hspace=None)
  43.         plt.subplot(show_row,show_column,index+1)
  44.         plt.axis('off')
  45.         plt.imshow(xx, cmap='gray', interpolation='nearest')
  46.         plt.title(titlename+':%i' % it)
  47.     # plt.savefig(titlename+'.png')
  48.     # 这个地方可能会有一个警告,可能因为图太大了,不过没关系,代码正常运行
  49.     plt.show()
  50. show_image(train_image[:50], train_label[:50],10,5,'label')
  51. ## sklearn 分类模型
  52. ## 数据归一化
  53. train_image = [im/255.0 for im in train_image]
  54. test_image = [im/255.0 for im in test_image]
  55. print(len(train_image))
  56. print(len(test_image))
  57. print(len(train_label))
  58. print(len(test_label))
  59. ## 模型分类
  60. ## 模型实例化
  61. lr = LogisticRegression(max_iter=1000)
  62. ## 模型训练
  63. lr.fit(train_image,train_label.ravel())
  64. ## 模型验证
  65. predict = lr.predict(test_image)
  66. print("accuracy score: %.4lf"% accuracy_score(predict,test_label))
  67. print("classfication report for %s:\n%s\n"%(lr, classification_report(test_label, predict)))
  68. show_image(test_image[:100],predict,10,10,'predict')
复制代码
结果展示


分析

我们展示了100张图片的识别效果,可以找到3张明显的识别错误,和模型的评估结果相似。

总结

我们可以多重复运行几次发现结果并没有变化,这可能也是logistic回归的缺点吧,我们也可以使用神经网络进行手写数字识别,但那是深度学习的内容,我们后续会对其进行实现。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4