详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读取txt某几列绘图的方法
Oct 14 Python
Python numpy.array()生成相同元素数组的示例
Nov 12 Python
python生成以及打开json、csv和txt文件的实例
Nov 16 Python
基于Numpy.convolve使用Python实现滑动平均滤波的思路详解
May 16 Python
Python实现12306火车票抢票系统
Jul 04 Python
python读取并写入mat文件的方法
Jul 12 Python
python批量处理txt文件的实例代码
Jan 13 Python
解决TensorFlow GPU版出现OOM错误的问题
Feb 03 Python
PyCharm 2020.2 安装详细教程
Sep 25 Python
pandas 按日期范围筛选数据的实现
Feb 20 Python
python基于turtle绘制几何图形
Jun 15 Python
Python学习之os包使用教程详解
Mar 21 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
php daddslashes()和 saddslashes()有哪些区别分析
2012/10/26 PHP
PHP中实现crontab代码分享
2015/03/26 PHP
Java中final关键字详解
2015/08/10 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
再谈Jquery Ajax方法传递到action(补充)
2014/05/12 Javascript
深入探寻seajs的模块化与加载方式
2015/04/14 Javascript
浅谈Node.js:理解stream
2016/12/08 Javascript
jquery自定义插件结合baiduTemplate.js实现异步刷新(附源码)
2016/12/22 Javascript
jQuery插件HighCharts实现的2D堆条状图效果示例【附demo源码下载】
2017/03/14 Javascript
Vue异步组件使用详解
2017/04/08 Javascript
bootstrap multiselect 多选功能实现方法
2017/06/05 Javascript
Vue内部渲染视图的方法
2019/09/02 Javascript
Python数据类型详解(四)字典:dict
2016/05/12 Python
教你用 Python 实现微信跳一跳(Mac+iOS版)
2018/01/04 Python
python字符串与url编码的转换实例
2018/05/10 Python
Django+Ajax+jQuery实现网页动态更新的实例
2018/05/28 Python
python 寻找list中最大元素对应的索引方法
2018/06/28 Python
Python基础学习之类与实例基本用法与注意事项详解
2019/06/17 Python
opencv3/C++图像像素操作详解
2019/12/10 Python
Python + Requests + Unittest接口自动化测试实例分析
2019/12/12 Python
使用python turtle画高达
2020/01/19 Python
python词云库wordcloud的使用方法与实例详解
2020/02/17 Python
在Django中自定义filter并在template中的使用详解
2020/05/19 Python
keras的ImageDataGenerator和flow()的用法说明
2020/07/03 Python
Python通过getattr函数获取对象的属性值
2020/10/16 Python
利用python+request通过接口实现人员通行记录上传功能
2021/01/13 Python
Ryderwear澳洲官网:澳大利亚高端健身训练装备品牌
2018/09/18 全球购物
德国的各种媒体在线商店:Thalia.de(书籍、电子书、玩具等)
2020/10/08 全球购物
岗位职责的构建方法
2014/02/01 职场文书
喜之郎果冻广告词
2014/03/20 职场文书
临时工聘用合同协议书
2014/10/29 职场文书
2015年卫生监督工作总结
2015/05/21 职场文书
教师纪律作风整顿心得体会
2016/01/23 职场文书
36个正则表达式(开发效率提高80%)
2021/11/17 Javascript
上个世纪50年代的可穿戴技术:无线电帽子
2022/02/18 无线电
win10清理dns缓存
2022/04/19 数码科技