pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用multiprocessing实现一个最简单的分布式作业调度系统
Mar 14 Python
Python基于QRCode实现生成二维码的方法【下载,安装,调用等】
Jul 11 Python
Python的mysql数据库的更新如何实现
Jul 31 Python
python中单例常用的几种实现方法总结
Oct 13 Python
python 使用正则表达式按照多个空格分割字符的实例
Dec 20 Python
python+opencv实现高斯平滑滤波
Jul 21 Python
PySide和PyQt加载ui文件的两种方法
Feb 27 Python
Python3分析处理声音数据的例子
Aug 27 Python
python爬虫中url管理器去重操作实例
Nov 30 Python
python 逐步回归算法
Apr 06 Python
pytorch 实现变分自动编码器的操作
May 24 Python
Python 阶乘详解
Oct 05 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
PHP读取汉字的点阵数据
2015/06/22 PHP
php多文件打包下载的实例代码
2017/07/12 PHP
jQuery根据纬度经度查看地图处理程序
2013/05/08 Javascript
JQuery设置和去除disabled属性的5种方法总结
2013/05/16 Javascript
写得不错的jquery table鼠标经过变色代码
2013/09/27 Javascript
阻止事件(取消浏览器对事件的默认行为并阻止其传播)
2013/11/03 Javascript
Jquery实现图片预加载与延时加载的方法
2014/12/22 Javascript
JavaScript中操作字符串之localeCompare()方法的使用
2015/06/06 Javascript
JS控制页面跳转时未请求要跳转的地址怎么回事
2016/10/14 Javascript
assert()函数用法总结(推荐)
2017/01/25 Javascript
JS基于贪心算法解决背包问题示例
2017/11/27 Javascript
微信小程序实现商城倒计时
2020/11/01 Javascript
原生js实现针对Dom节点的CRUD操作示例
2019/08/26 Javascript
vue 强制组件重新渲染(重置)的两种方案
2019/10/29 Javascript
Javascript实现关闭广告效果
2021/01/29 Javascript
[51:29]Alliance vs TNC 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/18 DOTA
详解python3中zipfile模块用法
2018/06/18 Python
Django中reverse反转并且传递参数的方法
2019/08/06 Python
Python使用贪婪算法解决问题
2019/10/22 Python
详解字符串在Python内部是如何省内存的
2020/02/03 Python
python tkinter之 复选、文本、下拉的实现
2020/03/04 Python
python tkiner实现 一个小小的图片翻页功能的示例代码
2020/06/24 Python
Python pip使用超时问题解决方案
2020/08/03 Python
Python如何批量生成和调用变量
2020/11/21 Python
HTML5 图片悬停放大的实现代码示例
2019/12/04 HTML / CSS
澳大利亚相机之家:Camera House
2017/11/30 全球购物
服装机修工岗位职责
2013/12/26 职场文书
《难忘的泼水节》教学反思
2014/02/27 职场文书
小学中等生评语
2014/12/29 职场文书
公司庆典欢迎词
2015/01/26 职场文书
2015年世界无烟日活动总结
2015/02/10 职场文书
卫生院艾滋病宣传活动总结
2015/05/09 职场文书
解决SpringBoot跨域的三种方式
2021/06/26 Java/Android
Python List remove()实例用法详解
2021/08/02 Python
交互式可视化js库gojs使用介绍及技巧
2022/02/18 Javascript
剑指Offer之Java算法习题精讲二叉树的构造和遍历
2022/03/21 Java/Android