Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中一些自然语言工具的使用的入门教程
Apr 13 Python
使用基于Python的Tornado框架的HTTP客户端的教程
Apr 24 Python
python实现可将字符转换成大写的tcp服务器实例
Apr 29 Python
老生常谈python之鸭子类和多态
Jun 13 Python
python构建自定义回调函数详解
Jun 20 Python
Sanic框架Cookies操作示例
Jul 17 Python
Django 日志配置按日期滚动的方法
Jan 31 Python
python利用多种方式来统计词频(单词个数)
May 27 Python
对python中的控制条件、循环和跳出详解
Jun 24 Python
django-rest-framework 自定义swagger过程详解
Jul 18 Python
详解python uiautomator2 watcher的使用方法
Sep 09 Python
Python日志模块logging用法
Jun 05 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
phpmyadmin导入(import)文件限制的解决办法
2009/12/11 PHP
PHP 实现类似js中alert() 提示框
2015/03/18 PHP
php中实现用数组妩媚地生成要执行的sql语句
2015/07/10 PHP
Linux php 中文乱码的快速解决方法
2016/05/13 PHP
header与缓冲区之间的深层次分析
2016/07/30 PHP
10个值得深思的PHP面试题
2016/11/14 PHP
php文件上传 你真的掌握了吗
2016/11/28 PHP
ThinkPHP3.2.3框架实现执行原生SQL语句的方法示例
2019/04/03 PHP
从JavaScript 到 JQuery (1)学习小结
2009/02/12 Javascript
js apply/call/caller/callee/bind使用方法与区别分析
2009/10/28 Javascript
JSQL SQLProxy 的 php 版本代码
2010/05/05 Javascript
利用jquery制作滚动到指定位置触发动画
2016/03/26 Javascript
Bootstrap布局组件教程之Bootstrap下拉菜单
2016/06/12 Javascript
JavaScript中的splice方法用法详解
2016/07/20 Javascript
JS获取和修改元素样式的实例代码
2016/08/06 Javascript
nodejs+express实现文件上传下载管理网站
2017/03/15 NodeJs
Extjs 中的 Treepanel 实现菜单级联选中效果及实例代码
2017/08/22 Javascript
vue中的$emit 与$on父子组件与兄弟组件的之间通信方式
2018/05/13 Javascript
Angular中sweetalert弹框的基本使用教程
2018/07/22 Javascript
Vue自定义弹窗指令的实现代码
2018/08/13 Javascript
查找Vue中下标的操作(some和findindex)
2020/08/12 Javascript
jQuery-App输入框实现实时搜索
2020/11/19 jQuery
Python代码实现KNN算法
2017/12/20 Python
Python 静态方法和类方法实例分析
2019/11/21 Python
Python之Sklearn使用入门教程
2021/02/19 Python
如何查看在weblogic中已经发布的EJB
2012/06/01 面试题
房产公证书格式
2015/01/26 职场文书
国博复兴之路观后感
2015/06/02 职场文书
小型婚礼主持词
2015/06/30 职场文书
安全教育培训制度
2015/08/06 职场文书
导游词之重庆钓鱼城
2019/09/19 职场文书
Python实现文本文件拆分写入到多个文本文件的方法
2021/04/18 Python
mysql外连接与内连接查询的不同之处
2021/06/03 MySQL
redis数据一致性的实现示例
2022/03/18 Redis
Python实现Excel文件的合并(以新冠疫情数据为例)
2022/03/20 Python
Windows Server 2012 R2 磁盘分区教程
2022/04/29 Servers