keras 多gpu并行运行案例


Posted in Python onJune 10, 2020

一、多张gpu的卡上使用keras

有多张gpu卡时,推荐使用tensorflow 作为后端。使用多张gpu运行model,可以分为两种情况,一是数据并行,二是设备并行。

二、数据并行

数据并行将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据。

利用multi_gpu_model实现

keras.utils.multi_gpu_model(model, gpus=None, cpu_merge=True, cpu_relocation=False)

具体来说,该功能实现了单机多 GPU 数据并行性。 它的工作原理如下:

将模型的输入分成多个子批次。

在每个子批次上应用模型副本。 每个模型副本都在专用 GPU 上执行。

将结果(在 CPU 上)连接成一个大批量。

例如, 如果你的 batch_size 是 64,且你使用 gpus=2, 那么我们将把输入分为两个 32 个样本的子批次, 在 1 个 GPU 上处理 1 个子批次,然后返回完整批次的 64 个处理过的样本。

参数

model: 一个 Keras 模型实例。为了避免OOM错误,该模型可以建立在 CPU 上, 详见下面的使用样例。

gpus: 整数 >= 2 或整数列表,创建模型副本的 GPU 数量, 或 GPU ID 的列表。

cpu_merge: 一个布尔值,用于标识是否强制合并 CPU 范围内的模型权重。

cpu_relocation: 一个布尔值,用来确定是否在 CPU 的范围内创建模型的权重。如果模型没有在任何一个设备范围内定义,您仍然可以通过激活这个选项来拯救它。

返回

一个 Keras Model 实例,它可以像初始 model 参数一样使用,但它将工作负载分布在多个 GPU 上。

例子

import tensorflow as tf
from keras.applications import Xception
from keras.utils import multi_gpu_model
import numpy as np

num_samples = 1000
height = 224
width = 224
num_classes = 1000

# 实例化基础模型(或者「模版」模型)。
# 我们推荐在 CPU 设备范围内做此操作,
# 这样模型的权重就会存储在 CPU 内存中。
# 否则它们会存储在 GPU 上,而完全被共享。
with tf.device('/cpu:0'):
 model = Xception(weights=None,
   input_shape=(height, width, 3),
   classes=num_classes)

# 复制模型到 8 个 GPU 上。
# 这假设你的机器有 8 个可用 GPU。
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
   optimizer='rmsprop')

# 生成虚拟数据
x = np.random.random((num_samples, height, width, 3))
y = np.random.random((num_samples, num_classes))

# 这个 `fit` 调用将分布在 8 个 GPU 上。
# 由于 batch size 是 256, 每个 GPU 将处理 32 个样本。
parallel_model.fit(x, y, epochs=20, batch_size=256)

# 通过模版模型存储模型(共享相同权重):
model.save('my_model.h5')

注意:

要保存多 GPU 模型,请通过模板模型(传递给 multi_gpu_model 的参数)调用 .save(fname) 或 .save_weights(fname) 以进行存储,而不是通过 multi_gpu_model 返回的模型。

即要用model来保存,而不是parallel_model来保存。

使用ModelCheckpoint() 遇到的问题

使用ModelCheckpoint()会遇到下面的问题:

TypeError: can't pickle ...(different text at different situation) objects

这个问题和保存问题类似,ModelCheckpoint() 会自动调用parallel_model.save()来保存,而不是model.save(),因此我们要自己写一个召回函数,使得ModelCheckpoint()用model.save()。

修改方法:

class ParallelModelCheckpoint(ModelCheckpoint):
 def __init__(self,model,filepath, monitor='val_loss', verbose=0,
   save_best_only=False, save_weights_only=False,
   mode='auto', period=1):
 self.single_model = model
 super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

 def set_model(self, model):
 super(ParallelModelCheckpoint,self).set_model(self.single_model)

checkpoint = ParallelModelCheckpoint(original_model)

ParallelModelCheckpoint调用的时候,model应该为原来的model而不是parallel_model。

EarlyStopping 没有此类问题

二、设备并行

设备并行适用于多分支结构,一个分支用一个gpu。

这种并行方法可以通过使用TensorFlow device scopes实现,下面是一个例子:

# Model where a shared LSTM is used to encode two different sequences in parallel
input_a = keras.Input(shape=(140, 256))
input_b = keras.Input(shape=(140, 256))

shared_lstm = keras.layers.LSTM(64)

# Process the first sequence on one GPU
with tf.device_scope('/gpu:0'):
 encoded_a = shared_lstm(tweet_a)
# Process the next sequence on another GPU
with tf.device_scope('/gpu:1'):
 encoded_b = shared_lstm(tweet_b)

# Concatenate results on CPU
with tf.device_scope('/cpu:0'):
 merged_vector = keras.layers.concatenate([encoded_a, encoded_b],
      axis=-1)

三、分布式运行

keras的分布式是利用TensorFlow实现的,要想完成分布式的训练,你需要将Keras注册在连接一个集群的TensorFlow会话上:

server = tf.train.Server.create_local_server()
sess = tf.Session(server.target)

from keras import backend as K
K.set_session(sess)

以上这篇keras 多gpu并行运行案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.7简单连接与操作MySQL的方法
Apr 27 Python
Python实现简易端口扫描器代码实例
Mar 15 Python
python 调用c语言函数的方法
Sep 29 Python
利用python实现简单的邮件发送客户端示例
Dec 23 Python
python Opencv将图片转为字符画
Feb 19 Python
用python 批量更改图像尺寸到统一大小的方法
Mar 31 Python
Django中STATIC_ROOT和STATIC_URL及STATICFILES_DIRS浅析
May 08 Python
3个用于数据科学的顶级Python库
Sep 29 Python
wxPython多个窗口的基本结构
Nov 19 Python
python实现文法左递归的消除方法
May 22 Python
Python列表的深复制和浅复制示例详解
Feb 12 Python
python 进阶学习之python装饰器小结
Sep 04 Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
keras 获取某层的输入/输出 tensor 尺寸操作
Jun 10 #Python
Python 字典中的所有方法及用法
Jun 10 #Python
You might like
PHP syntax error, unexpected $end 错误的一种原因及解决
2008/10/25 PHP
非常精妙的PHP递归调用与静态变量使用
2012/12/16 PHP
php中删除、清空session的方式总结
2015/10/09 PHP
深入理解PHP JSON数组与对象
2016/07/19 PHP
Yii2框架中日志的使用方法分析
2017/05/22 PHP
PHP设计模式之单例模式原理与实现方法分析
2018/04/25 PHP
PHP使用XMLWriter读写xml文件操作详解
2018/07/31 PHP
详细讲解JS节点知识
2010/01/31 Javascript
jquery的live使用注意事项
2014/02/18 Javascript
js格式化时间的方法
2015/12/18 Javascript
深入理解(function(){... })();
2016/08/16 Javascript
JavaScript实现自动跳转文本功能
2017/05/25 Javascript
react-native-tab-navigator组件的基本使用示例代码
2017/09/07 Javascript
JavaScript变量声明var,let.const及区别浅析
2018/04/23 Javascript
JavaScript ES2019中的8个新特性详解
2019/02/20 Javascript
jquery实现手风琴案例
2020/05/04 jQuery
Python中staticmethod和classmethod的作用与区别
2018/10/11 Python
对python:循环定义多个变量的实例详解
2019/01/20 Python
python实现简单日期工具类
2019/04/24 Python
在Django下测试与调试REST API的方法详解
2019/08/29 Python
浅谈tensorflow 中tf.concat()的使用
2020/02/07 Python
Python Tornado之跨域请求与Options请求方式
2020/03/28 Python
Django 设置admin后台表和App(应用)为中文名的操作方法
2020/05/10 Python
详解Django ORM引发的数据库N+1性能问题
2020/10/12 Python
python画图时设置分辨率和画布大小的实现(plt.figure())
2021/01/08 Python
HTML5给汉字加拼音收起展开组件的实现代码
2020/04/08 HTML / CSS
美国百年历史早餐食品供应商:Wolferman’s
2017/01/18 全球购物
艺术设计专业个人求职信
2013/09/21 职场文书
数学系毕业生的自我评价
2014/01/10 职场文书
学校师德师风自我剖析材料
2014/09/29 职场文书
欢迎家长标语
2014/10/08 职场文书
感谢信格式范文
2015/01/22 职场文书
单位计划生育责任书
2015/05/09 职场文书
正规欠条模板
2015/07/03 职场文书
十一月早安语录:把心放轻,人生就是一朵自在的云
2019/11/04 职场文书
Oracle创建只读账号的详细步骤
2021/06/07 Oracle