Pytorch框架实现mnist手写库识别(与tensorflow对比)


Posted in Python onJuly 20, 2020

前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。

Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://3water.com/article/191157.htm

Pytorch实战mnist手写数字识别

#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集

#下载数据集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#构建网络(网络结构对应tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU运行

#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#训练过程
for epoch in range(1):
  net.train() ##在进行训练时加上train(),测试时加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向传播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##这里得到的train_correct是一个longtensor型,需要转换为float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
    loss.backward() #loss反向传播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

运行结果:

Pytorch框架实现mnist手写库识别(与tensorflow对比)

到此这篇关于Pytorch框架实现mnist手写库识别(与tensorflow对比)的文章就介绍到这了,更多相关Pytorch框架实现mnist手写库识别(与tensorflow对比)内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友
Apr 09 Python
在Python中使用SimpleParse模块进行解析的教程
Apr 11 Python
Python读取图片为16进制表示简单代码
Jan 19 Python
python使用筛选法计算小于给定数字的所有素数
Mar 19 Python
对python同一个文件夹里面不同.py文件的交叉引用方法详解
Dec 15 Python
如何基于Python + requests实现发送HTTP请求
Jan 13 Python
动态设置django的model field的默认值操作步骤
Mar 30 Python
python读取配置文件方式(ini、yaml、xml)
Apr 09 Python
基于Python绘制个人足迹地图
Jun 01 Python
Python中BeautifulSoup通过查找Id获取元素信息
Dec 07 Python
Django视图类型总结
Feb 17 Python
matplotlib画混淆矩阵与正确率曲线的实例代码
Jun 01 Python
python集合能干吗
Jul 19 #Python
python如何建立全零数组
Jul 19 #Python
解决python中0x80072ee2错误的方法
Jul 19 #Python
python给视频添加背景音乐并改变音量的具体方法
Jul 19 #Python
python中加背景音乐如何操作
Jul 19 #Python
python实现最短路径的实例方法
Jul 19 #Python
python等待10秒执行下一命令的方法
Jul 19 #Python
You might like
全国FM电台频率大全 - 19 广东省
2020/03/11 无线电
深入PHP变量存储的详解
2013/06/13 PHP
PHP对接微信公众平台消息接口开发流程教程
2014/03/25 PHP
windows系统php环境安装swoole具体步骤
2021/03/04 PHP
select 控制网页内容隐藏于显示的实现代码
2010/05/25 Javascript
javascript强大的日期函数代码分享
2013/09/04 Javascript
JavaScript动态创建div属性和样式示例代码
2013/10/09 Javascript
JavaScript中字面量与函数的基本使用知识
2015/10/20 Javascript
利用CSS3在Angular中实现动画
2016/01/15 Javascript
vue实现移动端图片裁剪上传功能
2020/08/18 Javascript
详解Vue组件插槽的使用以及调用组件内的方法
2018/11/13 Javascript
深入了解JavaScript 的 WebAssembly
2019/06/15 Javascript
使用layui+ajax实现简单的菜单权限管理及排序的方法
2019/09/10 Javascript
JavaScript, select标签元素左右移动功能实现
2020/05/14 Javascript
JavaScript实现图片放大预览效果
2020/11/02 Javascript
[01:01:22]VGJ.S vs OG 2018国际邀请赛淘汰赛BO3 第一场 8.22
2018/08/23 DOTA
11个并不被常用但对开发非常有帮助的Python库
2015/03/31 Python
使用python实现rsa算法代码
2016/02/17 Python
pandas series序列转化为星期几的实例
2018/04/11 Python
python中的turtle库函数简单使用教程
2018/07/23 Python
Python给图像添加噪声具体操作
2019/03/03 Python
python3+selenium自动化测试框架详解
2019/03/17 Python
Python3内置模块之json编解码方法小结【推荐】
2020/12/09 Python
python中if及if-else如何使用
2020/06/02 Python
多个版本的python共存时使用pip的正确做法
2020/10/26 Python
用canvas实现图片滤镜效果附演示
2013/11/05 HTML / CSS
美丽的现代设计家具:2Modern
2018/07/26 全球购物
.net面试题
2015/12/22 面试题
linux面试题参考答案(2)
2015/12/06 面试题
感恩节活动方案
2014/01/27 职场文书
关于廉洁的广播稿
2014/01/30 职场文书
企业元宵节主持词
2014/03/25 职场文书
英语教育专业毕业生求职信
2014/08/28 职场文书
先进学校事迹材料
2014/12/30 职场文书
信息简报范文
2015/07/21 职场文书
Redis基于Bitmap实现用户签到功能
2021/06/20 Redis