tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现国外赌场热门游戏Craps(双骰子)
Mar 31 Python
Python基础入门之seed()方法的使用
May 15 Python
Django中对通过测试的用户进行限制访问的方法
Jul 23 Python
使用Python操作excel文件的实例代码
Oct 15 Python
python实战之实现excel读取、统计、写入的示例讲解
May 02 Python
Python除法之传统除法、Floor除法及真除法实例详解
May 23 Python
Python生态圈图像格式转换问题(推荐)
Dec 02 Python
Python监控服务器实用工具psutil使用解析
Dec 19 Python
pytorch快速搭建神经网络_Sequential操作
Jun 17 Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 Python
python 下载文件的多种方法汇总
Nov 17 Python
python 如何获取页面所有a标签下href的值
May 06 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
PHP查找与搜索数组元素方法总结
2015/06/12 PHP
PHP之预定义接口详解
2015/07/29 PHP
php实现的表单验证类完整示例
2019/08/13 PHP
PHP中用Trait封装单例模式的实现
2019/12/18 PHP
使用prototype.js进行异步操作
2007/02/07 Javascript
在Javascript里访问SharePoint列表数据的实现方法
2011/05/22 Javascript
简单漂亮的js弹窗可自由拖拽且兼容大部分浏览器
2013/10/22 Javascript
javascript获取url上某个参数的方法
2013/11/08 Javascript
自己动手手写jQuery插件总结
2015/01/20 Javascript
js实现百度联盟中一款不错的图片切换效果完整实例
2015/03/04 Javascript
逐一介绍Jquery data()、Jquery stop()、jquery delay()函数(详)
2015/11/04 Javascript
浅谈javascript中onbeforeunload与onunload事件
2015/12/10 Javascript
悬浮广告方法日常收集整理
2016/03/18 Javascript
原生js编写autoComplete插件
2016/04/13 Javascript
Angular 4 指令快速入门教程
2017/06/07 Javascript
基于JS实现网页中的选项卡(两种方法)
2017/06/16 Javascript
Webpack常见静态资源处理-模块加载器(Loaders)+ExtractTextPlugin插件
2017/06/29 Javascript
vue-cli启动本地服务局域网不能访问的原因分析
2018/01/22 Javascript
jQuery 实现批量提交表格多行数据的方法
2018/08/09 jQuery
Vue中this.$nextTick的作用及用法
2020/02/04 Javascript
pygame学习笔记(2):画点的三种方法和动画实例
2015/04/15 Python
python获取本地计算机名字的方法
2015/04/29 Python
简单介绍Python的Django框架加载模版的方式
2015/07/20 Python
CentOS 7下安装Python 3.5并与Python2.7兼容并存详解
2017/07/07 Python
python 动态生成变量名以及动态获取变量的变量名方法
2019/01/20 Python
python实现堆排序的实例讲解
2020/02/21 Python
CSS3 flex布局之快速实现BorderLayout布局
2015/12/03 HTML / CSS
Bootstrap 学习分享
2012/11/12 HTML / CSS
时尚孕妇装:HATCH Collection
2019/09/24 全球购物
大学新生军训个人的自我评价
2013/10/03 职场文书
总经理职责范文
2013/11/08 职场文书
战略合作意向书范本
2014/04/01 职场文书
中秋节国旗下演讲稿
2014/09/05 职场文书
12.4全国法制宣传日活动方案
2014/11/02 职场文书
先进党组织事迹材料
2014/12/26 职场文书
Oracle锁表解决方法的详细记录
2022/06/05 Oracle