tensorflow estimator 使用hook实现finetune方式


Posted in Python onJanuary 21, 2020

为了实现finetune有如下两种解决方案:

model_fn里面定义好模型之后直接赋值

def model_fn(features, labels, mode, params):
 # .....
 # finetune
 if params.checkpoint_path and (not tf.train.latest_checkpoint(params.model_dir)):
 checkpoint_path = None
 if tf.gfile.IsDirectory(params.checkpoint_path):
  checkpoint_path = tf.train.latest_checkpoint(params.checkpoint_path)
 else:
  checkpoint_path = params.checkpoint_path

 tf.train.init_from_checkpoint(
  ckpt_dir_or_file=checkpoint_path,
  assignment_map={params.checkpoint_scope: params.checkpoint_scope} # 'OptimizeLoss/':'OptimizeLoss/'
 )

使用钩子 hooks。

可以在定义tf.contrib.learn.Experiment的时候通过train_monitors参数指定

# Define the experiment
 experiment = tf.contrib.learn.Experiment(
 estimator=estimator, # Estimator
 train_input_fn=train_input_fn, # First-class function
 eval_input_fn=eval_input_fn, # First-class function
 train_steps=params.train_steps, # Minibatch steps
 min_eval_frequency=params.eval_min_frequency, # Eval frequency
 # train_monitors=[], # Hooks for training
 # eval_hooks=[eval_input_hook], # Hooks for evaluation
 eval_steps=params.eval_steps # Use evaluation feeder until its empty
 )

也可以在定义tf.estimator.EstimatorSpec 的时候通过training_chief_hooks参数指定。

不过个人觉得最好还是在estimator中定义,让experiment只专注于控制实验的模式(训练次数,验证次数等等)。

def model_fn(features, labels, mode, params):

 # ....

 return tf.estimator.EstimatorSpec(
 mode=mode,
 predictions=predictions,
 loss=loss,
 train_op=train_op,
 eval_metric_ops=eval_metric_ops,
 # scaffold=get_scaffold(),
 # training_chief_hooks=None
 )

这里顺便解释以下tf.estimator.EstimatorSpec对像的作用。该对象描述来一个模型的方方面面。包括:

当前的模式:

mode: A ModeKeys. Specifies if this is training, evaluation or prediction.

计算图

predictions: Predictions Tensor or dict of Tensor.

loss: Training loss Tensor. Must be either scalar, or with shape [1].

train_op: Op for the training step.

eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, namely a (metric_tensor, update_op) tuple. metric_tensor should be evaluated without any impact on state (typically is a pure computation results based on variables.). For example, it should not trigger the update_op or requires any input fetching.

导出策略

export_outputs: Describes the output signatures to be exported to

SavedModel and used during serving. A dict {name: output} where:

name: An arbitrary name for this output.

output: an ExportOutput object such as ClassificationOutput, RegressionOutput, or PredictOutput. Single-headed models only need to specify one entry in this dictionary. Multi-headed models should specify one entry for each head, one of which must be named using signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.

chief钩子 训练时的模型保存策略钩子CheckpointSaverHook, 模型恢复等

training_chief_hooks: Iterable of tf.train.SessionRunHook objects to run on the chief worker during training.

worker钩子 训练时的监控策略钩子如: NanTensorHook LoggingTensorHook 等

training_hooks: Iterable of tf.train.SessionRunHook objects to run on all workers during training.

指定初始化和saver

scaffold: A tf.train.Scaffold object that can be used to set initialization, saver, and more to be used in training.

evaluation钩子

evaluation_hooks: Iterable of tf.train.SessionRunHook objects to run during evaluation.

自定义的钩子如下:

class RestoreCheckpointHook(tf.train.SessionRunHook):
 def __init__(self,
   checkpoint_path,
   exclude_scope_patterns,
   include_scope_patterns
   ):
 tf.logging.info("Create RestoreCheckpointHook.")
 #super(IteratorInitializerHook, self).__init__()
 self.checkpoint_path = checkpoint_path

 self.exclude_scope_patterns = None if (not exclude_scope_patterns) else exclude_scope_patterns.split(',')
 self.include_scope_patterns = None if (not include_scope_patterns) else include_scope_patterns.split(',')


 def begin(self):
 # You can add ops to the graph here.
 print('Before starting the session.')

 # 1. Create saver

 #exclusions = []
 #if self.checkpoint_exclude_scopes:
 # exclusions = [scope.strip()
 #  for scope in self.checkpoint_exclude_scopes.split(',')]
 #
 #variables_to_restore = []
 #for var in slim.get_model_variables(): #tf.global_variables():
 # excluded = False
 # for exclusion in exclusions:
 # if var.op.name.startswith(exclusion):
 # excluded = True
 # break
 # if not excluded:
 # variables_to_restore.append(var)
 #inclusions
 #[var for var in tf.trainable_variables() if var.op.name.startswith('InceptionResnetV1')]

 variables_to_restore = tf.contrib.framework.filter_variables(
  slim.get_model_variables(),
  include_patterns=self.include_scope_patterns, # ['Conv'],
  exclude_patterns=self.exclude_scope_patterns, # ['biases', 'Logits'],

  # If True (default), performs re.search to find matches
  # (i.e. pattern can match any substring of the variable name).
  # If False, performs re.match (i.e. regexp should match from the beginning of the variable name).
  reg_search = True
 )
 self.saver = tf.train.Saver(variables_to_restore)


 def after_create_session(self, session, coord):
 # When this is called, the graph is finalized and
 # ops can no longer be added to the graph.

 print('Session created.')

 tf.logging.info('Fine-tuning from %s' % self.checkpoint_path)
 self.saver.restore(session, os.path.expanduser(self.checkpoint_path))
 tf.logging.info('End fineturn from %s' % self.checkpoint_path)

 def before_run(self, run_context):
 #print('Before calling session.run().')
 return None #SessionRunArgs(self.your_tensor)

 def after_run(self, run_context, run_values):
 #print('Done running one step. The value of my tensor: %s', run_values.results)
 #if you-need-to-stop-loop:
 # run_context.request_stop()
 pass


 def end(self, session):
 #print('Done with the session.')
 pass

以上这篇tensorflow estimator 使用hook实现finetune方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python编写的微博应用
Oct 17 Python
Python找出list中最常出现元素的方法
Jun 14 Python
python 实现红包随机生成算法的简单实例
Jan 04 Python
python+selenium实现自动抢票功能实例代码
Nov 23 Python
Python设计模式之代理模式实例详解
Jan 19 Python
对Python定时任务的启动和停止方法详解
Feb 19 Python
Python写一个基于MD5的文件监听程序
Mar 11 Python
python实现在函数中修改变量值的方法
Jul 16 Python
python实现拼图小游戏
Feb 22 Python
Python TKinter如何自动关闭主窗口
Feb 26 Python
python 串行执行和并行执行实例
Apr 30 Python
Python 没有main函数的原因
Jul 10 Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
关于tf.nn.dynamic_rnn返回值详解
Jan 20 #Python
You might like
PHP图片裁剪函数(保持图像不变形)
2014/05/04 PHP
测试JavaScript字符串处理性能的代码
2009/12/07 Javascript
JS模板实现方法
2013/04/03 Javascript
JS getAttribute和setAttribute(取得和设置属性)的使用介绍
2013/07/10 Javascript
js写的方法实现上传图片之后查看大图
2014/03/05 Javascript
上传图片js判断图片尺寸和格式兼容IE
2014/09/01 Javascript
Node.js中的模块机制学习笔记
2014/11/04 Javascript
浅谈jQuery事件绑定原理
2015/01/02 Javascript
JS+CSS实现感应鼠标渐变显示DIV层的方法
2015/02/20 Javascript
关于Bootstrap弹出框无法调用问题的解决办法
2016/03/10 Javascript
微信小程序使用第三方库Underscore.js步骤详解
2016/09/27 Javascript
js运动事件函数详解
2016/10/21 Javascript
理解JavaScript原型链
2016/10/25 Javascript
详解vue表单——小白速看
2018/04/08 Javascript
vee-validate vue 2.0自定义表单验证的实例
2018/08/28 Javascript
layui实现把数据表格时间戳转换为时间格式的例子
2019/09/12 Javascript
[02:12]Dota 2 推出全新英雄—— 电炎绝手
2019/08/23 DOTA
使用Python的Treq on Twisted来进行HTTP压力测试
2015/04/16 Python
Python中property属性实例解析
2018/02/10 Python
详解Python中如何写控制台进度条的整理
2018/03/07 Python
不知道这5种下划线的含义,你就不算真的会Python!
2018/10/09 Python
解决安装python库时windows error5 报错的问题
2018/10/21 Python
淘宝秒杀python脚本 扫码登录版
2019/09/19 Python
基于Python检测动态物体颜色过程解析
2019/12/04 Python
python上selenium的弹框操作实现
2020/07/13 Python
Pygame框架实现飞机大战
2020/08/07 Python
python 提高开发效率的5个小技巧
2020/10/19 Python
详解python中的异常捕获
2020/12/15 Python
英国家用电器折扣网站:Electrical Discount UK
2018/09/17 全球购物
Kipling意大利官网:世界著名的时尚休闲包袋品牌
2019/06/05 全球购物
中式餐厅创业计划书范文
2014/01/23 职场文书
创建服务型党组织实施方案
2014/02/25 职场文书
成品库仓管员岗位职责
2014/04/06 职场文书
优秀三好学生事迹材料
2014/08/31 职场文书
师德标兵事迹材料
2014/12/19 职场文书
python 判断字符串当中是否包含字符(str.contain)
2022/06/01 Python