Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
6.29 kB
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <mutex>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "attr_proto_util.h"
#include "onnx/common/constants.h"
#include "onnx/common/status.h"
#include "onnx/defs/parser.h"
#include "onnx/defs/schema.h"
#include "tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// Helper function to expand a function node given the function proto
void FunctionExpandHelper(
const NodeProto& node,
const FunctionProto& func,
GraphProto& g,
const std::string& node_prefix = "");
class FunctionBodyHelper {
public:
struct AttributeProtoWrapper {
AttributeProto proto;
AttributeProtoWrapper() {}
AttributeProtoWrapper(const AttributeProto& attr_prot) {
proto = attr_prot;
}
template <typename T>
AttributeProtoWrapper(const std::string& attr_name, const T& value) {
proto = MakeAttribute(attr_name, value);
}
};
struct NodeDef {
NodeDef(
std::vector<std::string> outputs,
std::string op_type,
std::vector<std::string> inputs,
std::vector<AttributeProtoWrapper> attributes = {},
std::string domain = "")
: outputs(std::move(outputs)),
op_type(std::move(op_type)),
inputs(std::move(inputs)),
attributes(std::move(attributes)),
domain(std::move(domain)) {}
std::vector<std::string> outputs;
std::string op_type;
std::vector<std::string> inputs;
std::vector<AttributeProtoWrapper> attributes;
std::string domain;
};
/*
BuildNodes() is an utility function for easily define a Function Body.
To build a simple node:
{{"Z"}, "Add", {"X", "Y"}} represents Z = Add(X,Y)
To build a node with attribute:
{{"Y"}, "Concat", {"X1", "X2", "X3"}, {{"axis", 1}}}
represents Y = Concat(X1,X2,X3) with axis = 1
The attribute type are infered from the attribute value's c++ type
Supported value types are
int64_t -> int, vector<int64_t> -> ints
float -> float, vector<float> -> floats
string -> string, vector<string> ->strings
For refering an attribute from parent, use:
{MakeRefAttribute("axes", AttributeProto::INTS)}}
To build a node which belongs to a domain other than onnx standard domain:
{{"Z"}, "Foo", {"X", "Y"}, "customdomain"} represents Z = customdomain.Foo(X,Y)
or
{{"Y"}, "Bar", {"X1", "X2", "X3"}, {{"axis", 1}}, "customdomain"}
represents Y = customdomain.Bar(X1,X2,X3) with axis = 1
For more examples, please find the references of this function
*/
static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs);
static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs);
static bool BuildFunctionProto(
FunctionProto& functionProto,
const OpSchema& schema,
const std::vector<NodeDef>& node_defs,
const std::vector<OperatorSetIdProto>& relied_opsets);
template <typename T>
static NodeDef Const(const std::string& name, const T& value) {
return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(value)}}};
}
template <typename T>
static NodeDef Const(const std::string& name, const std::vector<T>& values) {
return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(values)}}};
}
};
class FunctionBuilder {
public:
FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
FunctionBuilder& Add(const char* nodes_txt) {
OnnxParser parser(nodes_txt);
auto& nodes = *funProto.mutable_node();
while (!parser.EndOfInput()) {
auto status = parser.Parse(*nodes.Add());
if (!status.IsOK())
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
return *this;
}
FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
OnnxParser parser(node_txt);
auto& node = *funProto.add_node();
auto status = parser.Parse(node);
if (!status.IsOK()) {
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
if (!parser.EndOfInput()) {
ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
}
*node.add_attribute() = attr;
return *this;
}
template <typename T>
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) {
return Add(node_txt, MakeAttribute(attr_name, attr_value));
}
FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
// Creates a scalar constant (a tensor of rank zero).
template <typename T>
FunctionBuilder& Const(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
}
// Creates a 1D tensor constant consisting of a single value.
template <typename T>
FunctionBuilder& Const1D(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(const_value);
tensor.add_dims(1);
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
// Creates a 1D tensor constant consisting of zero or more values.
template <typename T>
FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(values);
tensor.add_dims(values.size()); // Treat as 1D tensor.
return Add(constant_op.c_str(), MakeAttribute("value", tensor));
}
FunctionBuilder& AddOpset(const char* domain, int version) {
auto* opset = funProto.add_opset_import();
opset->set_domain(domain);
opset->set_version(version);
return *this;
}
private:
FunctionProto& funProto;
};
} // namespace ONNX_NAMESPACE