Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现从百度API获取天气的方法
Mar 11 Python
python中字符串前面加r的作用
Jun 04 Python
Python中的descriptor描述器简明使用指南
Jun 02 Python
Python3 模块、包调用&路径详解
Oct 25 Python
Python实现合并两个有序链表的方法示例
Jan 31 Python
python解压TAR文件至指定文件夹的实例
Jun 10 Python
Python+Selenium使用Page Object实现页面自动化测试
Jul 14 Python
浅谈Python3实现两个矩形的交并比(IoU)
Jan 18 Python
python用WxPython库实现无边框窗体和透明窗体实现方法详解
Feb 21 Python
如何查看Django ORM执行的SQL语句的实现
Apr 20 Python
Python如何绘制日历图和热力图
Aug 07 Python
Python datetime 如何处理时区信息
Sep 02 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
php中关于codeigniter的xmlrpc的类在进行数据交换时的类型问题
2011/07/03 PHP
克隆一个新项目的快捷方式
2013/04/10 PHP
php版微信开发之接收消息,自动判断及回复相应消息的方法
2016/09/23 PHP
PHP实现的62进制转10进制,10进制转62进制函数示例
2019/06/06 PHP
扩展JS Date对象时间格式化功能的小例子
2013/12/02 Javascript
js函数内变量的作用域分析
2015/01/12 Javascript
javascript实现简单的省市区三级联动
2015/05/14 Javascript
jQuery hover事件简单实现同时绑定2个方法
2016/06/07 Javascript
JS代码实现根据时间变换页面背景效果
2016/06/16 Javascript
jQuery EasyUI基础教程之EasyUI常用组件(推荐)
2016/07/15 Javascript
js Canvas实现圆形时钟教程
2016/09/19 Javascript
AngularJS实践之使用NgModelController进行数据绑定
2016/10/08 Javascript
jq stop()和:is(:animated)的用法及区别(详解)
2017/02/12 Javascript
解决OneThink中无法异步提交kindeditor文本框中修改后的内容方法
2017/05/05 Javascript
微信小程序自定义导航教程(兼容各种手机)
2018/12/12 Javascript
微信小程序保存多张图片的实现方法
2019/03/05 Javascript
关于ligerui子页面关闭后,父页面刷新,重新加载的方法
2019/09/27 Javascript
Python内置函数dir详解
2015/04/14 Python
详解python 拆包可迭代数据如tuple, list
2017/12/29 Python
Django项目中用JS实现加载子页面并传值的方法
2018/05/28 Python
pytorch程序异常后删除占用的显存操作
2020/01/13 Python
Python爬取YY评级分数并保存数据实现过程解析
2020/06/01 Python
Python classmethod装饰器原理及用法解析
2020/10/17 Python
使用Html5中的cavas画一面国旗
2019/09/25 HTML / CSS
Vichy薇姿加拿大官网:法国药妆,全球专业敏感肌护肤领先品牌
2018/07/11 全球购物
施华洛世奇巴西官网:SWAROVSKI巴西
2019/12/03 全球购物
Unix如何添加新的用户
2014/08/20 面试题
群众路线批评与自我批评
2014/02/06 职场文书
《散步》教学反思
2014/03/02 职场文书
宣传口号大全
2014/06/16 职场文书
学校安全工作汇报材料
2014/08/16 职场文书
2014年度培训工作总结
2014/11/27 职场文书
初中优秀学生评语
2014/12/29 职场文书
副总经理岗位职责
2015/02/02 职场文书
2015年小学远程教育工作总结
2015/07/28 职场文书
浅谈mysql哪些情况会导致索引失效
2021/11/20 MySQL