Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python3编写C/S网络程序实例教程
Aug 25 Python
浅要分析Python程序与C程序的结合使用
Apr 07 Python
使用 Python 实现简单的 switch/case 语句的方法
Sep 17 Python
Python批量生成幻影坦克图片实例代码
Jun 04 Python
如何使用Python实现斐波那契数列
Jul 02 Python
Django单元测试中Fixtures用法详解
Feb 25 Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 Python
python集合能干吗
Jul 19 Python
Python基于opencv的简单图像轮廓形状识别(全网最简单最少代码)
Jan 28 Python
python实现杨辉三角的几种方法代码实例
Mar 02 Python
python 判断文件或文件夹是否存在
Mar 18 Python
Python正则表达式中flags参数的实例详解
Apr 01 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
一个简单的自动发送邮件系统(一)
2006/10/09 PHP
MySQL相关说明
2007/01/15 PHP
CodeIgniter多语言实现方法详解
2016/01/20 PHP
Yii2.0 Basic代码中路由链接被转义的处理方法
2016/09/21 PHP
PHP使用GD库制作验证码的方法(点击验证码或看不清会刷新验证码)
2017/08/15 PHP
laravel 实现根据字段不同值做不同查询
2019/10/23 PHP
限制复选框的最大可选数
2006/07/01 Javascript
js 数组克隆方法 小结
2010/03/20 Javascript
event.X和event.clientX的区别分析
2011/10/06 Javascript
Javascript继承(上)——对象构建介绍
2012/11/08 Javascript
uploadify在Firefox下丢失session问题的解决方法
2013/08/07 Javascript
纯JavaScript代码实现移动设备绘图解锁
2015/10/16 Javascript
js省市联动效果完整实例代码
2015/12/09 Javascript
实例讲解JS中setTimeout()的用法
2016/01/28 Javascript
很棒的vue弹窗组件
2017/05/24 Javascript
axios封装,使用拦截器统一处理接口,超详细的教程(推荐)
2019/05/02 Javascript
微信小程序之几种常见的弹框提示信息实现详解
2019/07/11 Javascript
JavaScript简易计算器制作
2020/01/17 Javascript
JavaScript 闭包的使用场景
2020/09/17 Javascript
uni-app 自定义底部导航栏的实现
2020/12/11 Javascript
[01:14:41]DOTA2-DPC中国联赛定级赛 iG vs Magma BO3第一场 1月8日
2021/03/11 DOTA
Python tkinter模块弹出窗口及传值回到主窗口操作详解
2017/07/28 Python
Python如何优雅获取本机IP方法
2019/11/10 Python
解决Tensorflow占用GPU显存问题
2020/02/03 Python
python数据库编程 ODBC方式实现通讯录
2020/03/27 Python
Python 爬虫的原理
2020/07/30 Python
CSS3实现水平居中、垂直居中、水平垂直居中的实例代码
2020/02/27 HTML / CSS
泰国的头号网上婴儿用品店:Motherhood.co.th
2019/04/09 全球购物
C#如何允许一个类被继承但是避免这个类的方法被重载?
2015/02/24 面试题
linux面试题参考答案(7)
2012/10/29 面试题
优秀求职自荐信怎样写
2013/12/18 职场文书
婚庆司仪主持词
2014/03/15 职场文书
小学爱国卫生月活动总结
2014/06/30 职场文书
创先争优承诺书
2015/01/20 职场文书
离职信范本
2015/06/23 职场文书
五年级作文之学校的四季
2019/12/05 职场文书