Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
研究Python的ORM框架中的SQLAlchemy库的映射关系
Apr 25 Python
python对DICOM图像的读取方法详解
Jul 17 Python
Python实现PS滤镜的万花筒效果示例
Jan 23 Python
Python中数组,列表:冒号的灵活用法介绍(np数组,列表倒序)
Apr 18 Python
浅谈JupyterNotebook导出pdf解决中文的问题
Apr 22 Python
python爬虫scrapy图书分类实例讲解
Nov 23 Python
Python爬虫之Selenium实现键盘事件
Dec 04 Python
Python实现Word文档转换Markdown的示例
Dec 22 Python
Pandas数据分析的一些常用小技巧
Feb 07 Python
python上下文管理器异常问题解决方法
Feb 07 Python
如何用 Python 子进程关闭 Excel 自动化中的弹窗
May 07 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
php之static静态属性与静态方法实例分析
2015/07/30 PHP
php数据序列化测试实例详解
2017/08/12 PHP
JQuery slideshow的一个小问题(如何发现及解决过程)
2013/02/06 Javascript
ie 7/8不支持trim的属性的解决方案
2014/05/23 Javascript
jQuery中removeData()方法用法实例
2014/12/27 Javascript
JavaScript插件化开发教程 (三)
2015/01/27 Javascript
javascript中局部变量和全局变量的区别详解
2015/02/27 Javascript
JS+CSS模拟可以无刷新显示内容的留言板实例
2015/03/03 Javascript
简介JavaScript中的unshift()方法的使用
2015/06/09 Javascript
jQuery+css实现的蓝色水平二级导航菜单效果代码
2015/09/11 Javascript
详解js中构造流程图的核心技术JsPlumb(2)
2015/12/08 Javascript
JS组件Bootstrap导航条使用方法详解
2016/04/29 Javascript
深入理解JavaScript单体内置对象
2016/06/06 Javascript
举例讲解jQuery对DOM元素的向上遍历、向下遍历和水平遍历
2016/07/07 Javascript
Node.js利用debug模块打印出调试日志的方法
2017/04/25 Javascript
JavaScript制作简单的框选图表
2017/05/15 Javascript
关于react-router的几种配置方式详解
2017/07/24 Javascript
JS实现倒计时图文效果
2018/11/17 Javascript
加快Vue项目的开发速度的方法
2018/12/12 Javascript
js中数组对象去重的两种方法
2019/01/18 Javascript
非常实用的jQuery代码段集锦【检测浏览器、滚动、复制、淡入淡出等】
2019/08/08 jQuery
django自定义Field实现一个字段存储以逗号分隔的字符串
2014/04/27 Python
利用python如何处理百万条数据(适用java新手)
2018/06/06 Python
浅析python中while循环和for循环
2019/11/19 Python
浅谈Python中的字符串
2020/06/10 Python
法国时尚童装网站:Melijoe
2016/08/10 全球购物
萨克斯第五大道的折扣店:Saks Fifth Avenue OFF 5TH
2016/08/25 全球购物
早晨薰衣草在线女性精品店:Morning Lavender
2021/01/04 全球购物
库房保管员岗位职责
2014/04/07 职场文书
营销经理工作检讨书
2014/11/03 职场文书
给男朋友的道歉短信
2015/05/12 职场文书
黑暗中的舞者观后感
2015/06/18 职场文书
《搭石》教学反思
2016/02/18 职场文书
Python爬虫之爬取哔哩哔哩热门视频排行榜
2021/04/28 Python
React forwardRef的使用方法及注意点
2021/06/13 Javascript
Redis实现短信验证码登录的示例代码
2022/06/14 Redis