解决Tensorflow sess.run导致的内存溢出问题


Posted in Python onFebruary 05, 2020

下面是调用模型进行批量测试的代码(出现溢出),开始以为导致溢出的原因是数据读入方式问题引起的,用了tf , PIL和cv等方式读入图片数据,发现越来越慢,内存占用飙升,调试时发现是sess.run这里出了问题(随着for循环进行速度越来越慢)。

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up, axis=3)
    
 
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解决溢出问题的代码(将部分代码放在for循环外

# Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name, size=[h, w], align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up, axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir, filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data, (w, h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1, 0)
    
    predict_img = sess.run(raw_output_up, feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data, predict_img, voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]), masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

总结:

在迭代过程中, 在sess.run的for循环中不要加入tensorflow一些op操作,会增加图节点,否则随着迭代的进行,tf的图会越来越大,最终导致溢出;

建议不要使用tf.gfile.FastGFile(image_path, 'rb').read()读入数据(有可能会造成溢出),用opencv之类读取。

以上这篇解决Tensoflow sess.run导致的内存溢出问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python装饰器验证配置文件示例
Feb 24 Python
在Mac OS系统上安装Python的Pillow库的教程
Nov 20 Python
python实现按任意键继续执行程序
Dec 30 Python
python 中if else 语句的作用及示例代码
Mar 05 Python
使用实现pandas读取csv文件指定的前几行
Apr 20 Python
python3中的md5加密实例
May 29 Python
详解python 模拟豆瓣登录(豆瓣6.0)
Apr 18 Python
Python 一键制作微信好友图片墙的方法
May 16 Python
Django框架静态文件使用/中间件/禁用ip功能实例详解
Jul 22 Python
详解python中各种文件打开模式
Jan 19 Python
python3中使用__slots__限定实例属性操作分析
Feb 14 Python
Python用来做Web开发的优势有哪些
Aug 05 Python
解决TensorFlow训练内存不断增长,进程被杀死问题
Feb 05 #Python
浅谈tensorflow之内存暴涨问题
Feb 05 #Python
对Tensorflow中Device实例的生成和管理详解
Feb 04 #Python
关于windows下Tensorflow和pytorch安装教程
Feb 04 #Python
django3.02模板中的超链接配置实例代码
Feb 04 #Python
tensorflow自定义激活函数实例
Feb 04 #Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
You might like
php 动态添加记录
2009/03/10 PHP
php at(@)符号的用法简介
2009/07/11 PHP
PHP程序级守护进程的实现与优化的使用概述
2013/05/02 PHP
Symfony2实现从数据库获取数据的方法小结
2016/03/18 PHP
php通过curl添加cookie伪造登陆抓取数据的方法
2016/04/02 PHP
yii2.0整合阿里云oss的示例代码
2017/09/19 PHP
PHP中有关长整数的一些操作教程
2019/09/11 PHP
javascript实现轮显新闻标题链接
2007/08/13 Javascript
JQuery Tips(2) 关于$()包装集你不知道的
2009/12/14 Javascript
Jquery获取和修改img的src值的方法
2014/02/17 Javascript
js实现文章文字大小字号功能完整实例
2014/11/01 Javascript
基于jQuery实现文本框只能输入数字(小数、整数)
2016/01/14 Javascript
javascript实现无法关闭的弹框
2016/11/27 Javascript
JavaScript实现公历转农历功能示例
2017/02/13 Javascript
Javascript之图片的延迟加载的实例详解
2017/07/24 Javascript
浅谈在vue项目中如何定义全局变量和全局函数
2017/10/24 Javascript
js字符串处理之绝妙的代码
2019/04/05 Javascript
Vue组件通信的几种实现方法
2019/04/25 Javascript
js Array.slice的8种不同用法示例
2019/07/10 Javascript
Vue实现图片与文字混输效果
2019/12/04 Javascript
python 系统调用的实例详解
2017/07/11 Python
python实现最长公共子序列
2018/05/22 Python
python实现傅里叶级数展开的实现
2018/07/21 Python
Python 字符串与二进制串的相互转换示例
2018/07/23 Python
pandas 把数据写入txt文件每行固定写入一定数量的值方法
2018/12/28 Python
pycharm配置pyqt5-tools开发环境的方法步骤
2019/02/11 Python
python整小时 整天时间戳获取算法示例
2019/02/20 Python
pyqt5数据库使用详细教程(打包解决方案)
2020/03/25 Python
使用python从三个角度解决josephus问题的方法
2020/03/27 Python
Python参数传递实现过程及原理详解
2020/05/14 Python
来自Ocado的宠物商店:Fetch
2018/07/10 全球购物
自动化专业毕业生自荐信
2013/11/01 职场文书
2014年高考决心书
2014/03/11 职场文书
处级领导班子全部召开专题民主生活会情况汇报
2014/09/27 职场文书
党建工作目标管理责任书
2015/01/29 职场文书
python井字棋游戏实现人机对战
2022/04/28 Python