【问题标题】:How to use Custom OP to build TensorFlow Graph in C++?如何使用自定义 OP 在 C++ 中构建 TensorFlow Graph?
【发布时间】:2018-11-20 00:14:55
【问题描述】:

从 TensorFlow 文档中,可以执行以下操作来使用固有 OP 构建图形

#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor.h"

int main() {
  using namespace tensorflow;
  using namespace tensorflow::ops;
  Scope root = Scope::NewRootScope();
  // Matrix A = [3 2; -1 0]
  auto A = Const(root, { {3.f, 2.f}, {-1.f, 0.f} });
  // Vector b = [3 5]
  auto b = Const(root, { {3.f, 5.f} });
  // v = Ab^T
  auto v = MatMul(root.WithOpName("v"), A, b, MatMul::TransposeB(true));
  std::vector<Tensor> outputs;
  ClientSession session(root);
  // Run and fetch v
  TF_CHECK_OK(session.Run({v}, &outputs));
  // Expect outputs[0] == [19; -3]
  LOG(INFO) << outputs[0].matrix<float>();
  return 0;
}

似乎MatMul 类是自动生成的,因为 github 源代码中没有tensorflow/cc/ops/math_ops.h。 如何为 here 的 ZeroOut OP 等自定义操作做同样的事情

【问题讨论】:

    标签: c++ tensorflow machine-learning


    【解决方案1】:

    here中的ZeroOut为例,你要做到以下几点

    class ZeroOut {
     public:
      ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x);
      operator ::tensorflow::Output() const { return y; }
      operator ::tensorflow::Input() const { return y; }
      ::tensorflow::Node* node() const { return y.node(); }
    
      ::tensorflow::Output y;
    };
    
    ZeroOut::ZeroOut(const ::tensorflow::Scope& scope, ::tensorflow::Input x) {
      if (!scope.ok()) return;
      auto _x = ::tensorflow::ops::AsNodeOut(scope, x);
      if (!scope.ok()) return;
      ::tensorflow::Node* ret;
      const auto unique_name = scope.GetUniqueNameForOp("ZeroOut");
      auto builder = ::tensorflow::NodeBuilder(unique_name, "ZeroOut")
                         .Input(_x)
      ;
      scope.UpdateBuilder(&builder);
      scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
      if (!scope.ok()) return;
      scope.UpdateStatus(scope.DoShapeInference(ret));
      this->y = Output(ret, 0);
    }
    

    然后你就可以用它来构建图了

    Scope root = Scope::NewRootScope();
    // Matrix A = [3 2; -1 0]
    auto A = Const(root, { {3, 2}, {-1, 0} });
    auto v = ZeroOut(root.WithOpName("v"), A);
    std::vector<Tensor> outputs;
    ClientSession session(root);
    // Run and fetch v
    TF_CHECK_OK(session.Run({v}, &outputs));
    LOG(INFO) << outputs[0].matrix<int>();
    

    注意:对于 TensorFlow 固有的 OP,ZeroOut class 之类的代码是由 bazel 规则自动生成的。如果我们只有几个自定义 OP,我们可以模仿那些代码(例如tensorflow/cc/ops/math_ops.h)来手写我们自己的类。

    【讨论】:

      猜你喜欢
      • 2021-04-12
      • 1970-01-01
      • 1970-01-01
      • 2017-11-02
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多