// Copyright (c) ONNX Project Contributors /* * SPDX-License-Identifier: Apache-2.0 */ // ATTENTION: The code in this file is highly EXPERIMENTAL. // Adventurous users should note that the APIs will probably change. #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "onnx/common/array_ref.h" #include "onnx/common/assertions.h" #include "onnx/common/common.h" #include "onnx/common/graph_node_list.h" #include "onnx/common/interned_strings.h" #include "onnx/common/tensor.h" #include "onnx/string_utils.h" #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \ TypeName(const TypeName&) = delete; \ TypeName& operator=(const TypeName&) = delete namespace ONNX_NAMESPACE { namespace { // internal/private API std::string toVarName(size_t i) { std::ostringstream oss; oss << "_v_" << i; return oss.str(); } } // namespace // Graph represents one "function" of computation. // It uses a simple ownership model where the graph owns all the nodes inside it. // All references inside the graph are raw pointers. // Destroying the Graph will invalidate any pointers to nodes in the graph. struct Graph; // Node is the base class of the IR graph. It represents one computation // and dependencies on a list of Values. The "prim-ops", so to speak. struct Node; // A Value represents an input or output to node that is either a // Tensor or an opaque Handle object, as determined by type(). struct Value; class ResourceGuard final { std::function destructor_; bool released_; public: ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard); explicit ResourceGuard(std::function destructor) : destructor_(std::move(destructor)), released_(false) {} ResourceGuard(ResourceGuard&& other) = default; ResourceGuard& operator=(ResourceGuard&& other) = default; ~ResourceGuard() { if (!released_) destructor_(); } void release() { released_ = true; } }; struct Dimension final { Dimension() : is_unknown(true), is_int(false), dim(-1) {} Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {} // NOLINT Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {} // NOLINT bool is_unknown; bool is_int; int64_t dim; std::string param; }; enum class AttributeKind : uint8_t { // float, float list, int, int list, string, string list, // tensor, tensor list, subgraph, subgraph list. type proto, type proto list f, fs, i, is, s, ss, t, ts, g, gs, tp, tps }; static inline const char* toString(AttributeKind kind) { static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"}; ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*)); return names[int(kind)]; } struct AttributeValue { explicit AttributeValue(Symbol name) : name(name) {} using Ptr = std::unique_ptr; Symbol name; virtual AttributeKind kind() const = 0; virtual Ptr clone() const = 0; virtual ~AttributeValue() = default; }; template struct ScalarAttributeValue final : public AttributeValue { using ConstructorType = const T&; using ValueType = T; ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {} ValueType& value() { return value_; } virtual Ptr clone() const override { return Ptr(new ScalarAttributeValue(name, value_)); } virtual AttributeKind kind() const override { return Kind; } private: ValueType value_; }; template struct VectorAttributeValue final : public AttributeValue { using ConstructorType = const std::vector&&; using ValueType = std::vector; VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { return value_; } virtual AttributeKind kind() const override { return Kind; } virtual std::unique_ptr clone() const override { auto copy = value_; return Ptr(new VectorAttributeValue(name, std::move(copy))); } private: ValueType value_; }; using FloatAttr = ScalarAttributeValue; using FloatsAttr = VectorAttributeValue; using IntAttr = ScalarAttributeValue; using IntsAttr = VectorAttributeValue; using StringAttr = ScalarAttributeValue; using StringsAttr = VectorAttributeValue; using TensorAttr = ScalarAttributeValue; using TensorsAttr = VectorAttributeValue; using GraphAttr = ScalarAttributeValue, AttributeKind::g>; using GraphsAttr = VectorAttributeValue, AttributeKind::gs>; using TypeProtoAttr = ScalarAttributeValue; using TypeProtosAttr = VectorAttributeValue; // CRTP so that Node which inherits Attributes can be return for // method chaining e.g: // Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5); // we return Derived* pointers because Nodes are normally held as pointers. template struct Attributes { Attributes() {} void copyAttributes(const Attributes& rhs) { values_.clear(); values_.reserve(rhs.values_.size()); for (auto& i : rhs.values_) { values_.push_back(i->clone()); } } bool hasAttribute(Symbol name) const { return find(name, false) != values_.end(); } AttributeKind kindOf(Symbol name) const { return (*find(name, true))->kind(); } Derived* removeAttribute(Symbol name) { values_.erase(find(name, true)); return This(); } bool hasAttributes() const { return !values_.empty(); } // The names are returned in order, since name actually is the index. std::vector attributeNames() const { std::vector names; names.reserve(values_.size()); for (auto& a : values_) names.push_back(a->name); return names; } #define CREATE_ACCESSOR(Kind, method) \ Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ return set(name, std::forward(v)); \ } \ const Kind##Attr::ValueType& method(Symbol name) const { \ return get(name); \ } CREATE_ACCESSOR(Float, f) CREATE_ACCESSOR(Floats, fs) CREATE_ACCESSOR(String, s) CREATE_ACCESSOR(Strings, ss) CREATE_ACCESSOR(Int, i) CREATE_ACCESSOR(Ints, is) CREATE_ACCESSOR(Tensor, t) CREATE_ACCESSOR(Tensors, ts) CREATE_ACCESSOR(Graph, g) CREATE_ACCESSOR(Graphs, gs) CREATE_ACCESSOR(TypeProto, tp) CREATE_ACCESSOR(TypeProtos, tps) #undef CREATE_ACCESSOR private: Derived* This() { return static_cast(this); } template Derived* set(Symbol name, typename T::ConstructorType v) { auto it = find(name, false); auto nv = AVPtr(new T(name, std::forward(v))); if (it == values_.end()) { values_.push_back(std::move(nv)); } else { *it = std::move(nv); } return This(); } template typename T::ValueType& get(Symbol name) const { auto it = find(name, true); T* child = static_cast(it->get()); return child->value(); } using AVPtr = AttributeValue::Ptr; // NB: For determinism, we use a vector rather than a hash map. This does // mean that lookups are O(n), so you shouldn't use Attributes to store // a big pile of messages. std::vector values_; using iterator = std::vector::iterator; iterator find(Symbol name, bool required) { auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); ONNX_ASSERT(!required || it != values_.end()); return it; } using const_iterator = std::vector::const_iterator; const_iterator find(Symbol name, bool required) const { auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); ONNX_ASSERTM( !required || it != values_.end(), "%s:%u: %s: required undefined attribute '%s'", __FILE__, __LINE__, __func__, name.toString()); return it; } }; // Each use is represented by this type, see Node::uses() // 'user' is the consumer of the value, offset is the index into // 'user's input this where the produces will be found. struct Use final { Use(Node* user, size_t offset) : user(user), offset(offset) {} Node* user; size_t offset; }; static inline bool operator==(const Use& a, const Use& b) { return a.user == b.user && a.offset == b.offset; } // the list types are intentionally simple, but we type-def // them here so if we need to change them, refactoring will be easier using node_list = std::vector; using value_list = std::vector; using use_list = std::vector; using NodeKind = Symbol; struct Value final { ONNX_DISALLOW_COPY_AND_ASSIGN(Value); Value(Node* node_, size_t offset_); Value(Value&&) = default; Value& operator=(Value&&) = default; ~Value() = default; private: friend struct Node; friend struct Graph; Node* node_; size_t offset_; size_t unique_ = 0; // unique id size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... use_list uses_in_current_graph_; bool has_unique_name_; std::string unique_name_; int32_t elem_type_; bool has_sizes_; std::vector sizes_; public: Value* setElemType(int32_t elem_type) { elem_type_ = elem_type; return this; } int32_t elemType() const { return elem_type_; } bool has_sizes() const { return has_sizes_; } Value* setSizes(std::vector sizes) { has_sizes_ = true; sizes_ = std::move(sizes); return this; } Value* wipeSizes() { has_sizes_ = false; sizes_ = std::vector(); return this; } const std::vector& sizes() const { return sizes_; } size_t unique() const { return unique_; } bool has_unique_name() const { return has_unique_name_; } std::string uniqueName() const { if (has_unique_name()) return unique_name_; return toVarName(unique()); } Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true); Value* setStage(size_t s) { stage_ = s; return this; } size_t stage() const { return stage_; } Node* node() { return node_; } size_t offset() const { return offset_; } const Node* node() const { return node_; } Graph* owningGraph(); const Graph* owningGraph() const; // TODO: make this more const correct const use_list uses() const; // Replaces all uses of this node with 'newValue'. // // Given: %3 = f(%1, %2) // %4 = g(%3) // %5 = h(%3, %3) // Execute: %3.replaceAllUsesWith(%6) // Result: %3 = f(%1, %2) // %4 = g(%6) // %5 = h(%6, %6) void replaceAllUsesWith(Value* newValue); Value* copyMetadata(Value* from) { setElemType(from->elemType()); setSizes(from->sizes()); if (from->has_unique_name()) { setUniqueName(from->uniqueName()); } return this; } }; struct Node : public Attributes { ONNX_DISALLOW_COPY_AND_ASSIGN(Node); friend struct Graph; friend struct Value; friend graph_node_list; friend const_graph_node_list; friend graph_node_list_iterator; friend const_graph_node_list_iterator; private: // each node but Return/Param // is associated with exactly one place in the node list... // of the graph_ // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the // list such that the list never has null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev pointer // using an array to allow the same iterator class for forward and reverse node lists // This list represents a topological sort Node* next_in_graph[2] = {nullptr, nullptr}; Node*& next() { return next_in_graph[kNextDirection]; } Node*& prev() { return next_in_graph[kPrevDirection]; } Node* const& next() const { return next_in_graph[kNextDirection]; } Node* const& prev() const { return next_in_graph[kPrevDirection]; } const NodeKind kind_; std::vector inputs_; std::vector outputs_; Graph* graph_; size_t stage_; bool has_name_; std::string name_; bool has_domain_; std::string domain_; bool has_doc_string_; std::string doc_string_; bool has_overload_; std::string overload_; protected: Node(Graph* graph_, NodeKind kind_); // defined after graph public: bool has_name() const { return has_name_; } const std::string& name() const { return name_; } void setName(std::string name) { has_name_ = true; name_ = std::move(name); } bool has_domain() const { return has_domain_; } const std::string& domain() const { return domain_; } void setDomain(std::string domain) { has_domain_ = true; domain_ = std::move(domain); } bool has_overload() const { return has_overload_; } const std::string& overload() const { return overload_; } void setOverload(std::string overload) { has_overload_ = true; overload_ = std::move(overload); } bool has_doc_string() const { return has_doc_string_; } const std::string& docString() const { return doc_string_; } void setDocString(std::string doc_string) { has_doc_string_ = true; doc_string_ = std::move(doc_string); } NodeKind kind() const { return kind_; } Graph* owningGraph() { return graph_; } const Graph* owningGraph() const { return graph_; } size_t stage() const { return stage_; } Node* setStage(size_t s) { stage_ = s; return this; } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no // way to soundly cast to std::vector (an insane // implementation of std::vector could make this representationally // different.) ArrayRef inputs() { return inputs_; } ArrayRef inputs() const { // Vectors are not convertible in const-ness of elements, but // raw pointers are. return {inputs_.data(), inputs_.size()}; } // NB: This returns an ArrayRef; that means that it will // get invalidated if you resize inputs (e.g., using addInput) // We can't return a std::vector& because there's no // way to soundly cast to std::vector (an insane // implementation of std::vector could make this representationally // different.) ArrayRef outputs() { return outputs_; } ArrayRef outputs() const { // Vectors are not convertible in const-ness of elements, but // raw pointers are. return {outputs_.data(), outputs_.size()}; } bool hasUses() const { for (auto o : outputs()) { if (!o->uses().empty()) return true; } return false; } void replaceAllUsesWith(Node* n) { ONNX_ASSERT(outputs().size() == n->outputs().size()); size_t nOutputs = outputs().size(); for (size_t i = 0; i < nOutputs; i++) { outputs()[i]->replaceAllUsesWith(n->outputs()[i]); } } // lots of things like chunk have a single input or single output, so we have a // helper to make accessing it easier Value* input() { ONNX_ASSERT(inputs_.size() == 1); return inputs_.at(0); } Value* output() { ONNX_ASSERT(outputs_.size() == 1); return outputs_.at(0); } const Value* input() const { ONNX_ASSERT(inputs_.size() == 1); return inputs_.at(0); } Value* output() const { ONNX_ASSERT(outputs_.size() == 1); return outputs_.at(0); } // Access a particular input. This is a checked index. Value* input(size_t i) { return inputs_.at(i); } const Value* input(size_t i) const { return inputs_.at(i); } // Graphs // Note [Topological invariant] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // We always maintain an up-to-date topological ordering of all nodes via // the next()/prev() links. All transformations to graphs must preserve // this topological ordering: for example, it is only valid to 'addInput' // with an input which is topologically before the current node. // // Usually, it is obvious whether or not topological order is maintained; // for example, if you are adding nodes to the end of the topsort, it's // impossible for them to refer to inputs that are not in the topsort. // If it is not obvious, please comment accordingly. // Add 'node' as an input to 'this' at the end of existing // arguments. Returns the added node for ease of chaining. // // Given: %3 = f(%1, %2) // Execute: %3.addInput(%4) // Result: %3 = f(%1, %2, %4) Value* addInput(Value* node) { ONNX_ASSERT(graph_ == node->owningGraph()); node->uses_in_current_graph_.emplace_back(this, inputs_.size()); inputs_.push_back(node); return node; } // Replace the input of 'this' at position 'i' with // 'newValue', returning the old node. // // Given: %3 = f(%1, %2) // Execute: %3.replaceInput(1, %4) // Result: %3 = f(%1, %4) Value* replaceInput(size_t i, Value* newValue) { ONNX_ASSERT(newValue->owningGraph() == graph_); Value* old = dropInput(i); inputs_[i] = newValue; newValue->uses_in_current_graph_.emplace_back(this, i); return old; } // Replace all occurrences of 'from' in the inputs of this // node with 'to'. Corresponds to llvm's replaceUsesOfWith. // // Given: %3 = f(%1, %2, %1) // Execute: %3.replaceInputWith(%1, %4) // Result: %3 = f(%4, %2, %4) void replaceInputWith(Value* from, Value* to) { ONNX_ASSERT(from->owningGraph() == graph_); ONNX_ASSERT(to->owningGraph() == graph_); size_t i = 0; for (auto input : inputs()) { if (input == from) replaceInput(i, to); i++; } } Value* addOutput() { outputs_.push_back(new Value(this, outputs_.size())); return outputs_.back(); } void eraseOutput(size_t i); // Insert unattached 'this' node after 'n' in the topological order. // Returns this (for chaining). // // Given: %3 = f(%1, %2) // %4 = g(%3) // and unattached: %5 = h(%1) // Execute: %5.insertBefore(%4) // Result: %3 = f(%1, %2) // %5 = h(%1) // %4 = g(%3) Node* insertBefore(Node* n) { ONNX_ASSERT(n->inGraphList()); insertAfter(n->prev()); return this; } // Insert unattached 'this' node after 'n' in the topological order. // Returns this (for chaining). // // Given: %3 = f(%1, %2) // %4 = g(%3) // and unattached: %5 = h(%1) // Execute: %5.insertAfter(%4) // Result: %3 = f(%1, %2) // %4 = g(%3) // %5 = h(%1) Node* insertAfter(Node* n) { ONNX_ASSERT(!inGraphList() && n->inGraphList()); Node* next = n->next(); n->next() = this; this->prev() = n; this->next() = next; next->prev() = this; return this; } // Move 'this' (already in the graph) after 'n' in the topological order. // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %2.moveAfter(%3) // Result: %3 = g(%1) // %2 = f(%1) // void moveAfter(Node* n) { removeFromList(); insertAfter(n); } // Move a node 'n' (already in the graph) before 'this' in the topological order. // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %3.moveBefore(%2) // Result: %3 = g(%1) // %2 = f(%1) void moveBefore(Node* n) { removeFromList(); insertBefore(n); } // Remove the input at 'i' from this node. // // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling // removeInput. // // Given: %3 = f(%1, %2) // Execute: %3.removeInput(1) // Result: %3 = f(%1) void removeInput(size_t i) { dropInput(i); // everything after this input shifts left, // so we need to update their use offsets to match for (size_t j = i + 1; j < inputs_.size(); j++) { auto it = findUseForInput(j); it->offset--; } inputs_.erase(inputs_.begin() + i); } // Remove all inputs from a node. // // Given: %3 = f(%1, %2) // Execute: %3.removeAllInputs() // Result: %3 = f() void removeAllInputs() { for (size_t i = 0; i < inputs().size(); ++i) dropInput(i); inputs_.clear(); } // Check whether this node is before node n in the graph. bool isBefore(Node* n); // iterators of the node list starting at this node // useful for resuming a search starting at this node graph_node_list_iterator iterator(); graph_node_list_iterator reverseIterator(); const_graph_node_list_iterator iterator() const; const_graph_node_list_iterator reverseIterator() const; // Remove 'this' from the instruction list and deallocate it. // // Invariant: no outputs of 'this' may have any uses. // // Given: %2 = f(%1) // %3 = g(%1) // Execute: %2.destroy() // Result: %3 = g(%1) void destroy(); // Dynamically cast this node to the subclass indicated by the // template variable, returning nullptr if the cast is invalid.. // // Example usage: if(auto s = n.cast