Spaces:
Running
Running
File size: 3,499 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include <algorithm>
#include <numeric>
#include "onnx/defs/function.h"
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
static const char* OptionalHasElement_ver1_doc = R"DOC(
Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalHasElement,
15,
OpSchema()
.SetDoc(OptionalHasElement_ver1_doc)
.Input(0, "input", "The optional input.", "O")
.Output(
0,
"output",
"A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.",
"B")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint("B", {"tensor(bool)"}, "Constrain output to a boolean tensor.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 1 input.");
}
const size_t numOutputs = ctx.getNumOutputs();
if (numOutputs != 1) {
fail_type_inference("OptionalHasElement is expected to have 1 output.");
}
auto* output_tensor_type = ctx.getOutputType(0)->mutable_tensor_type();
output_tensor_type->set_elem_type(TensorProto::BOOL);
output_tensor_type->mutable_shape()->Clear();
}));
static const char* OptionalGetElement_ver1_doc = R"DOC(
Outputs the element in the optional-type input. It is an error if the input value does not have an element
and the behavior is undefined in this case.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
OptionalGetElement,
15,
OpSchema()
.SetDoc(OptionalGetElement_ver1_doc)
.Input(0, "input", "The optional input.", "O")
.Output(0, "output", "Output element in the optional input.", "V")
.TypeConstraint(
"O",
OpSchema::all_optional_types(),
"Constrain input type to optional tensor and optional sequence types.")
.TypeConstraint(
"V",
[]() {
auto t = OpSchema::all_tensor_types();
auto s = OpSchema::all_tensor_sequence_types();
t.insert(t.end(), s.begin(), s.end());
return t;
}(),
"Constrain output type to all tensor or sequence types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
const size_t numInputs = ctx.getNumInputs();
if (numInputs != 1) {
fail_type_inference("OptionalGetElement must have an input element.");
}
auto input_type = ctx.getInputType(0);
if (input_type == nullptr) {
fail_type_inference("Input type is null. Input must have Type information.");
}
if (!input_type->has_optional_type() || !input_type->optional_type().has_elem_type()) {
fail_type_inference("Input must be an optional-type value containing an element with type information.");
}
ctx.getOutputType(0)->CopyFrom(input_type->optional_type().elem_type());
}));
} // namespace ONNX_NAMESPACE
|