File size: 3,215 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright (c) ONNX Project Contributors

/*

 * SPDX-License-Identifier: Apache-2.0

 */

#pragma once
#include "onnx/common/common.h"
#include "onnx/onnx_pb.h"

namespace ONNX_NAMESPACE {
namespace internal {

// Visitor: A readonly visitor class for ONNX Proto objects.
// This class is restricted to Nodes, Graphs, Attributes, and Functions.
// The VisitX methods invoke ProcessX, and if that returns true, will
// continue to visit all children of the X.

struct Visitor {
  virtual void VisitGraph(const GraphProto& graph) {
    if (ProcessGraph(graph))
      for (auto& node : graph.node())
        VisitNode(node);
  }

  virtual void VisitFunction(const FunctionProto& function) {
    if (ProcessFunction(function))
      for (auto& node : function.node())
        VisitNode(node);
  }

  virtual void VisitNode(const NodeProto& node) {
    if (ProcessNode(node)) {
      for (auto& attr : node.attribute()) {
        VisitAttribute(attr);
      }
    }
  }

  virtual void VisitAttribute(const AttributeProto& attr) {
    if (ProcessAttribute(attr)) {
      if (attr.has_g()) {
        VisitGraph(attr.g());
      }
      for (auto& graph : attr.graphs())
        VisitGraph(graph);
    }
  }

  virtual bool ProcessGraph(const GraphProto& graph) {
    ONNX_UNUSED_PARAMETER(graph);
    return true;
  }

  virtual bool ProcessFunction(const FunctionProto& function) {
    ONNX_UNUSED_PARAMETER(function);
    return true;
  }

  virtual bool ProcessNode(const NodeProto& node) {
    ONNX_UNUSED_PARAMETER(node);
    return true;
  }

  virtual bool ProcessAttribute(const AttributeProto& attr) {
    ONNX_UNUSED_PARAMETER(attr);
    return true;
  }

  virtual ~Visitor() {}
};

// MutableVisitor: A version of Visitor that allows mutation of the visited objects.
struct MutableVisitor {
  virtual void VisitGraph(GraphProto* graph) {
    if (ProcessGraph(graph))
      for (auto& node : *(graph->mutable_node()))
        VisitNode(&node);
  }

  virtual void VisitFunction(FunctionProto* function) {
    if (ProcessFunction(function))
      for (auto& node : *(function->mutable_node()))
        VisitNode(&node);
  }

  virtual void VisitNode(NodeProto* node) {
    if (ProcessNode(node)) {
      for (auto& attr : *(node->mutable_attribute())) {
        VisitAttribute(&attr);
      }
    }
  }

  virtual void VisitAttribute(AttributeProto* attr) {
    if (ProcessAttribute(attr)) {
      if (attr->has_g()) {
        VisitGraph(attr->mutable_g());
      }
      for (auto& graph : *(attr->mutable_graphs()))
        VisitGraph(&graph);
    }
  }

  virtual bool ProcessGraph(GraphProto* graph) {
    ONNX_UNUSED_PARAMETER(graph);
    return true;
  }

  virtual bool ProcessFunction(FunctionProto* function) {
    ONNX_UNUSED_PARAMETER(function);
    return true;
  }

  virtual bool ProcessNode(NodeProto* node) {
    ONNX_UNUSED_PARAMETER(node);
    return true;
  }

  virtual bool ProcessAttribute(AttributeProto* attr) {
    ONNX_UNUSED_PARAMETER(attr);
    return true;
  }

  virtual ~MutableVisitor() {}
};

} // namespace internal
} // namespace ONNX_NAMESPACE