TensorFlow实现自定义Op方式


Posted in Python onFebruary 04, 2020

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"
 
REGISTER_OP("Custom")  
  .Input("custom_input: int32")
  .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"
 
using namespace tensorflow;
 
class CustomOp : public OpKernel{
  public:
  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
  // 获取输入 tensor.
  const Tensor& input_tensor = context->input(0);
  auto input = input_tensor.flat<int32>();
  // 创建一个输出 tensor.
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                           &output_tensor));
  auto output = output_tensor->template flat<int32>();
  //进行具体的运算,操作input和output
  //……
 }
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>
#include <vector>
#include <cmath>
 
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
 
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
 
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
 
using namespace tensorflow;
 
REGISTER_OP("CTCBeamSearchDecoderWithParam")
  .Input("inputs: float")
  .Input("sequence_length: int32")
  .Attr("beam_width: int >= 1")
  .Attr("top_paths: int >= 1")
  .Attr("merge_repeated: bool = true")
  //新添加了两个参数
  .Attr("label_selection_size: int >= 0 = 0") 
  .Attr("label_selection_margin: float") 
  .Output("decoded_indices: top_paths * int64")
  .Output("decoded_values: top_paths * int64")
  .Output("decoded_shape: top_paths * int64")
  .Output("log_probability: float")
  .SetShapeFn([](InferenceContext* c) {
   ShapeHandle inputs;
   ShapeHandle sequence_length;
 
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
 
   // Get batch size from inputs and sequence_length.
   DimensionHandle batch_size;
   TF_RETURN_IF_ERROR(
     c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
 
   int32 top_paths;
   TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
 
   // Outputs.
   int out_idx = 0;
   for (int i = 0; i < top_paths; ++i) { // decoded_indices
    c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
   }
   for (int i = 0; i < top_paths; ++i) { // decoded_values
    c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
   }
   ShapeHandle shape_v = c->Vector(2);
   for (int i = 0; i < top_paths; ++i) { // decoded_shape
    c->set_output(out_idx++, shape_v);
   }
   c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
   return Status::OK();
  });
 
typedef Eigen::ThreadPoolDevice CPUDevice;
 
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
          int* c) {
 *c = 0;
 CHECK_LT(0, m.dimension(1));
 float p = m(r, 0);
 for (int i = 1; i < m.dimension(1); ++i) {
  if (m(r, i) > p) {
   p = m(r, i);
   *c = i;
  }
 }
 return p;
}
 
class CTCDecodeHelper {
 public:
 CTCDecodeHelper() : top_paths_(1) {}
 
 inline int GetTopPaths() const { return top_paths_; }
 void SetTopPaths(int tp) { top_paths_ = tp; }
 
 Status ValidateInputsGenerateOutputs(
   OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
   Tensor** log_prob, OpOutputList* decoded_indices,
   OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
  Status status = ctx->input("inputs", inputs);
  if (!status.ok()) return status;
  status = ctx->input("sequence_length", seq_len);
  if (!status.ok()) return status;
 
  const TensorShape& inputs_shape = (*inputs)->shape();
 
  if (inputs_shape.dims() != 3) {
   return errors::InvalidArgument("inputs is not a 3-Tensor");
  }
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
 
  if (max_time == 0) {
   return errors::InvalidArgument("max_time is 0");
  }
  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
   return errors::InvalidArgument("sequence_length is not a vector");
  }
 
  if (!(batch_size == (*seq_len)->dim_size(0))) {
   return errors::FailedPrecondition(
     "len(sequence_length) != batch_size. ", "len(sequence_length): ",
     (*seq_len)->dim_size(0), " batch_size: ", batch_size);
  }
 
  auto seq_len_t = (*seq_len)->vec<int32>();
 
  for (int b = 0; b < batch_size; ++b) {
   if (!(seq_len_t(b) <= max_time)) {
    return errors::FailedPrecondition("sequence_length(", b, ") <= ",
                     max_time);
   }
  }
 
  Status s = ctx->allocate_output(
    "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
  if (!s.ok()) return s;
 
  s = ctx->output_list("decoded_indices", decoded_indices);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_values", decoded_values);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_shape", decoded_shape);
  if (!s.ok()) return s;
 
  return Status::OK();
 }
 
 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 Status StoreAllDecodedSequences(
   const std::vector<std::vector<std::vector<int> > >& sequences,
   OpOutputList* decoded_indices, OpOutputList* decoded_values,
   OpOutputList* decoded_shape) const {
  // Calculate the total number of entries for each path
  const int64 batch_size = sequences.size();
  std::vector<int64> num_entries(top_paths_, 0);
 
  // Calculate num_entries per path
  for (const auto& batch_s : sequences) {
   CHECK_EQ(batch_s.size(), top_paths_);
   for (int p = 0; p < top_paths_; ++p) {
    num_entries[p] += batch_s[p].size();
   }
  }
 
  for (int p = 0; p < top_paths_; ++p) {
   Tensor* p_indices = nullptr;
   Tensor* p_values = nullptr;
   Tensor* p_shape = nullptr;
 
   const int64 p_num = num_entries[p];
 
   Status s =
     decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
   if (!s.ok()) return s;
   s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
   if (!s.ok()) return s;
   s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
   if (!s.ok()) return s;
 
   auto indices_t = p_indices->matrix<int64>();
   auto values_t = p_values->vec<int64>();
   auto shape_t = p_shape->vec<int64>();
 
   int64 max_decoded = 0;
   int64 offset = 0;
 
   for (int64 b = 0; b < batch_size; ++b) {
    auto& p_batch = sequences[b][p];
    int64 num_decoded = p_batch.size();
    max_decoded = std::max(max_decoded, num_decoded);
    std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
    for (int64 t = 0; t < num_decoded; ++t, ++offset) {
     indices_t(offset, 0) = b;
     indices_t(offset, 1) = t;
    }
   }
 
   shape_t(0) = batch_size;
   shape_t(1) = max_decoded;
  }
  return Status::OK();
 }
 
 private:
 int top_paths_;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
 
// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
 public:
 explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
  //从参数列表中读取新添的两个参数
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));
  int top_paths;
  OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
  decode_helper_.SetTopPaths(top_paths);
 }
 
 void Compute(OpKernelContext* ctx) override {
  const Tensor* inputs;
  const Tensor* seq_len;
  Tensor* log_prob = nullptr;
  OpOutputList decoded_indices;
  OpOutputList decoded_values;
  OpOutputList decoded_shape;
  OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
              ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
              &decoded_values, &decoded_shape));
 
  auto inputs_t = inputs->tensor<float, 3>();
  auto seq_len_t = seq_len->vec<int32>();
  auto log_prob_t = log_prob->matrix<float>();
 
  const TensorShape& inputs_shape = inputs->shape();
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
  const int64 num_classes_raw = inputs_shape.dim_size(2);
  OP_REQUIRES(
    ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    errors::InvalidArgument("num_classes cannot exceed max int"));
  const int num_classes = static_cast<const int>(num_classes_raw);
 
  log_prob_t.setZero();
 
  std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
 
  for (std::size_t t = 0; t < max_time; ++t) {
   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
                batch_size, num_classes);
  }
 
  ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
                      &beam_scorer_, 1 /* batch_size */,
                      merge_repeated_);
  //使用传入的两个参数进行Set
  beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);
  Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
  auto input_chip_t = input_chip.flat<float>();
 
  std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
  std::vector<float> log_probs;
 
  // Assumption: the blank index is num_classes - 1
  for (int b = 0; b < batch_size; ++b) {
   auto& best_paths_b = best_paths[b];
   best_paths_b.resize(decode_helper_.GetTopPaths());
   for (int t = 0; t < seq_len_t(b); ++t) {
    input_chip_t = input_list_t[t].chip(b, 0);
    auto input_bi =
      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
    beam_search.Step(input_bi);
   }
   OP_REQUIRES_OK(
     ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
                  &log_probs, merge_repeated_));
 
   beam_search.Reset();
 
   for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    log_prob_t(b, bp) = log_probs[bp];
   }
  }
 
  OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
              best_paths, &decoded_indices, &decoded_values,
              &decoded_shape));
 }
 
 private:
 CTCDecodeHelper decode_helper_;
 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
 bool merge_repeated_;
 int beam_width_;
 //新添两个数据成员,用于存储新加的参数
 int label_selection_size;
 float label_selection_margin;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};
 
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),
            CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:

cc_library(
  name = "ctc_decoder_with_param",
  srcs = [
      "new_beamsearch.cc"
      ] +
      glob(["boost_locale/**/*.hpp"]),
  includes = ["boost_locale"],
  copts = ["-std=c++11"],
  deps = ["//tensorflow/core:core",
      "//tensorflow/core/util/ctc",
      "//third_party/eigen3",
  ],
)

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,
          top_paths=1, merge_repeated=True):
  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
    decode_param_op_module.ctc_beam_search_decoder_with_param(
      inputs, sequence_length, beam_width=beam_width,
      top_paths=top_paths, merge_repeated=merge_repeated,
      label_selection_size=40, label_selection_margin=0.99))
  return (
    [tf.SparseTensor(ix, val, shape) for (ix, val, shape)
     in zip(decoded_ixs, decoded_vals, decoded_shapes)],
    log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

以上这篇TensorFlow实现自定义Op方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
王纯业的Python学习笔记 下载
Feb 10 Python
python实现倒计时的示例
Feb 14 Python
Django查找网站项目根目录和对正则表达式的支持
Jul 15 Python
Python爬虫爬验证码实现功能详解
Apr 14 Python
Python编程中flask的简介与简单使用
Dec 28 Python
解决Python计算矩阵乘向量,矩阵乘实数的一些小错误
Aug 26 Python
python实现KNN分类算法
Oct 16 Python
Python中的四种交换数值的方法解析
Nov 18 Python
python爬虫开发之urllib模块详细使用方法与实例全解
Mar 09 Python
python 实现图片裁剪小工具
Feb 02 Python
Python数据分析之pandas函数详解
Apr 21 Python
Django模型层实现多表关系创建和多表操作
Jul 21 Python
tensorflow使用指定gpu的方法
Feb 04 #Python
TensorFlow梯度求解tf.gradients实例
Feb 04 #Python
基于TensorFlow中自定义梯度的2种方式
Feb 04 #Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
You might like
php Undefined index和Undefined variable的解决方法
2008/03/27 PHP
PHP stream_context_create()作用和用法分析
2011/03/29 PHP
php防止站外远程提交表单的方法
2014/10/20 PHP
浅析PHP中call user func()函数及如何使用call user func调用自定义函数
2015/11/05 PHP
PHP使用xpath解析XML的方法详解
2017/05/20 PHP
baidu博客的编辑友情链接的新的层窗口!经典~支持【FF】
2007/02/09 Javascript
用JavaScript对JSON进行模式匹配 (Part 2 - 实现)
2010/07/17 Javascript
JQuery对checkbox操作 (循环获取)
2011/05/20 Javascript
php,js,css字符串截取的办法集锦
2014/09/26 Javascript
jquery中EasyUI实现异步树
2015/03/01 Javascript
详解JavaScript对Date对象的操作问题(生成一个倒数7天的数组)
2015/10/01 Javascript
实现非常简单的js双向数据绑定
2015/11/06 Javascript
移动端js触摸事件详解
2016/09/18 Javascript
seajs学习教程之基础篇
2016/10/20 Javascript
Bootstrap Table使用整理(二)
2017/06/09 Javascript
用原生 JS 实现 innerHTML 功能实例详解
2019/04/03 Javascript
Vue实现腾讯云点播视频上传功能的实现代码
2020/08/17 Javascript
实例讲解Python中的私有属性
2014/08/21 Python
Python Web框架Flask中使用七牛云存储实例
2015/02/08 Python
Django使用Celery异步任务队列的使用
2018/03/13 Python
用scikit-learn和pandas学习线性回归的方法
2019/06/21 Python
深入了解Python iter() 方法的用法
2019/07/11 Python
Python实现网页截图(PyQT5)过程解析
2019/08/12 Python
对python中的装包与解包实例详解
2019/08/24 Python
Python selenium页面加载慢超时的解决方案
2020/03/18 Python
解决安装新版PyQt5、PyQT5-tool后打不开并Designer.exe提示no Qt platform plugin的问题
2020/04/24 Python
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
美国牙科折扣计划:DentalPlans.com
2019/08/26 全球购物
学校经典推荐信
2013/10/30 职场文书
播音主持专业个人自我评价
2014/01/09 职场文书
委托证明的格式
2014/01/10 职场文书
党员一句话承诺大全
2014/03/28 职场文书
小学大队干部竞选稿
2015/11/20 职场文书
倡议书怎么写?
2019/04/11 职场文书
ORACLE查看当前账号的相关信息
2021/06/18 Oracle
Java Socket实现多人聊天系统
2021/07/15 Java/Android