用Pytorch训练CNN(数据集MNIST,使用GPU的方法)


Posted in Python onAugust 19, 2019

听说pytorch使用比TensorFlow简单,加之pytorch现已支持windows,所以今天装了pytorch玩玩,第一件事还是写了个简单的CNN在MNIST上实验,初步体验的确比TensorFlow方便。

参考代码(在莫烦python的教程代码基础上修改)如下:

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import time
#import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = False 
if_use_gpu = 1
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
#plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
#plt.title('%i' % training_data.train_labels[0]) 
#plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 

test_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=False, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
# 取前全部10000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1).float(), requires_grad=False)
#test_x = test_x.cuda()
## (~, 28, 28) to (~, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels
#test_y = test_y.cuda() 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
if if_use_gpu:
  cnn = cnn.cuda()

optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 


for epoch in range(EPOCH): 
  start = time.time() 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x, requires_grad=False) 
    b_y = Variable(y, requires_grad=False) 
    if if_use_gpu:
      b_x = b_x.cuda()
      b_y = b_y.cuda()
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0]) 
  duration = time.time() - start 
  print('Training duation: %.4f'%duration)
  
cnn = cnn.cpu()
test_output = cnn(test_x) 
pred_y = torch.max(test_output, 1)[1].data.squeeze()
accuracy = sum(pred_y == test_y) / test_y.size(0) 
print('Test Acc: %.4f'%accuracy)

以上这篇用Pytorch训练CNN(数据集MNIST,使用GPU的方法)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用Ctrl+C终止多线程程序的问题解决
Mar 30 Python
Python time模块详解(常用函数实例讲解,非常好)
Apr 24 Python
Python脚本实现DNSPod DNS动态解析域名
Feb 14 Python
简单介绍Python中的filter和lambda函数的使用
Apr 07 Python
解析Python中的异常处理
Apr 28 Python
python爬虫基本知识
Mar 05 Python
python读取图片并修改格式与大小的方法
Jul 24 Python
python三方库之requests的快速上手
Mar 04 Python
通过自学python能找到工作吗
Jun 21 Python
python 利用openpyxl读取Excel表格中指定的行或列教程
Feb 06 Python
python实现三次密码验证的示例
Apr 29 Python
实例详解Python的进程,线程和协程
Mar 13 Python
python opencv实现证件照换底功能
Aug 19 #Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 #Python
将Pytorch模型从CPU转换成GPU的实现方法
Aug 19 #Python
pytorch 使用单个GPU与多个GPU进行训练与测试的方法
Aug 19 #Python
在pytorch中为Module和Tensor指定GPU的例子
Aug 19 #Python
pytorch使用指定GPU训练的实例
Aug 19 #Python
关于pytorch多GPU训练实例与性能对比分析
Aug 19 #Python
You might like
mysql4.1以上版本连接时出现Client does not support authentication protocol问题解决办法
2007/03/15 PHP
用PHP程序实现支持页面后退的两种方法
2008/06/30 PHP
PHP将字符分解为多个字符串的方法
2014/11/22 PHP
PHP跨平台获取服务器IP地址自定义函数分享
2014/12/29 PHP
PHP实现的简单mock json脚本分享
2015/02/10 PHP
golang与PHP输出excel示例
2016/07/22 PHP
PHP实现图片批量打包下载功能
2017/03/01 PHP
Thinkphp5行为使用方法汇总
2017/12/21 PHP
解决PhpStorm64不能启动的问题
2020/06/20 PHP
用Javascript 获取页面元素的位置的代码
2009/09/25 Javascript
jQuery 打造动态下滑菜单实现说明
2010/04/15 Javascript
javascript中定义类的方法详解
2015/02/10 Javascript
最常见和最有用的字符串相关的方法详解
2017/02/06 Javascript
vue中计算属性(computed)、methods和watched之间的区别
2017/07/27 Javascript
vue.js异步上传文件前后端实现代码
2017/08/22 Javascript
限时抢购-倒计时的完整实例(分享)
2017/09/17 Javascript
详解Node.js模板引擎Jade入门
2018/01/19 Javascript
nodejs搭建本地服务器轻松解决跨域问题
2018/03/21 NodeJs
vue底部加载更多的实例代码
2018/06/29 Javascript
Vue利用History记录上一页面的数据方法实例
2018/11/02 Javascript
vue实现瀑布流组件滑动加载更多
2020/03/10 Javascript
JQuery Ajax如何实现注册检测用户名
2020/09/25 jQuery
如何在vue中使用百度地图添加自定义覆盖物(水波纹)
2020/11/03 Javascript
在vue中嵌入外部网站的实现
2020/11/13 Javascript
[01:09]模型精美,特效酷炫!TI9不朽宝藏Ⅰ鉴赏
2019/05/10 DOTA
Python3实现的反转单链表算法示例
2019/03/08 Python
Python 实现Numpy中找出array中最大值所对应的行和列
2019/11/26 Python
python解析多层json操作示例
2019/12/30 Python
jupyter lab的目录调整及设置默认浏览器为chrome的方法
2020/04/10 Python
快速了解Python开发环境Spyder
2020/06/29 Python
关于多种方式完美解决Python pip命令下载第三方库的问题
2020/12/21 Python
致跳远运动员广播稿
2014/02/11 职场文书
嘉宾邀请函
2015/01/31 职场文书
逃课检讨书范文
2015/05/06 职场文书
总结一下关于在Java8中使用stream流踩过的一些坑
2021/06/24 Java/Android
MySQL 条件查询的常用操作
2022/04/28 MySQL