pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python分析git log日志示例
Feb 27 Python
Python统计列表中的重复项出现的次数的方法
Aug 18 Python
Python3实现的Mysql数据库操作封装类
Jun 06 Python
django如何连接已存在数据的数据库
Aug 14 Python
python爬虫之爬取百度音乐的实现方法
Aug 24 Python
pytorch常见的Tensor类型详解
Jan 15 Python
解决安装新版PyQt5、PyQT5-tool后打不开并Designer.exe提示no Qt platform plugin的问题
Apr 24 Python
python函数map()和partial()的知识点总结
May 26 Python
python调用私有属性的方法总结
Jul 24 Python
python 5个实用的技巧
Sep 27 Python
opencv实现图像几何变换
Mar 24 Python
Python激活Anaconda环境变量的详细步骤
Jun 08 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
PHP一些有意思的小区别
2006/12/06 PHP
PHP项目开发中最常用的自定义函数整理
2010/12/02 PHP
php 搜索框提示(自动完成)实例代码
2012/02/05 PHP
Yii中render和renderPartial的区别
2014/09/03 PHP
php连接与操作PostgreSQL数据库的方法
2014/12/25 PHP
十幅图告诉你什么是PHP引用
2015/02/22 PHP
基于Laravel实现的用户动态模块开发
2017/09/21 PHP
JQuery的Validation插件中Remote验证的中文问题
2010/07/26 Javascript
Jquery练习之表单验证实现代码
2010/12/14 Javascript
ASP.NET jQuery 实例1(在TextBox里面创建一个默认提示)
2012/01/13 Javascript
js弹出的对话窗口永远保持居中显示
2012/12/15 Javascript
打开新窗口关闭当前页面不弹出关闭提示js代码
2013/03/18 Javascript
JavaScript 判断用户输入的邮箱及手机格式是否正确
2013/12/08 Javascript
jQuery获取当前对象标签名称的方法
2014/02/07 Javascript
Javascript添加监听与删除监听用法详解
2014/12/19 Javascript
Jquery实现仿腾讯娱乐频道焦点图(幻灯片)特效
2015/03/06 Javascript
jquery实现的V字形显示效果代码
2015/10/27 Javascript
Javascript类型转换的规则实例解析
2016/02/23 Javascript
jquery div模态窗口的简单实例
2016/05/28 Javascript
jQuery Easyui DataGrid点击某个单元格即进入编辑状态焦点移开后保存数据
2016/08/15 Javascript
基于react组件之间的参数传递(详解)
2017/09/05 Javascript
vue项目实现表单登录页保存账号和密码到cookie功能
2018/08/31 Javascript
详解Vue前端生产环境发布配置实战篇
2019/05/07 Javascript
Vue 实现登录界面验证码功能
2020/01/03 Javascript
python根据url地址下载小文件的实例
2018/12/18 Python
Python实现的矩阵转置与矩阵相乘运算示例
2019/03/26 Python
Python GUI编程 文本弹窗的实例
2019/06/11 Python
Python matplotlib生成图片背景透明的示例代码
2019/08/30 Python
瑞典网上购买现代和复古家具:Reforma
2019/10/21 全球购物
事业单位请假制度
2014/01/13 职场文书
鼓励运动员的广播稿
2014/02/08 职场文书
中药学专业求职信
2014/05/31 职场文书
大学生党员个人剖析材料
2014/10/08 职场文书
中学综治宣传月活动总结
2015/05/07 职场文书
联欢会开场白
2015/06/01 职场文书
详解运行Python的神器Jupyter Notebook
2021/06/03 Python