解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python检测手机QQ在线状态的脚本代码
Feb 10 Python
python中常用的各种数据库操作模块和连接实例
May 29 Python
python 远程统计文件代码分享
May 14 Python
Python获取运行目录与当前脚本目录的方法
Jun 01 Python
Python中的字典与成员运算符初步探究
Oct 13 Python
Python中字典和集合学习小结
Jul 07 Python
python实现自动发送邮件
Jun 20 Python
python实现内存监控系统
Mar 07 Python
python内存管理机制原理详解
Aug 12 Python
Python PyQt5运行程序把输出信息展示到GUI图形界面上
Apr 27 Python
深入了解Python 变量作用域
Jul 24 Python
python文件排序的方法总结
Sep 13 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
一周学会PHP(视频)Http下载
2006/12/12 PHP
解析PHP正则提取或替换img标记属性
2013/06/26 PHP
检测codeigniter脚本消耗内存情况的方法
2015/03/21 PHP
PHP变量赋值、代入给JavaScript中的变量
2015/06/29 PHP
php封装的单文件(图片)上传类完整实例
2016/10/18 PHP
Thinkphp5.0框架的Db操作实例分析【连接、增删改查、链式操作等】
2019/10/11 PHP
select组合框option的捕捉实例代码
2008/09/30 Javascript
Javascript 汉字字节判断
2009/08/01 Javascript
JavaScript 大数据相加的问题
2011/08/03 Javascript
20款效果非常棒的 jQuery 插件小结分享
2011/11/18 Javascript
javascript中定义私有方法说明(private method)
2014/01/27 Javascript
jquery中子元素和后代元素的区别示例介绍
2014/04/02 Javascript
点击标签切换和自动切换DIV选项卡
2014/08/10 Javascript
AngularJS iframe跨域打开内容时报错误的解决办法
2015/01/26 Javascript
jquery实现表单输入时提示文字滑动向上效果
2015/08/10 Javascript
JS实现侧边栏鼠标经过弹出框+缓冲效果
2017/03/29 Javascript
es6学习笔记之Async函数基本教程
2017/05/11 Javascript
bootstrap multiselect下拉列表功能
2017/08/22 Javascript
解决easyui日期时间框ie的兼容的问题
2018/03/01 Javascript
vue.js自定义组件directives的实例代码
2018/11/09 Javascript
Layui实现带查询条件的分页
2019/07/27 Javascript
JavaScript实现轮播图片完整代码
2020/03/07 Javascript
JavaScript装箱及拆箱boxing及unBoxing用法解析
2020/06/15 Javascript
Js利用正则表达式去除字符串的中括号
2020/11/23 Javascript
[06:44]2018DOTA2亚洲邀请赛4.5 SOLO赛 MidOne vs Sumail
2018/04/06 DOTA
[48:31]完美世界DOTA2联赛PWL S3 DLG vs Phoenix 第二场 12.17
2020/12/19 DOTA
python连接sql server乱码的解决方法
2013/01/28 Python
Python使用paramiko连接远程服务器执行Shell命令的实现
2021/03/04 Python
Html5页面在微信端的分享的实现方法
2018/08/30 HTML / CSS
草莓网化妆品加拿大网站:Strawberrynet Canada
2016/09/20 全球购物
班级读书活动总结
2014/06/30 职场文书
小学班主任工作总结2015
2015/04/07 职场文书
2016年学校禁毒宣传活动工作总结
2016/04/05 职场文书
pytorch 实现多个Dataloader同时训练
2021/05/29 Python
nginx的zabbix 5.0安装部署的方法步骤
2021/07/16 Servers
聊聊Python String型列表求最值的问题
2022/01/18 Python