keras多显卡训练方式


Posted in Python onJune 10, 2020

使用keras进行训练,默认使用单显卡,即使设置了os.environ['CUDA_VISIBLE_DEVICES']为两张显卡,也只是占满了显存,再设置tf.GPUOptions(allow_growth=True)之后可以清楚看到,只占用了第一张显卡,第二张显卡完全没用。

要使用多张显卡,需要按如下步骤:

(1)import multi_gpu_model函数:from keras.utils import multi_gpu_model

(2)在定义好model之后,使用multi_gpu_model设置模型由几张显卡训练,如下:

model=Model(...) #定义模型结构
model_parallel=multi_gpu_model(model,gpu=n) #使用几张显卡n等于几
model_parallel.compile(...) #注意是model_parallel,不是model

通过以上代码,model将作为CPU上的原始模型,而model_parallel将作为拷贝模型被复制到各个GPU上进行梯度计算。如果batchsize为128,显卡n=2,则每张显卡单独计算128/2=64张图像,然后在CPU上将两张显卡计算得到的梯度进行融合更新,并对模型权重进行更新后再将新模型拷贝到GPU再次训练。

(3)从上面可以看出,进行训练时,仍然在model_parallel上进行:

model_parallel.fit(...) #注意是model_parallel

(4)保存模型时,model_parallel保存了训练时显卡数量的信息,所以如果直接保存model_parallel的话,只能将模型设置为相同数量的显卡调用,否则训练的模型将不能调用。因此,为了之后的调用方便,只保存CPU上的模型,即model:

model.save(...) #注意是model,不是model_parallel

如果用到了callback函数,则默认保存的也是model_parallel(因为训练函数是针对model_parallel的),所以要用回调函数保存model的话需要自己对回调函数进行定义:

class OwnCheckpoint(keras.callbacks.Callback):
 def __init__(self,model):
  self.model_to_save=model
 def on_epoch_end(self,epoch,logs=None): #这里logs必须写
  self.model_to_save.save('model_advanced/model_%d.h5' % epoch)

定以后具体使用如下:

checkpoint=OwnCheckpoint(model)
model_parallel.fit_generator(...,callbacks=[checkpoint])

这样就没问题了!

补充知识:keras.fit_generator及多卡训练记录

1.环境问题

使用keras,以tensorflow为背景,tensorflow1.14多卡训练会出错 python3.6

2.代码

2.1

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'

2.2 自定义generator函数

def img_image_generator(path_img, path_lab, batch_size, data_list):
 while True:
 # 'train_list.csv'
 file_list = pd.read_csv(data_list, sep=',',usecols=[1]).values.tolist()
 file_list = [i[0] for i in file_list]
 cnt = 0
 X = []
 Y1 = []
 for file_i in file_list:
 x = cv2.imread(path_img+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 x = x.astype('float32')
 x /= 255.
 y = cv2.imread(path_lab+'/'+file_i, cv2.IMREAD_GRAYSCALE)
 y = y.astype('float32')
 y /= 255.
 X.append(x.reshape(256, 256, 1))
 Y1.append(y.reshape(256, 256, 1))
 cnt += 1
 if cnt == batch_size:
 cnt = 0
 yield (np.array(X), [np.array(Y1), np.array(Y1)])
 X = []
 Y1 = []

2.3 函数调用及训练

generator_train = img_image_generator(path1, path2, 4, pathcsv_train)
 generator_test= img_image_generator(path1, path2, 4, pathcsv_test)
 model.fit_generator(generator_train, steps_per_epoch=237*2, epochs=50, callbacks=callbacks_list, validation_data=generator_test, validation_steps=60*2)

3. 多卡训练

3.1 复制model

model_parallel = multi_gpu_model(model, gpus=2)

3.2 checkpoint 定义

class ParallelModelCheckpoint(ModelCheckpoint):
  def __init__(self, model, filepath, monitor='val_out_final_score', verbose=0,\
   save_best_only=False, save_weights_only=False, mode='auto', period=1):
   self.single_model = model 
   super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, period)
  
  def set_model(self, model):
   super(ParallelModelCheckpoint, self).set_model(self.single_model)

使用

model_checkpoint = ParallelModelCheckpoint(model=model, filepath=filepath, monitor='val_loss',verbose=1, save_best_only=True, mode='min')

3.3 注意的问题

保存模型是时候需要使用以原来的模型保存,不能使用model_parallel保存

以上这篇keras多显卡训练方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中测试访问同一数据的竞争条件的方法
Apr 23 Python
分析用Python脚本关闭文件操作的机制
Jun 28 Python
Python 爬虫多线程详解及实例代码
Oct 08 Python
Python实现繁体中文与简体中文相互转换的方法示例
Dec 18 Python
python多线程调用exit无法退出的解决方法
Feb 18 Python
python基于递归解决背包问题详解
Jul 03 Python
python装饰器代替set get方法实例
Dec 19 Python
tensorflow 只恢复部分模型参数的实例
Jan 06 Python
30行Python代码实现高分辨率图像导航的方法
May 22 Python
python 绘制场景热力图的示例
Sep 23 Python
在python中实现导入一个需要传参的模块
May 12 Python
python3中apply函数和lambda函数的使用详解
Feb 28 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 #Python
keras 多gpu并行运行案例
Jun 10 #Python
Keras自定义IOU方式
Jun 10 #Python
Python实现在线批量美颜功能过程解析
Jun 10 #Python
浅谈keras中的目标函数和优化函数MSE用法
Jun 10 #Python
keras 解决加载lstm+crf模型出错的问题
Jun 10 #Python
使用Keras加载含有自定义层或函数的模型操作
Jun 10 #Python
You might like
虚拟主机中对PHP的特殊设置
2006/10/09 PHP
php实现单链表的实例代码
2013/03/22 PHP
PHP生成指定随机字符串的简单实现方法
2015/04/01 PHP
使用PHP连接多种数据库的实现代码(mysql,access,sqlserver,Oracle)
2016/12/21 PHP
在laravel-admin中列表中禁止某行编辑、删除的方法
2019/10/03 PHP
捕获关闭窗口的脚本
2009/01/10 Javascript
善用事件代理,警惕闭包的性能陷阱。
2011/01/20 Javascript
70+漂亮且极具亲和力的导航菜单设计国外网站推荐
2011/09/20 Javascript
输入框的字数时时统计—关于 onpropertychange 和 oninput 使用
2011/10/21 Javascript
jQuery 仿百度输入标签插件附效果图
2014/07/04 Javascript
JavaScript判断表单中多选框checkbox选中个数的方法
2015/08/17 Javascript
JavaScript函数的一些注意要点小结及js匿名函数
2015/11/10 Javascript
jQuery中使用animate自定义动画的方法
2016/05/29 Javascript
javascript类型系统_正则表达式RegExp类型详解
2016/06/24 Javascript
微信小程序之小豆瓣图书实例
2016/11/30 Javascript
jQuery实现导航回弹效果
2017/02/27 Javascript
js实现图片粘贴上传到服务器并展示的实例
2017/11/08 Javascript
[01:37]TI4西雅图DOTA2前线报道 VG拿下首胜教练357给出获胜秘诀
2014/07/10 DOTA
[43:03]LGD vs Newbee 2019国际邀请赛小组赛 BO2 第一场 8.16
2019/08/19 DOTA
Eclipse中Python开发环境搭建简单教程
2016/03/23 Python
python3.5 tkinter实现页面跳转
2018/01/30 Python
Python中关于logging模块的学习笔记
2020/06/03 Python
英国著名的化妆品折扣网站:Allbeauty.com
2016/07/21 全球购物
德国内衣、泳装和睡衣网上商店:Bigsize Dessous
2018/07/09 全球购物
J2EE相关知识面试题
2013/08/26 面试题
高校毕业生登记表自我鉴定
2013/11/03 职场文书
幼儿园教师节活动方案
2014/02/02 职场文书
读群众路线心得体会
2014/03/07 职场文书
见习期自我鉴定范文
2014/03/19 职场文书
小学班主任培训方案
2014/06/04 职场文书
幼儿学前班评语
2014/12/29 职场文书
让世界充满爱观后感
2015/06/10 职场文书
圣贤教育改变命运观后感
2015/06/16 职场文书
Python函数对象与闭包函数
2022/04/13 Python
MySQL 数据表操作
2022/05/04 MySQL
Go中使用gjson来操作JSON数据的实现
2022/08/14 Golang