tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作列表之list.extend()方法的使用
May 20 Python
Python编程修改MP3文件名称的方法
Apr 19 Python
python中requests爬去网页内容出现乱码问题解决方法介绍
Oct 25 Python
python测试mysql写入性能完整实例
Jan 18 Python
Python3多进程 multiprocessing 模块实例详解
Jun 11 Python
python让列表倒序输出的实例
Jun 25 Python
Python3匿名函数用法示例
Jul 25 Python
用python 实现在不确定行数情况下多行输入方法
Jan 28 Python
详解pandas中MultiIndex和对象实际索引不一致问题
Jul 23 Python
Python栈的实现方法示例【列表、单链表】
Feb 22 Python
Python包管理工具pip的15 个使用小技巧
May 17 Python
Python入门学习之类的相关知识总结
May 25 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
在百度知道团队中快速审批新成员的js脚本
2014/02/02 Javascript
JavaScript获取flash对象与网上的有所不同
2014/04/21 Javascript
浅析webapp框架AngularUI的demo
2014/12/21 Javascript
Angular用来控制元素的展示与否的原生指令介绍
2015/01/07 Javascript
javascript 获取浏览器版本
2015/01/21 Javascript
Javascript中的匿名函数与封装介绍
2015/03/15 Javascript
javascript中eval函数用法分析
2015/04/25 Javascript
使用控制台破解百小度一个月只准改一次名字
2015/08/13 Javascript
详解JavaScript正则表达式之RegExp对象
2015/12/13 Javascript
原生js实现移动端瀑布流式代码示例
2015/12/18 Javascript
JS简单实现仿百度控制台输出信息效果
2016/09/04 Javascript
tab栏切换原理
2017/03/22 Javascript
JavaScript编程设计模式之观察者模式(Observer Pattern)实例详解
2017/10/25 Javascript
vue中引入第三方字体文件的方法示例
2018/12/17 Javascript
jquery获取img的src值实例介绍
2019/01/16 jQuery
详解element-ui中表单验证的三种方式
2019/09/18 Javascript
解决Nuxt使用axios跨域问题
2020/07/06 Javascript
layui使用及简单的三级联动实现教程
2020/12/01 Javascript
[02:40]DOTA2超级联赛专访430 从小就爱玩对抗性游戏
2013/06/18 DOTA
Python中的测试模块unittest和doctest的使用教程
2015/04/14 Python
Python文本特征抽取与向量化算法学习
2017/12/22 Python
使用Python OpenCV为CNN增加图像样本的实现
2019/06/10 Python
python 实现查找文件并输出满足某一条件的数据项方法
2019/06/12 Python
PyCharm更改字体和界面样式的方法步骤
2019/09/27 Python
python求绝对值的三种方法小结
2019/12/04 Python
Python 跨.py文件调用自定义函数说明
2020/06/01 Python
Python3中对json格式数据的分析处理
2021/01/28 Python
2013年军训通讯稿
2014/02/05 职场文书
营销总监岗位职责范本
2014/02/26 职场文书
“四风”问题整改措施和努力方向
2014/09/20 职场文书
领导干部群众路线个人对照检查材料思想汇报
2014/09/30 职场文书
营销计划书
2015/01/17 职场文书
应聘教师自荐信
2015/03/26 职场文书
《大禹治水》教学反思
2016/02/22 职场文书
分析mysql中一条SQL查询语句是如何执行的
2021/06/21 MySQL
PostgreSQL自动更新时间戳实例代码
2021/11/27 PostgreSQL