TensorFlow实现自定义Op方式


Posted in Python onFebruary 04, 2020

『写在前面』

以CTC Beam search decoder为例,简单整理一下TensorFlow实现自定义Op的操作流程。

基本的流程

1. 定义Op接口

#include "tensorflow/core/framework/op.h"
 
REGISTER_OP("Custom")  
  .Input("custom_input: int32")
  .Output("custom_output: int32");

2. 为Op实现Compute操作(CPU)或实现kernel(GPU)

#include "tensorflow/core/framework/op_kernel.h"
 
using namespace tensorflow;
 
class CustomOp : public OpKernel{
  public:
  explicit CustomOp(OpKernelConstruction* context) : OpKernel(context) {}
  void Compute(OpKernelContext* context) override {
  // 获取输入 tensor.
  const Tensor& input_tensor = context->input(0);
  auto input = input_tensor.flat<int32>();
  // 创建一个输出 tensor.
  Tensor* output_tensor = NULL;
  OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                           &output_tensor));
  auto output = output_tensor->template flat<int32>();
  //进行具体的运算,操作input和output
  //……
 }
};

3. 将实现的kernel注册到TensorFlow系统中

REGISTER_KERNEL_BUILDER(Name("Custom").Device(DEVICE_CPU), CustomOp);

CTCBeamSearchDecoder自定义

该Op对应TensorFlow中的源码部分

Op接口的定义:

tensorflow-master/tensorflow/core/ops/ctc_ops.cc

CTCBeamSearchDecoder本身的定义:

tensorflow-master/tensorflow/core/util/ctc/ctc_beam_search.cc

Op-Class的封装与Op注册:

tensorflow-master/tensorflow/core/kernels/ctc_decoder_ops.cc

基于源码修改的Op

#include <algorithm>
#include <vector>
#include <cmath>
 
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
 
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
 
namespace tf = tensorflow;
using tf::shape_inference::DimensionHandle;
using tf::shape_inference::InferenceContext;
using tf::shape_inference::ShapeHandle;
 
using namespace tensorflow;
 
REGISTER_OP("CTCBeamSearchDecoderWithParam")
  .Input("inputs: float")
  .Input("sequence_length: int32")
  .Attr("beam_width: int >= 1")
  .Attr("top_paths: int >= 1")
  .Attr("merge_repeated: bool = true")
  //新添加了两个参数
  .Attr("label_selection_size: int >= 0 = 0") 
  .Attr("label_selection_margin: float") 
  .Output("decoded_indices: top_paths * int64")
  .Output("decoded_values: top_paths * int64")
  .Output("decoded_shape: top_paths * int64")
  .Output("log_probability: float")
  .SetShapeFn([](InferenceContext* c) {
   ShapeHandle inputs;
   ShapeHandle sequence_length;
 
   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
 
   // Get batch size from inputs and sequence_length.
   DimensionHandle batch_size;
   TF_RETURN_IF_ERROR(
     c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
 
   int32 top_paths;
   TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
 
   // Outputs.
   int out_idx = 0;
   for (int i = 0; i < top_paths; ++i) { // decoded_indices
    c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
   }
   for (int i = 0; i < top_paths; ++i) { // decoded_values
    c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
   }
   ShapeHandle shape_v = c->Vector(2);
   for (int i = 0; i < top_paths; ++i) { // decoded_shape
    c->set_output(out_idx++, shape_v);
   }
   c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
   return Status::OK();
  });
 
typedef Eigen::ThreadPoolDevice CPUDevice;
 
inline float RowMax(const TTypes<float>::UnalignedConstMatrix& m, int r,
          int* c) {
 *c = 0;
 CHECK_LT(0, m.dimension(1));
 float p = m(r, 0);
 for (int i = 1; i < m.dimension(1); ++i) {
  if (m(r, i) > p) {
   p = m(r, i);
   *c = i;
  }
 }
 return p;
}
 
class CTCDecodeHelper {
 public:
 CTCDecodeHelper() : top_paths_(1) {}
 
 inline int GetTopPaths() const { return top_paths_; }
 void SetTopPaths(int tp) { top_paths_ = tp; }
 
 Status ValidateInputsGenerateOutputs(
   OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len,
   Tensor** log_prob, OpOutputList* decoded_indices,
   OpOutputList* decoded_values, OpOutputList* decoded_shape) const {
  Status status = ctx->input("inputs", inputs);
  if (!status.ok()) return status;
  status = ctx->input("sequence_length", seq_len);
  if (!status.ok()) return status;
 
  const TensorShape& inputs_shape = (*inputs)->shape();
 
  if (inputs_shape.dims() != 3) {
   return errors::InvalidArgument("inputs is not a 3-Tensor");
  }
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
 
  if (max_time == 0) {
   return errors::InvalidArgument("max_time is 0");
  }
  if (!TensorShapeUtils::IsVector((*seq_len)->shape())) {
   return errors::InvalidArgument("sequence_length is not a vector");
  }
 
  if (!(batch_size == (*seq_len)->dim_size(0))) {
   return errors::FailedPrecondition(
     "len(sequence_length) != batch_size. ", "len(sequence_length): ",
     (*seq_len)->dim_size(0), " batch_size: ", batch_size);
  }
 
  auto seq_len_t = (*seq_len)->vec<int32>();
 
  for (int b = 0; b < batch_size; ++b) {
   if (!(seq_len_t(b) <= max_time)) {
    return errors::FailedPrecondition("sequence_length(", b, ") <= ",
                     max_time);
   }
  }
 
  Status s = ctx->allocate_output(
    "log_probability", TensorShape({batch_size, top_paths_}), log_prob);
  if (!s.ok()) return s;
 
  s = ctx->output_list("decoded_indices", decoded_indices);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_values", decoded_values);
  if (!s.ok()) return s;
  s = ctx->output_list("decoded_shape", decoded_shape);
  if (!s.ok()) return s;
 
  return Status::OK();
 }
 
 // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 Status StoreAllDecodedSequences(
   const std::vector<std::vector<std::vector<int> > >& sequences,
   OpOutputList* decoded_indices, OpOutputList* decoded_values,
   OpOutputList* decoded_shape) const {
  // Calculate the total number of entries for each path
  const int64 batch_size = sequences.size();
  std::vector<int64> num_entries(top_paths_, 0);
 
  // Calculate num_entries per path
  for (const auto& batch_s : sequences) {
   CHECK_EQ(batch_s.size(), top_paths_);
   for (int p = 0; p < top_paths_; ++p) {
    num_entries[p] += batch_s[p].size();
   }
  }
 
  for (int p = 0; p < top_paths_; ++p) {
   Tensor* p_indices = nullptr;
   Tensor* p_values = nullptr;
   Tensor* p_shape = nullptr;
 
   const int64 p_num = num_entries[p];
 
   Status s =
     decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices);
   if (!s.ok()) return s;
   s = decoded_values->allocate(p, TensorShape({p_num}), &p_values);
   if (!s.ok()) return s;
   s = decoded_shape->allocate(p, TensorShape({2}), &p_shape);
   if (!s.ok()) return s;
 
   auto indices_t = p_indices->matrix<int64>();
   auto values_t = p_values->vec<int64>();
   auto shape_t = p_shape->vec<int64>();
 
   int64 max_decoded = 0;
   int64 offset = 0;
 
   for (int64 b = 0; b < batch_size; ++b) {
    auto& p_batch = sequences[b][p];
    int64 num_decoded = p_batch.size();
    max_decoded = std::max(max_decoded, num_decoded);
    std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
    for (int64 t = 0; t < num_decoded; ++t, ++offset) {
     indices_t(offset, 0) = b;
     indices_t(offset, 1) = t;
    }
   }
 
   shape_t(0) = batch_size;
   shape_t(1) = max_decoded;
  }
  return Status::OK();
 }
 
 private:
 int top_paths_;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
};
 
// CTC beam search
class CTCBeamSearchDecoderWithParamOp : public OpKernel {
 public:
 explicit CTCBeamSearchDecoderWithParamOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
  OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
  //从参数列表中读取新添的两个参数
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_size", &label_selection_size));
  OP_REQUIRES_OK(ctx, ctx->GetAttr("label_selection_margin", &label_selection_margin));
  int top_paths;
  OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
  decode_helper_.SetTopPaths(top_paths);
 }
 
 void Compute(OpKernelContext* ctx) override {
  const Tensor* inputs;
  const Tensor* seq_len;
  Tensor* log_prob = nullptr;
  OpOutputList decoded_indices;
  OpOutputList decoded_values;
  OpOutputList decoded_shape;
  OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
              ctx, &inputs, &seq_len, &log_prob, &decoded_indices,
              &decoded_values, &decoded_shape));
 
  auto inputs_t = inputs->tensor<float, 3>();
  auto seq_len_t = seq_len->vec<int32>();
  auto log_prob_t = log_prob->matrix<float>();
 
  const TensorShape& inputs_shape = inputs->shape();
 
  const int64 max_time = inputs_shape.dim_size(0);
  const int64 batch_size = inputs_shape.dim_size(1);
  const int64 num_classes_raw = inputs_shape.dim_size(2);
  OP_REQUIRES(
    ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
    errors::InvalidArgument("num_classes cannot exceed max int"));
  const int num_classes = static_cast<const int>(num_classes_raw);
 
  log_prob_t.setZero();
 
  std::vector<TTypes<float>::UnalignedConstMatrix> input_list_t;
 
  for (std::size_t t = 0; t < max_time; ++t) {
   input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
                batch_size, num_classes);
  }
 
  ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_,
                      &beam_scorer_, 1 /* batch_size */,
                      merge_repeated_);
  //使用传入的两个参数进行Set
  beam_search.SetLabelSelectionParameters(label_selection_size, label_selection_margin);
  Tensor input_chip(DT_FLOAT, TensorShape({num_classes}));
  auto input_chip_t = input_chip.flat<float>();
 
  std::vector<std::vector<std::vector<int> > > best_paths(batch_size);
  std::vector<float> log_probs;
 
  // Assumption: the blank index is num_classes - 1
  for (int b = 0; b < batch_size; ++b) {
   auto& best_paths_b = best_paths[b];
   best_paths_b.resize(decode_helper_.GetTopPaths());
   for (int t = 0; t < seq_len_t(b); ++t) {
    input_chip_t = input_list_t[t].chip(b, 0);
    auto input_bi =
      Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
    beam_search.Step(input_bi);
   }
   OP_REQUIRES_OK(
     ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
                  &log_probs, merge_repeated_));
 
   beam_search.Reset();
 
   for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
    log_prob_t(b, bp) = log_probs[bp];
   }
  }
 
  OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
              best_paths, &decoded_indices, &decoded_values,
              &decoded_shape));
 }
 
 private:
 CTCDecodeHelper decode_helper_;
 ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_;
 bool merge_repeated_;
 int beam_width_;
 //新添两个数据成员,用于存储新加的参数
 int label_selection_size;
 float label_selection_margin;
 TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderWithParamOp);
};
 
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithParam").Device(DEVICE_CPU),
            CTCBeamSearchDecoderWithParamOp);

将自定义的Op编译成.so文件

在tensorflow-master目录下新建一个文件夹custom_op

cd custom_op

新建一个BUILD文件,并在其中添加如下代码:

cc_library(
  name = "ctc_decoder_with_param",
  srcs = [
      "new_beamsearch.cc"
      ] +
      glob(["boost_locale/**/*.hpp"]),
  includes = ["boost_locale"],
  copts = ["-std=c++11"],
  deps = ["//tensorflow/core:core",
      "//tensorflow/core/util/ctc",
      "//third_party/eigen3",
  ],
)

编译过程:

1. cd 到 tensorflow-master 目录下

2. bazel build -c opt --copt=-O3 //tensorflow:libtensorflow_cc.so //custom_op:ctc_decoder_with_param

3. bazel-bin/custom_op 目录下生成 libctc_decoder_with_param.so

在训练(预测)程序中使用自定义的Op

在程序中定义如下的方法:

decode_param_op_module = tf.load_op_library('libctc_decoder_with_param.so')
def decode_with_param(inputs, sequence_length, beam_width=100,
          top_paths=1, merge_repeated=True):
  decoded_ixs, decoded_vals, decoded_shapes, log_probabilities = (
    decode_param_op_module.ctc_beam_search_decoder_with_param(
      inputs, sequence_length, beam_width=beam_width,
      top_paths=top_paths, merge_repeated=merge_repeated,
      label_selection_size=40, label_selection_margin=0.99))
  return (
    [tf.SparseTensor(ix, val, shape) for (ix, val, shape)
     in zip(decoded_ixs, decoded_vals, decoded_shapes)],
    log_probabilities)

然后就可以像使用tf.nn.ctc_beam_search_decoder一样使用该Op了。

以上这篇TensorFlow实现自定义Op方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的文件夹清理程序分享
Nov 22 Python
使用Python程序抓取新浪在国内的所有IP的教程
May 04 Python
浅谈numpy库的常用基本操作方法
Jan 09 Python
Python深度优先算法生成迷宫
Jan 22 Python
Windows下anaconda安装第三方包的方法小结(tensorflow、gensim为例)
Apr 05 Python
解决pycharm py文件运行后停止按钮变成了灰色的问题
Nov 29 Python
python 爬取学信网登录页面的例子
Aug 13 Python
Django stark组件使用及原理详解
Aug 22 Python
python聚类算法解决方案(rest接口/mpp数据库/json数据/下载图片及数据)
Aug 28 Python
Pytorch之view及view_as使用详解
Dec 31 Python
TensorFlow固化模型的实现操作
May 26 Python
Pycharm快捷键配置详细整理
Oct 13 Python
tensorflow使用指定gpu的方法
Feb 04 #Python
TensorFlow梯度求解tf.gradients实例
Feb 04 #Python
基于TensorFlow中自定义梯度的2种方式
Feb 04 #Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
You might like
PHP中防止SQL注入实现代码
2011/02/19 PHP
yii2整合百度编辑器umeditor及umeditor图片上传问题的解决办法
2016/04/20 PHP
php合并数组并保留键值的实现方法
2018/03/12 PHP
laravel数据库查询结果自动转数组修改实例
2021/02/27 PHP
javascript模版引擎-tmpl的bug修复与性能优化分析
2011/10/23 Javascript
js获取GridView中行数据的两种方法 分享
2013/07/13 Javascript
javascript获取隐藏元素(display:none)的高度和宽度的方法
2014/06/06 Javascript
javascript框架设计读书笔记之数组的扩展与修复
2014/12/02 Javascript
jQuery实现鼠标滑过Div层背景变颜色的方法
2015/02/17 Javascript
举例讲解jQuery中可见性过滤选择器的使用
2016/04/18 Javascript
js判断手机浏览器操作系统和微信浏览器的方法
2016/04/30 Javascript
JS中数组重排序方法
2016/11/11 Javascript
JS简单判断字符在另一个字符串中出现次数的2种常用方法
2017/04/20 Javascript
在百度搜索结果中去除掉一些网站的资料(通过js控制不让显示)
2017/05/02 Javascript
js实现登录与注册界面
2017/11/01 Javascript
jQuery中库的引用方法
2018/01/06 jQuery
jquery+ajaxform+springboot控件实现数据更新功能
2018/01/22 jQuery
Vue数据监听方法watch的使用
2018/03/28 Javascript
Vue中使用Sortable的示例代码
2018/04/07 Javascript
Bootstrap table表格初始化表格数据的方法
2018/07/25 Javascript
NodeJS如何实现同步的方法示例
2018/08/24 NodeJs
js module大战
2019/04/19 Javascript
json解析大全 双引号、键值对不在一起的情况
2019/12/06 Javascript
在Django框架中编写Context处理器的方法
2015/07/20 Python
Python爬虫之模拟知乎登录的方法教程
2017/05/25 Python
儿童学习python的一些小技巧
2018/05/27 Python
Pytorch中的自动求梯度机制和Variable类实例
2020/02/29 Python
django haystack实现全文检索的示例代码
2020/06/24 Python
HTML5 本地存储之如果没有数据库究竟会怎样
2013/04/25 HTML / CSS
Revolution Beauty美国官网:英国知名化妆品网站
2018/07/23 全球购物
德国黑胶唱片、街头服装及运动鞋网上商店:HHV
2018/08/24 全球购物
奥地利领先的在线药房:SHOP APOTHEKE
2019/10/07 全球购物
《雷鸣电闪波尔卡》教学反思
2014/02/23 职场文书
五一劳动节演讲稿
2014/09/12 职场文书
导游词之凤凰古城
2019/10/22 职场文书
python基础之停用词过滤详解
2021/04/21 Python