pytorch实现MNIST手写体识别


Posted in Python onFebruary 14, 2020

本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下

实验环境

pytorch 1.4
Windows 10
python 3.7
cuda 10.1(我笔记本上没有可以使用cuda的显卡)

实验过程

1. 确定我们要加载的库

import torch
import torch.nn as nn
import torchvision #这里面直接加载MNIST数据的方法
import torchvision.transforms as transforms # 将数据转为Tensor
import torch.optim as optim 
import torch.utils.data.dataloader as dataloader

2. 加载数据

这里使用所有数据进行训练,再使用所有数据进行测试

train_set = torchvision.datasets.MNIST(
 root='./data', # 文件存储位置
 train=True,
 transform=transforms.ToTensor(),
 download=True
)

train_dataloader = dataloader.DataLoader(dataset=train_set,shuffle=False,batch_size=100)# dataset可以省

'''
dataloader返回(images,labels)
其中,
images维度:[batch_size,1,28,28]
labels:[batch_size],即图片对应的
'''

test_set = torchvision.datasets.MNIST(
 root='./data',
 train=False,
 transform=transforms.ToTensor(),
 download=True
)

test_dataloader = dataloader.DataLoader(test_set,batch_size=100,shuffle=False) # dataset可以省

3. 定义神经网络模型

这里使用全神经网络作为模型

class NeuralNet(nn.Module):
 def __init__(self,in_num,h_num,out_num):
 super(NeuralNet,self).__init__()
 self.ln1 = nn.Linear(in_num,h_num)
 self.ln2 = nn.Linear(h_num,out_num)
 self.relu = nn.ReLU()
 
 def forward(self,x):
 return self.ln2(self.relu(self.ln1(x)))

4. 模型训练

in_num = 784 # 输入维度
h_num = 500 # 隐藏层维度
out_num = 10 # 输出维度
epochs = 30 # 迭代次数
learning_rate = 0.001
USE_CUDA = torch.cuda.is_available() # 定义是否可以使用cuda

model = NeuralNet(in_num,h_num,out_num) # 初始化模型
optimizer = optim.Adam(model.parameters(),lr=learning_rate) # 使用Adam
loss_fn = nn.CrossEntropyLoss() # 损失函数

for e in range(epochs):
 for i,data in enumerate(train_dataloader):
 (images,labels) = data
 images = images.reshape(-1,28*28) # [batch_size,784]
 if USE_CUDA:
  images = images.cuda() # 使用cuda
  labels = labels.cuda() # 使用cuda
  
 y_pred = model(images) # 预测
 loss = loss_fn(y_pred,labels) # 计算损失
 
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 n = e * i +1
 if n % 100 == 0:
  print(n,'loss:',loss.item())

训练模型的loss部分截图如下:

pytorch实现MNIST手写体识别

5. 测试模型

with torch.no_grad():
 total = 0
 correct = 0
 for (images,labels) in test_dataloader:
 images = images.reshape(-1,28*28)
 if USE_CUDA:
  images = images.cuda()
  labels = labels.cuda()
  
 result = model(images)
 prediction = torch.max(result, 1)[1] # 这里需要有[1],因为它返回了概率还有标签
 total += labels.size(0)
 correct += (prediction == labels).sum().item()
 
 print("The accuracy of total {} images: {}%".format(total, 100 * correct/total))

实验结果

最终实验的正确率达到:98.22%

pytorch实现MNIST手写体识别

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python两种遍历字典(dict)的方法比较
May 29 Python
用Python将动态GIF图片倒放播放的方法
Nov 02 Python
使用python爬虫实现网络股票信息爬取的demo
Jan 05 Python
kaggle+mnist实现手写字体识别
Jul 26 Python
Python2和Python3之间的str处理方式导致乱码的讲解
Jan 03 Python
python networkx 包绘制复杂网络关系图的实现
Jul 10 Python
python Django编写接口并用Jmeter测试的方法
Jul 31 Python
使用python和pygame制作挡板弹球游戏
Dec 03 Python
Python中os模块功能与用法详解
Feb 26 Python
python实现测试工具(二)——简单的ui测试工具
Oct 19 Python
python批量提取图片信息并保存的实现
Feb 05 Python
教你怎么用python爬取爱奇艺热门电影
May 20 Python
Python3.7实现验证码登录方式代码实例
Feb 14 #Python
Python逐行读取文件内容的方法总结
Feb 14 #Python
Python3和PyCharm安装与环境配置【图文教程】
Feb 14 #Python
python对Excel的读取的示例代码
Feb 14 #Python
Python安装依赖(包)模块方法详解
Feb 14 #Python
python 项目目录结构设置
Feb 14 #Python
wxpython自定义下拉列表框过程图解
Feb 14 #Python
You might like
php之字符串变相相减的代码
2007/03/19 PHP
PHP高级OOP技术演示
2009/08/27 PHP
php处理json格式数据经典案例总结
2016/05/19 PHP
php实现保存周期为1天的购物车类
2017/07/07 PHP
js鼠标滑过弹出层的定位IE6bug解决办法
2012/12/26 Javascript
JS字符串处理实例代码
2013/08/05 Javascript
setInterval()和setTimeout()的用法和区别示例介绍
2013/11/17 Javascript
同域jQuery(跨)iframe操作DOM(实例讲解)
2013/12/19 Javascript
JavaScript中0和""比较引发的问题
2016/05/26 Javascript
jQuery插件扩展extend的简单实现原理
2016/06/24 Javascript
Bootstrap在线电子商务网站实战项目5
2016/10/14 Javascript
js实现文本上下来回滚动
2017/02/03 Javascript
详解如何使用Node.js编写命令工具——以vue-cli为例
2017/06/29 Javascript
基于Jquery Ajax type的4种类型(详解)
2017/08/02 jQuery
vue-router路由懒加载和权限控制详解
2017/12/13 Javascript
vue.js通过路由实现经典的三栏布局实例代码
2018/07/08 Javascript
详解bootstrap-fileinput文件上传控件的亲身实践
2019/03/21 Javascript
React如何实现浏览器打印部分内容详析
2019/05/19 Javascript
JavaScript实现横版菜单栏
2020/03/17 Javascript
React服务端渲染原理解析与实践
2021/03/04 Javascript
python subprocess 杀掉全部派生的子进程方法
2017/01/16 Python
解决Spyder中图片显示太小的问题
2018/04/27 Python
基于python log取对数详解
2018/06/08 Python
完美解决Python 2.7不能正常使用pip install的问题
2018/06/12 Python
Python中最好用的命令行参数解析工具(argparse)
2019/08/23 Python
Python matplotlib绘制饼状图功能示例
2019/09/10 Python
Python自带的IDE在哪里
2020/07/01 Python
Python 添加文件注释和函数注释操作
2020/08/09 Python
利用HTML5中的Canvas绘制一张笑脸的教程
2015/05/07 HTML / CSS
H5 video poster属性设置视频封面的方法
2020/05/25 HTML / CSS
澳大利亚最受欢迎的美发和美容在线商店:Catwalk
2018/12/12 全球购物
联想阿根廷官方网站:Lenovo Argentina
2019/10/14 全球购物
乌克兰在线电子产品商店:MTA
2019/11/14 全球购物
中职招生先进个人材料
2014/08/31 职场文书
2014年工会工作总结
2014/11/12 职场文书
2014年班组建设工作总结
2014/12/01 职场文书