详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python通过DOM和SAX方式解析XML的应用实例分享
Nov 16 Python
常用python编程模板汇总
Feb 12 Python
Python画柱状统计图操作示例【基于matplotlib库】
Jul 04 Python
Django框架HttpRequest对象用法实例分析
Nov 01 Python
Django中密码的加密、验密、解密操作
Dec 19 Python
Python3标准库之dbm UNIX键-值数据库问题
Mar 24 Python
python 弧度与角度互转实例
Apr 15 Python
Python判断远程服务器上Excel文件是否被人打开的方法
Jul 13 Python
分享unittest单元测试框架中几种常用的用例加载方法
Dec 02 Python
Python基于Tkinter开发一个爬取B站直播弹幕的工具
May 06 Python
教你怎么用PyCharm为同一服务器配置多个python解释器
May 31 Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
百事可乐也出咖啡了 双倍咖啡因双倍快乐
2021/03/03 咖啡文化
php面向对象全攻略 (六)__set() __get() __isset() __unset()的用法
2009/09/30 PHP
php实现使用正则将文本中的网址转换成链接标签
2014/12/03 PHP
PHP实现将浏览历史页面网址保存到cookie的方法
2015/01/26 PHP
php判断文件夹是否存在不存在则创建
2015/04/09 PHP
php实现xml转换数组的方法示例
2017/02/03 PHP
兼容IE与firefox火狐的回车事件(js与jquery)
2010/10/20 Javascript
js判断是否为数组的函数: isArray()
2011/10/30 Javascript
checkbox使用示例
2013/08/23 Javascript
jquery checkbox实现单选小例
2013/11/27 Javascript
jQuery实现Div拖动+键盘控制综合效果的方法
2015/03/10 Javascript
在JavaScript中用getMinutes()方法返回指定的分时刻
2015/06/10 Javascript
jQuery 3.0中存在问题及解决办法
2016/07/15 Javascript
javascript实现根据汉字获取简拼
2016/09/25 Javascript
DOM 事件的深入浅出(二)
2016/12/05 Javascript
Javascript中字符串replace方法的第二个参数探究
2016/12/05 Javascript
JavaScript实现滑动导航栏效果
2017/08/30 Javascript
nodejs项目windows下开机自启动的方法
2017/11/22 NodeJs
vue实现个人信息查看和密码修改功能
2018/05/06 Javascript
Python Trie树实现字典排序
2014/03/28 Python
Python中使用gflags实例及原理解析
2019/12/13 Python
Pytorch 高效使用GPU的操作
2020/06/27 Python
PyTorch中的拷贝与就地操作详解
2020/12/09 Python
python中doctest库实例用法
2020/12/31 Python
CSS3 @media的基本用法总结
2019/09/10 HTML / CSS
三陽商会官方网站:Sanyo iStore
2019/05/15 全球购物
90后毕业生的求职信范文
2013/09/21 职场文书
文明宿舍获奖感言
2014/02/07 职场文书
致跳远运动员广播稿
2014/02/11 职场文书
2014年感恩母亲演讲稿
2014/05/27 职场文书
企业员工爱岗敬业演讲稿
2014/08/26 职场文书
先进基层党组织主要事迹材料
2015/11/03 职场文书
SQLServer 日期函数大全(小结)
2021/04/08 SQL Server
python 命令行传参方法总结
2021/05/25 Python
React Fragment介绍与使用详解
2021/11/11 Javascript
springboot+rabbitmq实现智能家居实例详解
2022/07/23 Java/Android