tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
Python正则表达式实现截取成对括号的方法
Jan 06 Python
详解用python实现简单的遗传算法
Jan 02 Python
Python面向对象之类和对象属性的增删改查操作示例
Dec 14 Python
对Python3 * 和 ** 运算符详解
Feb 16 Python
Python 抓取微信公众号账号信息的方法
Jun 14 Python
python+selenium 鼠标事件操作方法
Aug 24 Python
Python直接赋值及深浅拷贝原理详解
Sep 05 Python
python中doctest库实例用法
Dec 31 Python
Python读取ini配置文件传参的简单示例
Jan 05 Python
详解Python 中的 defaultdict 数据类型
Feb 22 Python
matplotlib之属性组合包(cycler)的使用
Feb 24 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
在WAMP环境下搭建ZendDebugger php调试工具的方法
2011/07/18 PHP
php 按指定元素值去除数组元素的实现方法
2011/11/04 PHP
PHP 9 大缓存技术总结
2015/09/17 PHP
Laravel 对某一列进行筛选然后求和sum()的例子
2019/10/10 PHP
PHP实现图片防盗链破解操作示例【解决图片防盗链问题/反向代理】
2020/05/29 PHP
javascript基础知识大集锦(一) 推荐收藏
2011/01/13 Javascript
Jquery多选下拉列表插件jquery multiselect功能介绍及使用
2013/05/24 Javascript
nodejs的10个性能优化技巧
2014/07/15 NodeJs
JS判断客服QQ号在线还是离线状态的方法
2015/01/13 Javascript
jquery图形密码实现方法
2015/03/11 Javascript
图解Sublime Text3使用技巧
2015/12/21 Javascript
jquery操作select取值赋值与设置选中实例
2017/02/28 Javascript
微信小程序--组件(swiper)详细介绍
2017/06/13 Javascript
解决vue.js在编写过程中出现空格不规范报错的问题
2017/09/20 Javascript
vue组件实现可搜索下拉框扩展
2020/10/23 Javascript
react 兄弟组件如何调用对方的方法示例
2018/10/23 Javascript
用node开发并发布一个cli工具的方法步骤
2019/01/03 Javascript
JQuery判断radio单选框是否选中并获取值的方法
2019/01/17 jQuery
JointJS JavaScript流程图绘制框架解析
2019/08/15 Javascript
Nodejs使用archiver-zip-encrypted库加密压缩文件时报错(解决方案)
2019/11/18 NodeJs
js实现自动播放匀速轮播图
2020/02/06 Javascript
Python 反转字符串(reverse)的方法小结
2018/02/20 Python
Python之pandas读写文件乱码的解决方法
2018/04/20 Python
Django使用paginator插件实现翻页功能的实例
2018/10/24 Python
Python File(文件) 方法整理
2019/02/18 Python
你不知道的5个HTML5新功能
2016/06/28 HTML / CSS
html5配合css3实现带提示文字的输入框(摆脱js)
2013/03/08 HTML / CSS
香港唯港荟酒店预订:Hotel ICON
2018/03/27 全球购物
HomeAway澳大利亚:预订你的度假屋,公寓、度假村、别墅等
2019/02/20 全球购物
JAVA程序员面试题
2012/10/03 面试题
教师自荐书
2013/10/08 职场文书
教研活动总结
2014/04/28 职场文书
教师求职信范文
2014/05/24 职场文书
大学专科求职信
2014/07/02 职场文书
2014年“世界无车日”活动方案
2014/09/21 职场文书
如何优化vue打包文件过大
2022/04/13 Vue.js