Pytorch 多块GPU的使用详解


Posted in Python onDecember 31, 2019

注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。

在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。

1. 设置需要使用的GPU编号

import os
 
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]

比如我们需要使用第0和第4块GPU,只用上述三行代码即可。

其中第二行指程序只能看到第1块和第4块GPU;

第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。

2.更改网络,可以理解为将网络放入GPU

class CNN(nn.Module):
  def __init__(self):
    super(CNN,self).__init__()
    self.conv1 = nn.Sequential(
    ......
    )
    
    ......
    
    self.out = nn.Linear(Liner_input,2)
 
  ......
    
  def forward(self,x):
    x = self.conv1(x)
    ......
    output = self.out(x)
    return output,x
  
cnn = CNN()
 
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()

3. 更改输出数据(如向量/矩阵/张量):

for epoch in range(EPOCH):
  epoch_loss = 0.
  for i, data in enumerate(train_loader2):
    image = data['image'] # data是字典,我们需要改的是其中的image
 
    #############更改!!!##################
    image = Variable(image).float().cuda()
    ############################################
 
    label = inputs['label']
    #############更改!!!##################
    label = Variable(label).type(torch.LongTensor).cuda()
    ############################################
    label = label.resize(BATCH_SIZE)
    output = cnn(image)[0]
    loss = loss_func(output, label)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step() 
    ... ...

4. 更改其他CPU与GPU冲突的地方

有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。

若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:

... ...
    #################################################
    pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
    
    accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))

假如不加.cpu()便会报错,此时再改即可。

5. 更改前向传播函数,从而使用多块GPU

以VGG为例:

class VGG(nn.Module):
 
  def __init__(self, features, num_classes=2, init_weights=True):
    super(VGG, self).__init__()
... ...
 
  def forward(self, x):
    #x = self.features(x)
    #################Multi GPUS#############################
    x = nn.parallel.data_parallel(self.features,x,ids)
    x = x.view(x.size(0), -1)
    # x = self.classifier(x)
    x = nn.parallel.data_parallel(self.classifier,x,ids)
    return x
... ...

然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:

Pytorch 多块GPU的使用详解

可以看到0和4都被使用啦

以上这篇Pytorch 多块GPU的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程中的异常处理教程
Aug 21 Python
获取Django项目的全部url方法详解
Oct 26 Python
Python中eval带来的潜在风险代码分析
Dec 11 Python
Python+matplotlib绘制不同大小和颜色散点图实例
Jan 19 Python
Python实现中一次读取多个值的方法
Apr 22 Python
python 使用re.search()筛选后 选取部分结果的方法
Nov 28 Python
Python获取网段内ping通IP的方法
Jan 31 Python
django框架CSRF防护原理与用法分析
Jul 22 Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 Python
使用python实现微信小程序自动签到功能
Apr 27 Python
Python unittest单元测试框架实现参数化
Apr 29 Python
matplotlib绘制鼠标的十字光标的实现(自定义方式,官方实例)
Jan 10 Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
You might like
一步一步学习PHP(3) php 函数
2010/02/15 PHP
那些年我们错过的魔术方法(Magic Methods)
2014/01/14 PHP
Yii2.0中的COOKIE和SESSION用法
2016/08/12 PHP
PHP超低内存遍历目录文件和读取超大文件的方法
2019/05/01 PHP
Laravel模糊查询区分大小写的实例
2019/09/29 PHP
Laravel中GraphQL接口请求频率实战记录
2020/09/01 PHP
javascript iframe编程相关代码
2009/12/28 Javascript
jQuery Selector选择器小结
2010/05/06 Javascript
关于捕获用户何时点击window.onbeforeunload的取消事件
2011/03/06 Javascript
学习Bootstrap组件之下拉菜单
2015/07/28 Javascript
JS实现新浪微博效果带遮罩层的弹出框代码
2015/10/12 Javascript
BootStrap selectpicker
2016/06/20 Javascript
轮播的简单实现方法
2016/07/28 Javascript
JS正则匹配URL网址的方法(可匹配www,http开头的一切网址)
2017/01/06 Javascript
PHP实现记录代码运行时间封装类实例教程
2017/05/08 Javascript
基于es6三点运算符的使用方法(实例讲解)
2017/10/12 Javascript
JS实现的哈夫曼编码示例【原始版与修改版】
2018/04/22 Javascript
基于Vue实现拖拽效果
2018/04/27 Javascript
js中的数组对象排序分析
2018/12/11 Javascript
python中安装模块包版本冲突问题的解决
2017/05/02 Python
Python实现注册登录系统
2017/08/08 Python
利用python操作SQLite数据库及文件操作详解
2017/09/22 Python
python requests 测试代理ip是否生效
2018/07/25 Python
一文带你了解Python中的字符串是什么
2018/11/20 Python
Pyinstaller打包.py生成.exe的方法和报错总结
2019/04/02 Python
pycharm下配置pyqt5的教程(anaconda虚拟环境下+tensorflow)
2020/03/25 Python
python由已知数组快速生成新数组的方法
2020/04/08 Python
浅析Python 多行匹配模式
2020/07/24 Python
iframe在移动端的缩放的示例代码
2018/10/12 HTML / CSS
HTML5中的网络存储实现方式
2020/04/28 HTML / CSS
施华洛世奇日本官网:SWAROVSKI日本
2018/05/04 全球购物
可持续未来的时尚基础:Alternative Apparel
2019/05/06 全球购物
马耳他航空公司官方网站:Air Malta
2019/05/15 全球购物
电子商务专业推荐信范文
2013/12/02 职场文书
体育课课后反思
2014/04/24 职场文书
在Windows Server 2012上安装 .NET Framework 3.5 所遇到的问题
2022/04/29 Servers