PyTorch-GPU加速实例


Posted in Python onJune 23, 2020

硬件:NVIDIA-GTX1080

软件:Windows7、python3.6.5、pytorch-gpu-0.4.1

一、基础知识

将数据和网络都推到GPU,接上.cuda()

二、代码展示

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
# torch.manual_seed(1)
 
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
 
train_data = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
 
# !!!!!!!! Change in here !!!!!!!!! #
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000].cuda()/255. # Tensor on GPU
test_y = test_data.test_labels[:2000].cuda()
 
class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2,),
         nn.ReLU(), nn.MaxPool2d(kernel_size=2),)
  self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), nn.ReLU(), nn.MaxPool2d(2),)
  self.out = nn.Linear(32 * 7 * 7, 10)
 
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1)
  output = self.out(x)
  return output
 
cnn = CNN()
 
# !!!!!!!! Change in here !!!!!!!!! #
cnn.cuda()  # Moves all model parameters and buffers to the GPU.
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
for epoch in range(EPOCH):
 for step, (x, y) in enumerate(train_loader):
 
  # !!!!!!!! Change in here !!!!!!!!! #
  b_x = x.cuda() # Tensor on GPU
  b_y = y.cuda() # Tensor on GPU
 
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  if step % 50 == 0:
   test_output = cnn(test_x)
 
   # !!!!!!!! Change in here !!!!!!!!! #
   pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
   accuracy = torch.sum(pred_y == test_y).type(torch.FloatTensor) / test_y.size(0)
   print('Epoch: ', epoch, '| train loss: %.4f' % loss, '| test accuracy: %.2f' % accuracy)
 
test_output = cnn(test_x[:10])
 
# !!!!!!!! Change in here !!!!!!!!! #
pred_y = torch.max(test_output, 1)[1].cuda().data # move the computation in GPU
 
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')

三、结果展示

PyTorch-GPU加速实例

补充知识:pytorch使用gpu对网络计算进行加速

1.基本要求

你的电脑里面有合适的GPU显卡(NVIDA),并且需要支持CUDA模块

你必须安装GPU版的Torch,(详细安装方法请移步pytorch官网)

2.使用GPU训练CNN

利用pytorch使用GPU进行加速方法主要就是将数据的形式变成GPU能读的形式,然后将CNN也变成GPU能读的形式,具体办法就是在后面加上.cuda()。

例如:

#如何检查自己电脑是否支持cuda
print torch.cuda.is_available()
# 返回True代表支持,False代表不支持
'''
注意在进行某种运算的时候使用.cuda()
'''
test_data=test_data.test_labels[:2000].cuda()
'''
对于CNN与损失函数利用cuda加速
'''
class CNN(nn.Module):
 ...
cnn=CNN()
cnn.cuda()
loss_f = t.nn.CrossEntropyLoss()
loss_f = loss_f.cuda()

而在train时,对于train_data训练过程进行GPU加速。也同样+.cuda()。

for epoch ..:
 for step, ...:
 1
'''
若你的train_data在训练时需要进行操作
若没有其他操作仅仅只利用cnn()则无需另加.cuda()
'''
#eg
 train_data = torch.max(teain_data, 1)[1].cuda()

补充:取出数据需要从GPU切换到CPU上进行操作

eg:

loss = loss.cpu()
acc = acc.cpu()

理解并不全,如有纰漏或者错误还望各位大佬指点迷津

以上这篇PyTorch-GPU加速实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之文本文件的读取和写入
Aug 29 Python
跟老齐学Python之不要红头文件(1)
Sep 28 Python
python脚本替换指定行实现步骤
Jul 11 Python
python reduce 函数使用详解
Dec 05 Python
如何使用 Pylint 来规范 Python 代码风格(来自IBM)
Apr 06 Python
python的继承知识点总结
Dec 10 Python
解决python3.5 正常安装 却不能直接使用Tkinter包的问题
Feb 22 Python
Python实现的登录验证系统完整案例【基于搭建的MVC框架】
Apr 12 Python
Django框架模型简单介绍与使用分析
Jul 18 Python
Django 对象关系映射(ORM)源码详解
Aug 06 Python
Java Spring项目国际化(i18n)详细方法与实例
Mar 20 Python
Python: glob匹配文件的操作
Dec 11 Python
Python基于yaml文件配置logging日志过程解析
Jun 23 #Python
解决pytorch多GPU训练保存的模型,在单GPU环境下加载出错问题
Jun 23 #Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 #Python
浅谈pytorch中的BN层的注意事项
Jun 23 #Python
Python3与fastdfs分布式文件系统如何实现交互
Jun 23 #Python
踩坑:pytorch中eval模式下结果远差于train模式介绍
Jun 23 #Python
pytorch掉坑记录:model.eval的作用说明
Jun 23 #Python
You might like
PHP闭包(Closure)使用详解
2013/05/02 PHP
PHP程序员基本要求和必备技能
2014/05/09 PHP
php生成随机颜色的方法
2014/11/13 PHP
php集成环境xampp中apache无法启动问题解决方案
2014/11/18 PHP
PHP实现的观察者模式实例
2017/06/21 PHP
JS 类型转换常见方法小结
2010/05/31 Javascript
jQuery弹出层插件Lightbox_me使用指南
2015/04/21 Javascript
javascript中checkbox使用方法简单实例演示
2015/11/17 Javascript
详解JavaScript 中的 replace 方法
2016/01/01 Javascript
实例详解JSON数据格式及json格式数据域字符串相互转换
2016/01/07 Javascript
微信小程序 vidao实现视频播放和弹幕的功能
2016/11/02 Javascript
jQuery如何跳转到另一个网页 就这么简单
2016/12/28 Javascript
vue2.0父子组件及非父子组件之间的通信方法
2017/01/21 Javascript
vue实现一个移动端屏蔽滑动的遮罩层实例
2017/06/08 Javascript
JavaScript实现打印星型金字塔功能实例分析
2017/09/27 Javascript
vue 动态修改a标签的样式的方法
2018/01/18 Javascript
详解js类型判断
2018/05/22 Javascript
详解如何配置vue-cli3.0的vue.config.js
2018/08/23 Javascript
Python和C/C++交互的几种方法总结
2017/05/11 Python
python编写微信远程控制电脑的程序
2018/01/05 Python
Python+PIL实现支付宝AR红包
2018/02/09 Python
Django项目实战之用户头像上传与访问的示例
2018/04/21 Python
pyhanlp安装介绍和简单应用
2019/02/22 Python
PythonWeb项目Django部署在Ubuntu18.04腾讯云主机上
2019/04/01 Python
Python中asyncio模块的深入讲解
2019/06/10 Python
解决python中使用PYQT时中文乱码问题
2019/06/17 Python
浅谈Pandas Series 和 Numpy array中的相同点
2019/06/28 Python
python中sys模块是做什么用的
2020/08/16 Python
python 使用tkinter+you-get实现视频下载器
2020/11/17 Python
CSS3按钮鼠标悬浮实现光圈效果源码
2016/09/11 HTML / CSS
关于老式浏览器兼容HTML5和CSS3的问题
2016/06/01 HTML / CSS
建筑人员岗位职责
2013/12/25 职场文书
孟佩杰观后感
2015/06/17 职场文书
关于运动会的广播稿
2015/08/19 职场文书
Python爬虫数据的分类及json数据使用小结
2021/03/29 Python
利用nginx搭建RTMP视频点播、直播、HLS服务器
2022/05/25 Servers