keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
剖析Django中模版标签的解析与参数传递
Jul 21 Python
Python Requests安装与简单运用
Apr 07 Python
Python实现的维尼吉亚密码算法示例
Apr 12 Python
python并发和异步编程实例
Nov 15 Python
python实现QQ空间自动点赞功能
Apr 09 Python
利用Python裁切tiff图像且读取tiff,shp文件的实例
Mar 10 Python
Python气泡提示与标签的实现
Apr 01 Python
pycharm远程连接vagrant虚拟机中mariadb数据库
Jun 05 Python
Python 实现 T00ls 自动签到脚本代码(邮件+钉钉通知)
Jul 06 Python
Python Django路径配置实现过程解析
Nov 05 Python
python 通过使用Yolact训练数据集
Apr 06 Python
Python利用zhdate模块实现农历日期处理
Mar 31 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
FireFox浏览器使用Javascript上传大文件
2013/10/30 PHP
PHP两种实现无级递归分类的方法
2017/03/02 PHP
ThinkPHP5.0框架使用build 自动生成模块操作示例
2019/04/11 PHP
PHP Cli 模式设置进程名称的方法
2019/06/12 PHP
用Div仿showModalDialog模式菜单的效果的代码
2007/03/05 Javascript
javascript之典型高阶函数应用介绍
2013/01/10 Javascript
JavaScript前端图片加载管理器imagepool使用详解
2014/12/29 Javascript
js实现带圆角的多级下拉菜单效果
2015/08/28 Javascript
基于canvas实现的钟摆效果完整实例
2016/01/26 Javascript
浅析Javascript中bind()方法的使用与实现
2016/05/30 Javascript
浅谈JavaScript函数的四种存在形态
2016/06/08 Javascript
微信小程序 滚动到某个位置添加class效果实现代码
2017/04/19 Javascript
Angular.js中控制器之间的传值详解
2017/04/24 Javascript
原生js实现简单的模态框示例
2017/09/08 Javascript
使用nvm管理不同版本的node与npm的方法
2017/10/31 Javascript
基于jQuery使用Ajax动态执行模糊查询功能
2018/07/05 jQuery
图片文字识别(OCR)插件Ocrad.js教程
2018/11/26 Javascript
ES6知识点整理之数组解构和字符串解构的应用示例
2019/04/17 Javascript
JavaScript RegExp 对象用法详解
2019/09/24 Javascript
JS实现拖动模糊框特效
2020/08/25 Javascript
vue3 watch和watchEffect的使用以及有哪些区别
2021/01/26 Vue.js
python增加矩阵维度的实例讲解
2018/04/04 Python
10个Python小技巧你值得拥有
2018/09/29 Python
对python多线程中Lock()与RLock()锁详解
2019/01/11 Python
tensorflow实现对张量数据的切片操作方式
2020/01/19 Python
Python 面向对象之类class和对象基本用法示例
2020/02/02 Python
pyautogui自动化控制鼠标和键盘操作的步骤
2020/04/01 Python
用opencv给图片换背景色的示例代码
2020/07/08 Python
pip/anaconda修改镜像源,加快python模块安装速度的操作
2021/03/04 Python
html5是什么_动力节点Java学院整理
2017/07/07 HTML / CSS
敬老院院长事迹材料
2014/05/21 职场文书
2014年服装销售工作总结
2014/11/27 职场文书
导游词之澳门妈祖庙
2019/12/19 职场文书
JavaScript 防篡改对象的用法示例
2021/04/24 Javascript
Win11应用商店打开闪退怎么解决? win11应用商店打不开的多种解决办法
2022/04/05 数码科技
德生BCL3000抢先使用感受和评价
2022/04/07 无线电