keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作xml文件示例
Apr 07 Python
python检查字符串是否是正确ISBN的方法
Jul 11 Python
Python实现Linux的find命令实例分享
Jun 04 Python
Python 3.x 安装opencv+opencv_contrib的操作方法
Apr 02 Python
简单谈谈Python的pycurl模块
Apr 07 Python
python绘制立方体的方法
Jul 02 Python
解决vscode python print 输出窗口中文乱码的问题
Dec 03 Python
python实现求特征选择的信息增益
Dec 18 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 Python
Django 构建模板form表单的两种方法
Jun 14 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
一个用于mysql的数据库抽象层函数库
2006/10/09 PHP
php利用header函数实现文件下载时直接提示保存
2009/11/12 PHP
PHP 实现页面静态化的几种方法
2017/07/23 PHP
使两个iframe的高度与内容自适应,且相等
2006/11/20 Javascript
JavaScript 学习笔记(十二) dom
2010/01/21 Javascript
javascript页面倒计时实例
2015/07/25 Javascript
用js动态添加html元素,以及属性的简单实例
2016/07/19 Javascript
jQuery Easyui DataGrid点击某个单元格即进入编辑状态焦点移开后保存数据
2016/08/15 Javascript
javascript实现页面滚屏效果
2017/01/17 Javascript
基于Vue实现timepicker
2017/04/25 Javascript
详解微信小程序 登录获取unionid
2017/06/27 Javascript
JS原生带小白点轮播图实例讲解
2017/07/22 Javascript
详解关于Vue版本不匹配问题(Vue packages version mismatch)
2018/09/17 Javascript
浅谈Vue2.4.0 $attrs与inheritAttrs的具体使用
2020/03/08 Javascript
简单了解Vue + ElementUI后台管理模板
2020/04/07 Javascript
详解使用python crontab设置linux定时任务
2016/12/08 Python
浅谈Python类的__getitem__和__setitem__特殊方法
2016/12/25 Python
基于Python数据结构之递归与回溯搜索
2020/02/26 Python
Tensorflow中的图(tf.Graph)和会话(tf.Session)的实现
2020/04/22 Python
python os.listdir()乱码解决方案
2021/01/31 Python
CSS3的常见transformation图形变化用法小结
2016/05/13 HTML / CSS
涂鸦板简单实现 Html5编写属于自己的画画板
2016/07/05 HTML / CSS
基于HTML5+CSS3实现简单的时钟效果
2017/09/11 HTML / CSS
网购亚洲时装、美容产品和生活百货:YesStyle
2016/09/15 全球购物
iHerb中文官网:维生素、保健品和健康产品
2018/11/01 全球购物
Charles & Keith欧盟:新加坡时尚品牌
2019/08/01 全球购物
一道输出判断型Java面试题
2014/10/01 面试题
生产文员岗位职责
2014/04/05 职场文书
抗震救灾标语
2014/06/26 职场文书
合同和协议有什么区别?
2014/10/08 职场文书
2014年置业顾问工作总结
2014/11/17 职场文书
2014年审计人员工作总结
2014/12/19 职场文书
中班下学期个人工作总结
2015/02/12 职场文书
先进工作者主要事迹材料
2015/11/03 职场文书
MySQL中存储时间的最佳实践指南
2021/07/01 MySQL
vue el-table实现递归嵌套的示例代码
2022/08/14 Vue.js