File size: 3,738 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*

 * SPDX-License-Identifier: Apache-2.0

 */

#include "onnx/defs/generator/utils.h"

#include <algorithm>
#include <cmath>

namespace ONNX_NAMESPACE {

void ConstantOpInference(InferenceContext& ctx) {
  auto* value = ctx.getAttribute("value");
  auto* sparse_value = ctx.getAttribute("sparse_value");
  auto* value_int = ctx.getAttribute("value_int");
  auto* value_ints = ctx.getAttribute("value_ints");
  auto* value_float = ctx.getAttribute("value_float");
  auto* value_floats = ctx.getAttribute("value_floats");
  auto* value_string = ctx.getAttribute("value_string");
  auto* value_strings = ctx.getAttribute("value_strings");

  std::vector<bool> non_null_attr = {
      (nullptr != value),
      (nullptr != sparse_value),
      (nullptr != value_int),
      (nullptr != value_ints),
      (nullptr != value_float),
      (nullptr != value_floats),
      (nullptr != value_string),
      (nullptr != value_strings)};

  if (std::count(non_null_attr.begin(), non_null_attr.end(), true) != 1) {
    fail_shape_inference(
        "One and only one of the attributes 'value', 'value_*' or 'sparse_value' must be specified for a Constant node.");
  }

  if (nullptr != value) {
    // OpSchema::Verify check ensures that the attribute value has_t():
    const TensorProto& tensor_proto = value->t();
    updateOutputElemType(ctx, 0, tensor_proto.data_type());
    updateOutputShape(ctx, 0, tensor_proto);
    return;
  }

  if (nullptr != value_int) {
    // OpSchema::Verify check ensures that the attribute value has_i():
    if (!value_int->has_i()) {
      fail_shape_inference("Attribute 'value_int' expect an integer.")
    }
    updateOutputElemType(ctx, 0, TensorProto::INT64);
    updateOutputShape(ctx, 0, TensorShapeProto());
    return;
  }

  if (nullptr != value_ints) {
    updateOutputElemType(ctx, 0, TensorProto::INT64);
    appendDim(getOutputShape(ctx, 0), value_ints->ints_size());
    return;
  }

  if (nullptr != value_float) {
    // OpSchema::Verify check ensures that the attribute value has_i():
    if (!value_float->has_f()) {
      fail_shape_inference("Attribute 'value_float' expect a float.");
    }
    updateOutputElemType(ctx, 0, TensorProto::FLOAT);
    updateOutputShape(ctx, 0, TensorShapeProto());
    return;
  }

  if (nullptr != value_floats) {
    updateOutputElemType(ctx, 0, TensorProto::FLOAT);
    appendDim(getOutputShape(ctx, 0), value_floats->floats_size());
    return;
  }

  if (nullptr != value_string) {
    // OpSchema::Verify check ensures that the attribute value has_i():
    if (!value_string->has_s()) {
      fail_shape_inference("Attribute 'value_string' expect a string.");
    }
    updateOutputElemType(ctx, 0, TensorProto::STRING);
    updateOutputShape(ctx, 0, TensorShapeProto());
    return;
  }

  if (nullptr != value_strings) {
    updateOutputElemType(ctx, 0, TensorProto::STRING);
    appendDim(getOutputShape(ctx, 0), value_strings->strings_size());
    return;
  }

  if (nullptr != sparse_value) {
    // OpSchema::Verify check ensures that the attribute value
    // has_sparse_tensor():
    const SparseTensorProto& sparse = sparse_value->sparse_tensor();
    // checker.cc::check_sparse_tensor checks that the sparse-value is
    // well-formed
    updateOutputElemType(ctx, 0, sparse.values().data_type());
    auto* output_shape = getOutputShape(ctx, 0);
    for (int i = 0; i < sparse.dims_size(); ++i)
      appendDim(output_shape, sparse.dims(i));
    return;
  }

  fail_shape_inference(
      "TypeAndShapeInferenceFunction implementation incomplete: "
      "this line should never be reached.");
}

} // namespace ONNX_NAMESPACE