tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python标准日志模块logging的使用方法
Nov 01 Python
在Docker上部署Python的Flask框架的教程
Apr 08 Python
python使用正则表达式的search()函数实现指定位置搜索功能
Nov 10 Python
Redis使用watch完成秒杀抢购功能的代码
May 07 Python
Python 2.7中文显示与处理方法
Jul 16 Python
解决Python print输出不换行没空格的问题
Nov 14 Python
python随机生成库faker库api实例详解
Nov 28 Python
python实现替换word中的关键文字(使用通配符)
Feb 13 Python
python判断两个序列的成员是否一样的实例代码
Mar 01 Python
在echarts中图例legend和坐标系grid实现左右布局实例
May 16 Python
Python 爬虫批量爬取网页图片保存到本地的实现代码
Dec 24 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
php实现的遍历文件夹下所有文件,编辑删除
2010/01/05 PHP
php实现图片添加描边字和马赛克的方法
2014/12/10 PHP
PHP中使用SimpleXML检查XML文件结构实例
2015/01/07 PHP
laravel5表单唯一验证的实例代码
2019/09/30 PHP
JavaScript进阶教程(第四课第一部分)
2007/04/05 Javascript
javascript prototype原型操作笔记
2009/12/07 Javascript
Javascript 实用小技巧
2010/04/07 Javascript
基于JQuery的简单实现折叠菜单代码
2010/09/15 Javascript
JavaScript常用对象的方法和属性小结
2012/01/24 Javascript
jquery ajax应用中iframe自适应高度问题解决方法
2014/04/12 Javascript
实例讲解JQuery中this和$(this)区别
2014/12/08 Javascript
jquery实现清新实用的网页菜单效果
2015/08/28 Javascript
jQuery简单实现点击文本框复制内容到剪贴板上的方法
2016/08/01 Javascript
解决webpack打包速度慢的解决办法汇总
2017/07/06 Javascript
使用JS动态显示文本
2017/09/09 Javascript
Vue项目中添加锁屏功能实现思路
2018/06/29 Javascript
Vue中插入HTML代码的方法
2018/09/21 Javascript
Angular ui-roter 和AngularJS 通过 ocLazyLoad 实现动态(懒)加载模块和依赖
2018/11/25 Javascript
js事件触发操作实例分析
2019/06/21 Javascript
详解如何在Vue项目中发送jsonp请求
2019/10/25 Javascript
js实现点赞效果
2020/03/16 Javascript
Vue的Options用法说明
2020/08/14 Javascript
ant-design-vue中tree增删改的操作方法
2020/11/03 Javascript
如何在vue-cli中使用css-loader实现css module
2021/01/07 Vue.js
使用python turtle画高达
2020/01/19 Python
Python实现寻找回文数字过程解析
2020/06/09 Python
Opencv+Python识别PCB板图片的步骤
2021/01/07 Python
购买大码女装:Lane Bryant
2016/09/07 全球购物
九年级体育教学反思
2014/01/23 职场文书
文科毕业生自荐书范文
2014/04/17 职场文书
行政监察建议书
2014/05/19 职场文书
李敖北大演讲稿
2014/05/24 职场文书
幼儿发展评估方案
2014/06/11 职场文书
2015秋季幼儿园开学通知
2015/07/16 职场文书
JS中一些高效的魔法运算符总结
2021/05/06 Javascript
Vue详细的入门笔记
2021/05/10 Vue.js