用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
python 捕获 shell/bash 脚本的输出结果实例
Jan 04 Python
python机器学习理论与实战(五)支持向量机
Jan 19 Python
Python OpenCV处理图像之图像像素点操作
Jul 10 Python
python3利用tcp实现文件夹远程传输
Jul 28 Python
python ddt数据驱动最简实例代码
Feb 22 Python
基于python的ini配置文件操作工具类
Apr 24 Python
python中pip的使用和修改下载源的方法
Jul 08 Python
ansible动态Inventory主机清单配置遇到的坑
Jan 19 Python
python脚本监控logstash进程并邮件告警实例
Apr 28 Python
Python Pillow(PIL)库的用法详解
Sep 19 Python
Python简易开发之制作计算器
Apr 28 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
一个php短网址的生成代码(仿微博短网址)
2014/05/07 PHP
PHP实现获取FLV文件的时间
2015/02/10 PHP
PHP+jQuery翻板抽奖功能实现
2015/10/19 PHP
php轻松实现文件上传功能
2016/03/03 PHP
基于PHP微信红包的算法探讨
2016/07/21 PHP
探究Laravel使用env函数读取环境变量为null的问题
2016/12/06 PHP
简单选项卡 js和jquery制作方法分享
2014/02/26 Javascript
JavaScript 学习笔记之操作符(续)
2015/01/14 Javascript
使用jQuery实现图片遮罩半透明坠落遮挡
2015/03/16 Javascript
JavaScript如何实现组合列表框中元素移动效果
2016/03/01 Javascript
如何更好的编写js async函数
2018/05/13 Javascript
使用iView Upload 组件实现手动上传图片的示例代码
2018/10/01 Javascript
jQuery使用$.extend(true,object1, object2);实现深拷贝对象的方法分析
2019/03/06 jQuery
JS学习笔记之原型链和利用原型实现继承详解
2019/05/29 Javascript
JS数组方法join()用法实例分析
2020/01/18 Javascript
浅谈python为什么不需要三目运算符和switch
2016/06/17 Python
Python生成随机密码的方法
2017/06/16 Python
python实现决策树分类(2)
2018/08/30 Python
Python利用字典破解WIFI密码的方法
2019/02/27 Python
基于python全局设置id 自动化测试元素定位过程解析
2019/09/04 Python
python 初始化一个定长的数组实例
2019/12/02 Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
2019/12/31 Python
Django单元测试中Fixtures的使用方法
2020/02/26 Python
Python使用pdb调试代码的技巧
2020/05/03 Python
python矩阵运算,转置,逆运算,共轭矩阵实例
2020/05/11 Python
基于python 将列表作为参数传入函数时的测试与理解
2020/06/05 Python
编译 pycaffe时报错:fatal error: numpy/arrayobject.h没有那个文件或目录
2020/11/29 Python
CSS3毛玻璃效果(blur)有白边问题的解决方法
2016/11/15 HTML / CSS
护理专业应届毕业生推荐信
2013/11/15 职场文书
12岁生日感言
2014/01/21 职场文书
科级干部考察材料
2014/02/15 职场文书
学雷锋活动总结范文
2014/04/25 职场文书
新闻传播专业求职信
2014/07/22 职场文书
教师自我剖析材料(四风问题)
2014/09/30 职场文书
小学三年级数学教学反思
2016/02/16 职场文书
解析CSS 提取图片主题色功能(小技巧)
2021/05/12 HTML / CSS