Pytorch实现的手写数字mnist识别功能完整示例


Posted in Python onDecember 13, 2019

本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定义前向传播过程,输入为x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8  #遍历数据集次数
BATCH_SIZE = 64   #批处理尺寸(batch_size)
LR = 0.001    #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定义测试数据集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 数据读取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每训练100个batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch测试一下准确率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那个类
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
探究Python中isalnum()方法的使用
May 18 Python
详解python开发环境搭建
Dec 16 Python
关于python的bottle框架跨域请求报错问题的处理方法
Mar 19 Python
PyQt5每天必学之切换按钮
Aug 20 Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 Python
python使用PIL模块获取图片像素点的方法
Jan 08 Python
详解Python3之数据指纹MD5校验与对比
Jun 11 Python
使用Rasterio读取栅格数据的实例讲解
Nov 26 Python
Python的形参和实参使用方式
Dec 24 Python
Python如何基于smtplib发不同格式的邮件
Dec 30 Python
使用python实现名片管理系统
Jun 18 Python
Python Selenium XPath根据文本内容查找元素的方法
Dec 07 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 #Python
Python实现将蓝底照片转化为白底照片功能完整实例
Dec 13 #Python
python多进程重复加载的解决方式
Dec 13 #Python
使用pyqt5 tablewidget 单元格设置正则表达式
Dec 13 #Python
Python代码块及缓存机制原理详解
Dec 13 #Python
Python3和pyqt5实现控件数据动态显示方式
Dec 13 #Python
python实现简单日志记录库glog的使用
Dec 13 #Python
You might like
整理的9个实用的PHP库简介和下载
2010/11/09 PHP
php switch语句多个值匹配同一代码块应用示例
2014/07/29 PHP
PHP连接Nginx服务器并解析Nginx日志的方法
2015/08/16 PHP
解决windows上php xdebug 无法调试的问题
2020/02/19 PHP
PHP 扩展Memcached命令用法实例总结
2020/06/04 PHP
javascript中的有名函数和无名函数
2007/10/17 Javascript
nodejs中的fiber(纤程)库详解
2015/03/24 NodeJs
jQuery插件datalist实现很好看的input下拉列表
2015/07/14 Javascript
基于jQuery仿淘宝产品图片放大镜特效
2020/10/19 Javascript
浅谈JavaScript对象与继承
2016/07/10 Javascript
使用Vue动态生成form表单的实例代码
2018/04/26 Javascript
在ES5与ES6环境下处理函数默认参数的实现方法
2018/05/13 Javascript
玩转vue的slot内容分发
2018/09/22 Javascript
React降级配置及Ant Design配置详解
2018/12/27 Javascript
vue的keep-alive中使用EventBus的方法
2019/04/23 Javascript
微信小程序HTTP接口请求封装代码实例
2019/09/05 Javascript
vue使用exif获取图片经纬度的示例代码
2020/12/11 Vue.js
[01:34:42]NAVI vs EG 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
python中多个装饰器的执行顺序详解
2018/10/08 Python
python仿抖音表白神器
2019/04/08 Python
用Python+OpenCV对比图像质量的几种方法
2019/07/15 Python
pandas 空数据处理方法详解
2019/11/02 Python
Python os模块常用方法和属性总结
2020/02/20 Python
python GUI库图形界面开发之PyQt5信号与槽多窗口数据传递详细使用方法与实例
2020/03/08 Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
2020/05/26 Python
python右对齐的实例方法
2020/07/05 Python
Pycharm Available Package无法显示/安装包的问题Error Loading Package List解决
2020/09/18 Python
检测用户浏览器是否支持CSS3的方法
2009/08/29 HTML / CSS
Ralph Lauren英国官方网站:Ralph Lauren UK
2018/04/03 全球购物
金融专业应届生求职信
2013/11/02 职场文书
计算机求职信
2013/12/01 职场文书
防灾减灾活动总结
2014/08/30 职场文书
商铺门面租房协议书
2014/10/21 职场文书
2015年庆祝国庆节66周年演讲稿
2015/07/30 职场文书
2016年秋季新学期致辞
2015/07/30 职场文书
2016教师暑期培训学习心得体会
2016/01/09 职场文书