Spaces:
Sleeping
Sleeping
| /* | |
| * SPDX-License-Identifier: Apache-2.0 | |
| */ | |
| namespace ONNX_NAMESPACE { | |
| void ConstantOpInference(InferenceContext& ctx) { | |
| auto* value = ctx.getAttribute("value"); | |
| auto* sparse_value = ctx.getAttribute("sparse_value"); | |
| auto* value_int = ctx.getAttribute("value_int"); | |
| auto* value_ints = ctx.getAttribute("value_ints"); | |
| auto* value_float = ctx.getAttribute("value_float"); | |
| auto* value_floats = ctx.getAttribute("value_floats"); | |
| auto* value_string = ctx.getAttribute("value_string"); | |
| auto* value_strings = ctx.getAttribute("value_strings"); | |
| std::vector<bool> non_null_attr = { | |
| (nullptr != value), | |
| (nullptr != sparse_value), | |
| (nullptr != value_int), | |
| (nullptr != value_ints), | |
| (nullptr != value_float), | |
| (nullptr != value_floats), | |
| (nullptr != value_string), | |
| (nullptr != value_strings)}; | |
| if (std::count(non_null_attr.begin(), non_null_attr.end(), true) != 1) { | |
| fail_shape_inference( | |
| "One and only one of the attributes 'value', 'value_*' or 'sparse_value' must be specified for a Constant node."); | |
| } | |
| if (nullptr != value) { | |
| // OpSchema::Verify check ensures that the attribute value has_t(): | |
| const TensorProto& tensor_proto = value->t(); | |
| updateOutputElemType(ctx, 0, tensor_proto.data_type()); | |
| updateOutputShape(ctx, 0, tensor_proto); | |
| return; | |
| } | |
| if (nullptr != value_int) { | |
| // OpSchema::Verify check ensures that the attribute value has_i(): | |
| if (!value_int->has_i()) { | |
| fail_shape_inference("Attribute 'value_int' expect an integer.") | |
| } | |
| updateOutputElemType(ctx, 0, TensorProto::INT64); | |
| updateOutputShape(ctx, 0, TensorShapeProto()); | |
| return; | |
| } | |
| if (nullptr != value_ints) { | |
| updateOutputElemType(ctx, 0, TensorProto::INT64); | |
| appendDim(getOutputShape(ctx, 0), value_ints->ints_size()); | |
| return; | |
| } | |
| if (nullptr != value_float) { | |
| // OpSchema::Verify check ensures that the attribute value has_i(): | |
| if (!value_float->has_f()) { | |
| fail_shape_inference("Attribute 'value_float' expect a float."); | |
| } | |
| updateOutputElemType(ctx, 0, TensorProto::FLOAT); | |
| updateOutputShape(ctx, 0, TensorShapeProto()); | |
| return; | |
| } | |
| if (nullptr != value_floats) { | |
| updateOutputElemType(ctx, 0, TensorProto::FLOAT); | |
| appendDim(getOutputShape(ctx, 0), value_floats->floats_size()); | |
| return; | |
| } | |
| if (nullptr != value_string) { | |
| // OpSchema::Verify check ensures that the attribute value has_i(): | |
| if (!value_string->has_s()) { | |
| fail_shape_inference("Attribute 'value_string' expect a string."); | |
| } | |
| updateOutputElemType(ctx, 0, TensorProto::STRING); | |
| updateOutputShape(ctx, 0, TensorShapeProto()); | |
| return; | |
| } | |
| if (nullptr != value_strings) { | |
| updateOutputElemType(ctx, 0, TensorProto::STRING); | |
| appendDim(getOutputShape(ctx, 0), value_strings->strings_size()); | |
| return; | |
| } | |
| if (nullptr != sparse_value) { | |
| // OpSchema::Verify check ensures that the attribute value | |
| // has_sparse_tensor(): | |
| const SparseTensorProto& sparse = sparse_value->sparse_tensor(); | |
| // checker.cc::check_sparse_tensor checks that the sparse-value is | |
| // well-formed | |
| updateOutputElemType(ctx, 0, sparse.values().data_type()); | |
| auto* output_shape = getOutputShape(ctx, 0); | |
| for (int i = 0; i < sparse.dims_size(); ++i) | |
| appendDim(output_shape, sparse.dims(i)); | |
| return; | |
| } | |
| fail_shape_inference( | |
| "TypeAndShapeInferenceFunction implementation incomplete: " | |
| "this line should never be reached."); | |
| } | |
| } // namespace ONNX_NAMESPACE | |