详解tensorflow实现迁移学习实例


Posted in Python onFebruary 10, 2018

本文主要是总结利用tensorflow实现迁移学习的基本步骤。

所谓迁移学习,就是将上一个问题上训练好的模型通过简单的调整使其适用于一个新的问题。比如说,我们可以保留训练好的Inception-v3模型中所有的参数,只替换最后一层全连接层。在最后一层全连接层之前的网络称之为瓶颈层(bottleneck)。

持久化

首先需要简单介绍下tensorflow中的持久化:在tensorflow中提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是tf.train.Saver类。当采用该方法保存时会生成三个文件,一个文件是model.ckpt.meta,它保存了Tensorflow计算图的结构;第二个文件是model.ckpt,它保存了程序中每一个变量的取值;最后一个文件是checkpoint文件,这个文件中保存了一个目录下所有模型文件列表。

保存图

init_op = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init_op)
  saver.save(sess, "model.ckpt")

加载图

saver = tf.train.import_meta_graph("model.ckpt.meta")
with tf.Session() as sess:
  saver.restore(sess, "model.ckpt")

迁移学习

第一步: 读取加载已经训练好的模型

在inception-v3模型代表瓶颈层结果的张量名称是'pool3/_reshape:0',图像输入张量对应的名称'DecodeJpeg/contents:0'

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
#读取已经训练好的模型
  with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

第二步:利用读取的模型,定义新的神经网络输入,这个输入就是新的图片经过Inception-v3模型前向传播到达瓶颈层的取值,是一种特征提取过程。

def run_bottlenect_on_images(sess, image_data, image_data_tensor, bottlenect_tensor):
  bottlenect_values = sess.run(bottlenect_tensor, {image_data_tensor: image_data})

  # 经过卷积网络处理后的是一个思维数组,压缩成一个特征,一维向量输出
  bottlenect_values = np.squeeze(bottlenect_values)
  return bottlenect_values

该过程实际上利用获取的tensor计算图片的特征向量,完成特征提取的过程。

第三步:利用获取的图像的特征向量完成接下来的任务(比如分类)

以上是仅关键代码。希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置的字符串处理函数整理
Jan 29 Python
Python实现二叉搜索树
Feb 03 Python
python如何统计序列中元素
Jul 31 Python
Python 创建空的list,以及append用法讲解
May 04 Python
对numpy.append()里的axis的用法详解
Jun 28 Python
将Python字符串生成PDF的实例代码详解
May 17 Python
python实现Dijkstra算法的最短路径问题
Jun 21 Python
Python 图像处理: 生成二维高斯分布蒙版的实例
Jul 04 Python
Django MEDIA的配置及用法详解
Jul 25 Python
python中的列表与元组的使用
Aug 08 Python
python爬虫实现爬取同一个网站的多页数据的实例讲解
Jan 18 Python
pytorch中的model=model.to(device)使用说明
May 24 Python
Python学习之Django的管理界面代码示例
Feb 10 #Python
Tensorflow 自带可视化Tensorboard使用方法(附项目代码)
Feb 10 #Python
tensorflow训练中出现nan问题的解决
Feb 10 #Python
用Eclipse写python程序
Feb 10 #Python
tensorflow建立一个简单的神经网络的方法
Feb 10 #Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
You might like
PHP按行读取、处理较大CSV文件的代码实例
2014/04/09 PHP
PHP表单提交后引号前自动加反斜杠的原因及三种办法关闭php魔术引号
2015/09/30 PHP
php编程中echo用逗号和用点号连接的区别
2016/03/26 PHP
Yii2实现ajax上传图片插件用法
2016/04/28 PHP
phpinfo()中Loaded Configuration File(none)的解决方法
2017/01/16 PHP
PHP实现的数据对象映射模式详解
2019/03/20 PHP
javascript引用对象的方法
2007/01/11 Javascript
jQuery弹出层插件简化版代码下载
2008/10/16 Javascript
主页面中的两个iframe实现鼠标拖动改变其大小
2013/04/16 Javascript
查找页面中所有类为test的结点的方法
2014/03/28 Javascript
jQuery中复合属性选择器用法实例
2014/12/31 Javascript
jquery事件preventDefault()方法用法实例
2015/01/16 Javascript
javascript日期操作详解(脚本之家整理)
2015/09/05 Javascript
GitHub上一些实用的JavaScript的文件压缩解压缩库推荐
2016/03/13 Javascript
实例解析jQuery中proxy()函数的用法
2016/05/24 Javascript
vue中Npm run build 根据环境传递参数方法来打包不同域名
2018/03/29 Javascript
微信小程序使用scroll-view标签实现自动滑动到底部功能的实例代码
2018/11/09 Javascript
解决layui追加或者动态修改的表单元素“没效果”的问题
2019/09/18 Javascript
JS eval代码快速解密实例解析
2020/04/23 Javascript
[42:36]DOTA2上海特级锦标赛B组败者赛 VG VS Spirit第二局
2016/02/26 DOTA
python基础教程之字典操作详解
2014/03/25 Python
python生成式的send()方法(详解)
2017/05/08 Python
高效使用Python字典的清单
2018/04/04 Python
如何使用django的MTV开发模式返回一个网页
2019/07/22 Python
用Python将Excel数据导入到SQL Server的例子
2019/08/24 Python
vue常用指令代码实例总结
2020/03/16 Python
HTML5对手机页面长按会粘贴复制禁用的解决方法
2016/07/19 HTML / CSS
西班牙英格列斯百货官网:El Corte Inglés
2016/09/25 全球购物
北大青鸟学生求职信
2013/09/24 职场文书
教师评优事迹材料
2014/01/10 职场文书
大学生通用个人自我评价
2014/04/27 职场文书
社团活动总结
2014/04/28 职场文书
2014银行授权委托书样本
2014/10/04 职场文书
2016银行招聘自荐信
2016/01/28 职场文书
小学作文之描写天气
2019/08/15 职场文书
Python实现Hash算法
2022/03/18 Python