pytorch 利用lstm做mnist手写数字识别分类的实例


Posted in Python onJanuary 10, 2020

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python IDLE 错误:IDLE''s subprocess didn''t make connection 的解决方案
Feb 13 Python
python 以16进制打印输出的方法
Jul 09 Python
Python3利用Dlib实现摄像头实时人脸检测和平铺显示示例
Feb 21 Python
Python跳出多重循环的方法示例
Jul 03 Python
新手入门Python编程的8个实用建议
Jul 12 Python
Python 过滤错误log并导出的实例
Dec 26 Python
Python try except异常捕获机制原理解析
Apr 18 Python
Django 用户登陆访问限制实例 @login_required
May 13 Python
python3:excel操作之读取数据并返回字典 + 写入的案例
Sep 01 Python
python中random.randint和random.randrange的区别详解
Sep 20 Python
selenium框架中driver.close()和driver.quit()关闭浏览器
Dec 08 Python
python中温度单位转换的实例方法
Dec 27 Python
Tensorflow Summary用法学习笔记
Jan 10 #Python
TENSORFLOW变量作用域(VARIABLE SCOPE)
Jan 10 #Python
python numpy数组复制使用实例解析
Jan 10 #Python
关于Pytorch的MNIST数据集的预处理详解
Jan 10 #Python
详解pycharm连接不上mysql数据库的解决办法
Jan 10 #Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 #Python
python ubplot使用方法解析
Jan 10 #Python
You might like
全世界最小的php网页木马一枚 附PHP木马的防范方法
2009/10/09 PHP
php fputcsv命令 写csv文件遇到的小问题(多维数组连接符)
2011/05/24 PHP
关于PHP 如何用 curl 读取 HTTP chunked 数据
2016/02/26 PHP
Javascript 各浏览器的 Javascript 效率对比
2008/01/23 Javascript
javascript数字数组去重复项的实现代码
2010/12/30 Javascript
js中有关IE版本检测
2012/01/04 Javascript
js 获取页面高度和宽度兼容 ie firefox chrome等
2014/05/14 Javascript
使用js画图之饼图
2015/01/12 Javascript
webpack中引用jquery的简单实现
2016/06/08 Javascript
jQuery UI结合Ajax创建可定制的Web界面
2016/06/22 Javascript
使用原生js写ajax实例(推荐)
2017/05/31 Javascript
extjs简介_动力节点Java学院整理
2017/07/17 Javascript
Vue2.0学习之详解Vue 组件及父子组件通信
2017/12/12 Javascript
vue 2.5.1 源码学习 之Vue.extend 和 data的合并策略
2019/06/04 Javascript
vue中组件通信的八种方式(值得收藏!)
2019/08/09 Javascript
JavaScript实现文件下载并重命名代码实例
2019/12/12 Javascript
flexible.js实现移动端rem适配方案
2020/04/07 Javascript
微信小程序开发打开另一个小程序的实现方法
2020/05/17 Javascript
vue实现两个组件之间数据共享和修改操作
2020/11/12 Javascript
[01:19:33]DOTA2-DPC中国联赛 正赛 iG vs VG BO3 第一场 2月2日
2021/03/11 DOTA
[50:44]DOTA2-DPC中国联赛 正赛 SAG vs Dragon BO3 第二场 2月22日
2021/03/11 DOTA
全面解析Python的While循环语句的使用方法
2015/10/13 Python
Python的Tornado框架实现异步非阻塞访问数据库的示例
2016/06/30 Python
使用Python如何测试InnoDB与MyISAM的读写性能
2018/09/18 Python
利用scikitlearn画ROC曲线实例
2020/07/02 Python
Python开发.exe小工具的详细步骤
2021/01/27 Python
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
高中生物教学反思
2014/02/05 职场文书
课外活动总结范文
2014/07/09 职场文书
股份转让协议书范本
2015/01/27 职场文书
试用期旷工辞退通知书
2015/04/17 职场文书
2015年村计划生育工作总结
2015/04/28 职场文书
毕业晚宴祝酒词
2015/08/11 职场文书
小学语文继续教育研修日志
2015/11/13 职场文书
小学生优秀作文范文(六篇)
2019/07/10 职场文书
win10如何快速切换窗口 win10切换窗口快捷键分享
2022/07/23 数码科技