Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python高效编程技巧
Jan 07 Python
简单讲解Python中的字符串与字符串的输入输出
Mar 13 Python
Python通过future处理并发问题
Oct 17 Python
在python中pandas读文件,有中文字符的方法
Dec 12 Python
Python发展简史 Python来历
May 14 Python
python 非线性规划方式(scipy.optimize.minimize)
Feb 11 Python
在python中利用dict转json按输入顺序输出内容方式
Feb 27 Python
PyQt中使用QtSql连接MySql数据库的方法
Jul 28 Python
通过案例解析python鸭子类型相关原理
Oct 10 Python
python源码剖析之PyObject详解
May 18 Python
Python爬取用户观影数据并分析用户与电影之间的隐藏信息!
Jun 29 Python
Python字符串格式化方式
Apr 07 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
PHP 魔术函数使用说明
2010/05/14 PHP
php实现文件下载简单示例(代码实现文件下载)
2014/03/10 PHP
ThinkPHP使用getlist方法实现数据搜索功能示例
2017/05/08 PHP
PHP+ajax实现二级联动菜单功能示例
2018/08/10 PHP
支持ie与FireFox的剪切板操作代码
2009/09/28 Javascript
JavaScript的eval JSON object问题
2009/11/15 Javascript
基于jsTree的无限级树JSON数据的转换代码
2010/07/27 Javascript
网站页面自动跳转实现方法PHP、JSP(下)
2010/08/01 Javascript
理解Javascript_02_理解undefined和null
2010/10/11 Javascript
JavaScript修改css样式style动态改变元素样式
2013/12/16 Javascript
js中的getAttribute方法使用示例
2014/08/01 Javascript
原生js仿jq判断当前浏览器是否为ie,精确到ie6~8
2014/08/30 Javascript
jquery中使用循环下拉菜单示例代码
2014/09/24 Javascript
如何调试异步加载页面里包含的js文件
2014/10/30 Javascript
jquery解决客户端跨域访问问题
2015/01/06 Javascript
jquery点击切换背景色的简单实例
2016/08/25 Javascript
js实现炫酷的左右轮播图
2017/01/18 Javascript
利用vue + koa2 + mockjs模拟数据的方法教程
2017/11/22 Javascript
JS实现小星星特效
2019/12/24 Javascript
vue 二维码长按保存和复制内容操作
2020/09/22 Javascript
Openlayers学习之加载鹰眼控件
2020/09/28 Javascript
[55:45]LGD vs OG 2019国际邀请赛淘汰赛 胜者组 BO3 第三场 8.24
2019/09/10 DOTA
Python实现1-9数组形成的结果为100的所有运算式的示例
2017/11/03 Python
Pycharm的Available Packages为空的解决方法
2020/09/18 Python
CSS3字体效果的设置方法小结
2016/06/13 HTML / CSS
CSS3中的@keyframes关键帧动画的选择器绑定
2016/06/13 HTML / CSS
使用html2canvas.js实现页面截图并显示或上传的示例代码
2018/12/18 HTML / CSS
Melissa香港官网:MDreams
2016/07/01 全球购物
电气自动化大学生求职信
2013/10/16 职场文书
上课迟到检讨书100字
2014/01/11 职场文书
学生感冒英文请假条
2014/02/04 职场文书
学校联谊活动方案
2014/02/15 职场文书
新学期国旗下演讲稿
2014/05/08 职场文书
委托书的格式
2014/08/01 职场文书
2021-4-5课程——SQL Server查询【3】
2021/04/05 SQL Server
详解OpenCV获取高动态范围(HDR)成像
2022/04/29 Python