用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)
Apr 28 Python
Python面向对象编程中关于类和方法的学习笔记
Jun 30 Python
使用Python脚本实现批量网站存活检测遇到问题及解决方法
Oct 11 Python
python中defaultdict的用法详解
Jun 07 Python
浅谈python之新式类
Aug 12 Python
用Python将结果保存为xlsx的方法
Jan 28 Python
ERLANG和PYTHON互通实现过程详解
Jul 05 Python
使用python将excel数据导入数据库过程详解
Aug 27 Python
Tensorflow 定义变量,函数,数值计算等名字的更新方式
Feb 10 Python
使用npy转image图像并保存的实例
Jul 01 Python
python 日志模块logging的使用场景及示例
Jan 04 Python
神经网络训练采用gpu设置的方式
Mar 03 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
Ajax PHP分页演示
2007/01/02 PHP
Session保存到数据库的php类分享
2011/10/24 PHP
php flush无效,IIS7下php实时输出的方法
2016/08/25 PHP
php使用QueryList轻松采集js动态渲染页面方法
2018/09/11 PHP
Flash对联广告的关闭按钮讨论
2007/01/30 Javascript
基于jquery的获取浏览器窗口大小的代码
2011/03/28 Javascript
JavaScript代码编写中各种各样的坑和填坑方法
2014/06/06 Javascript
JavaScript中的DSL元编程介绍
2015/03/15 Javascript
JavaScript动态加载样式表的方法
2015/03/21 Javascript
jQuery插件Validate实现自定义表单验证
2016/01/18 Javascript
基于JavaScript实现图片剪切效果
2017/03/07 Javascript
微信浏览器禁止页面下拉查看网址实例详解
2017/06/28 Javascript
js实现1,2,3,5数字按照概率生成
2017/09/12 Javascript
解决循环中setTimeout执行顺序的问题
2018/06/20 Javascript
使用Angular-CLI构建NPM包的方法
2018/09/07 Javascript
angularJs中orderBy筛选以及filter过滤数据的方法
2018/09/30 Javascript
微信小程序封装多张图片上传api代码实例
2019/12/30 Javascript
Python随机数用法实例详解【基于random模块】
2017/04/18 Python
pyqt5的QWebEngineView 使用模板的方法
2018/08/18 Python
python三大神器之fabric使用教程
2019/06/10 Python
python实现低通滤波器代码
2020/02/26 Python
python pip如何手动安装二进制包
2020/09/30 Python
html5音频_动力节点Java学院整理
2018/08/22 HTML / CSS
web页面录屏实现
2019/02/12 HTML / CSS
美国领先的医疗警报服务:Philips Lifeline
2018/03/12 全球购物
美国班级戒指、帽子和礼服、毕业产品、年鉴:Balfour
2018/11/01 全球购物
如何理解transaction事务的概念
2015/05/27 面试题
创业计划书的主要内容有哪些
2014/01/29 职场文书
汽车运用工程专业求职信
2014/06/18 职场文书
纪念九一八事变演讲稿:勿忘国耻
2014/09/14 职场文书
公安局负责人查摆问题及整改方案
2014/09/27 职场文书
毕业生自荐材料范文
2014/12/30 职场文书
银行自荐信怎么写
2015/03/05 职场文书
汽车质检员岗位职责
2015/04/08 职场文书
2016感恩母亲节校园广播稿
2015/12/17 职场文书
Win11查看设备管理器
2022/04/19 数码科技