/* * SPDX-License-Identifier: Apache-2.0 */ #include "onnx/defs/controlflow/utils.h" #include #include namespace ONNX_NAMESPACE { void ClearShape(TypeProto& input_type) { if (input_type.has_tensor_type()) { input_type.mutable_tensor_type()->clear_shape(); } else if (input_type.has_sequence_type()) { auto& seq_type = *input_type.mutable_sequence_type(); if (seq_type.has_elem_type()) { ClearShape(*(seq_type.mutable_elem_type())); } } else if (input_type.has_optional_type()) { auto& opt_type = *input_type.mutable_optional_type(); if (opt_type.has_elem_type()) { ClearShape(*(opt_type.mutable_elem_type())); } } } void IfInferenceFunction(InferenceContext& ctx) { // there are no inputs so we just need to run the subgraph inferencing for // then/else subgraphs and apply those to the outputs. std::vector subgraph_input_types; // none std::vector input_data; // none std::vector then_output_types; std::vector else_output_types; // Run inferencing on the subgraph GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("then_branch"); if (graphInferencer) { then_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); } graphInferencer = ctx.getGraphAttributeInferencer("else_branch"); if (graphInferencer) { else_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); } auto num_outputs = ctx.getNumOutputs(); auto num_then_outputs = then_output_types.size(); auto num_else_outputs = else_output_types.size(); // the output types for then and else should be the same if (num_then_outputs != num_else_outputs) { fail_type_inference( "then_branch and else_branch produce different number of outputs. ", num_then_outputs, " != ", num_else_outputs); } if (num_then_outputs != num_outputs) { fail_type_inference("If node has ", num_outputs, " but subgraphs produce ", num_then_outputs); } for (size_t i = 0, end = then_output_types.size(); i < end; ++i) { auto then_output = then_output_types[i]; auto else_output = else_output_types[i]; auto* if_output = ctx.getOutputType(i); *if_output = *then_output; UnionTypeInfo(*else_output, *if_output); } } void LoopInferenceFunction(InferenceContext& ctx) { auto num_inputs = ctx.getNumInputs(); assert(num_inputs >= 2); auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond' std::vector subgraph_input_types; subgraph_input_types.reserve(num_inputs); std::vector temporary_type_protos; temporary_type_protos.reserve(num_inputs - 2); // create TypeProto to validate iteration number type is the same as the // optional 'M' input for max iterations. TypeProto iter_num_type; iter_num_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); subgraph_input_types.push_back(&iter_num_type); // 'cond' subgraph_input_types.push_back(ctx.getInputType(1)); // loop state value types get propagated to outputs, but shape may change // across iterations so don't propagate it to the outputs and don't pass it // into the subgraph inferencing for (size_t i = 2; i < num_inputs; ++i) { propagateElemTypeFromInputToOutput(ctx, i, i - 2); // copy so we can remove the shape before passing to the subgraph // inferencing temporary_type_protos.push_back(*ctx.getInputType(i)); auto& input_type = temporary_type_protos.back(); ClearShape(input_type); subgraph_input_types.push_back(&input_type); } // Run inferencing on the subgraph std::vector subgraph_output_types; GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); if (graphInferencer) { std::vector input_data; input_data.push_back(nullptr); // iteration number for (size_t i = 1; i < num_inputs; ++i) { input_data.push_back(ctx.getInputData(i)); } subgraph_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); } // if empty(), assume inferencing was skipped if (!subgraph_output_types.empty()) { auto num_outputs = ctx.getNumOutputs(); // subgraph outputs the condition value first but that is only used // internally and not returned by Loop. if (subgraph_output_types.size() != num_outputs + 1) { fail_type_inference( "Graph attribute inferencing returned type information for ", subgraph_output_types.size(), " outputs. Expected ", num_outputs + 1); } // check loop state values match. we should already have type/shape info for (size_t i = 0; i < num_outputs; ++i) { auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond' auto* loop_output_type = ctx.getOutputType(i); const bool is_loop_state_var = i < num_loop_state_vars; if (!subgraph_output_type->has_tensor_type() && !subgraph_output_type->has_sequence_type() && !subgraph_output_type->has_optional_type()) { fail_type_inference( "Loop 'body' subgraph outputs should all be tensors or sequences or optionals, but output ", i, " was ", subgraph_output_type->value_case()); } if (!is_loop_state_var && !subgraph_output_type->has_tensor_type()) { fail_type_inference( "Loop 'body' subgraph scan outputs should all be tensors but output ", i, " was ", subgraph_output_type->value_case()); } // if there's an existing type check it matches. otherwise propagate propagateElemTypeWithValidation(subgraph_output_type, loop_output_type); if (is_loop_state_var) { // shape may change across iterations so ignore. } else { // propagate shape if (subgraph_output_type->tensor_type().has_shape()) { // per iteration output. first dimension will be number of iterations // but we don't know that value yet TypeProto inferred_type(*subgraph_output_type); auto* mutable_inferred_tensor_type = inferred_type.mutable_tensor_type(); auto* mutable_inferred_shape = mutable_inferred_tensor_type->mutable_shape(); mutable_inferred_shape->clear_dim(); // add empty dimension for number of iterations mutable_inferred_shape->add_dim(); // add dimensions from subgraph output shape for (const auto& dim : subgraph_output_type->tensor_type().shape().dim()) { (*mutable_inferred_shape->add_dim()) = dim; } mergeInShapeInfo(*mutable_inferred_tensor_type, *loop_output_type->mutable_tensor_type()); } } } } } int handle_negative_axis_validate(const std::string& attrib, int axis, int rank) { if (!(-rank <= axis && axis < rank)) { fail_shape_inference(attrib, " axis value ", axis, " is invalid for a tensor of rank ", rank); } return (axis >= 0 ? axis : axis + rank); } void ScanInferenceFunction(InferenceContext& ctx) { auto num_inputs = ctx.getNumInputs(); auto num_scan_inputs = narrow_cast(ctx.getAttribute("num_scan_inputs")->i()); auto num_loop_state_vars = num_inputs - num_scan_inputs; auto num_outputs = ctx.getNumOutputs(); auto num_scan_outputs = num_outputs - num_loop_state_vars; std::vector axes, output_axes; if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) { if (axes.size() != num_scan_inputs) { fail_shape_inference( "Number of scan input axes specified (", axes.size(), ") is not equal to number of scan inputs (", num_scan_inputs, ")."); } } else { axes.insert(axes.end(), num_scan_inputs, 0); } if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) { if (output_axes.size() != num_scan_outputs) { fail_shape_inference( "Number of scan output axes specified (", output_axes.size(), ") is not equal to number of scan outputs (", num_scan_outputs, ")."); } } else { output_axes.insert(output_axes.end(), num_scan_outputs, 0); } std::vector temporary_type_protos; temporary_type_protos.reserve(num_inputs); std::vector subgraph_input_types; subgraph_input_types.reserve(num_inputs); TensorShapeProto_Dimension sequence_len_dim; for (size_t i = 0; i < num_inputs; ++i) { bool is_loop_state_var = i < num_loop_state_vars; bool has_shape = hasInputShape(ctx, i); const auto* input_type = ctx.getInputType(i); // Enforce type constraint for inputs if (!input_type || !input_type->has_tensor_type()) { fail_type_inference("Scan input ", i, " was not a tensor."); } if (is_loop_state_var) { // If it's a loop state variable we can propagate type and shape 1:1 to // the matching Scan output. // We can also pass through the type and shape to the subgraph but need to // remove the batch size dimension from the shape. propagateElemTypeFromInputToOutput(ctx, i, i); if (has_shape) propagateShapeFromInputToOutput(ctx, i, i); subgraph_input_types.push_back(input_type); } else { // For other inputs there is no fixed relationships to the Scan outputs, // so we don't propagate type/shape information. // We can pass through the type and shape to the subgraph inputs but // need to remove the sequence length dimensions from the shape. if (has_shape) { const auto& shape = input_type->tensor_type().shape(); // remove sequence length dimensions and add to subgraph_input_types int axis = static_cast(axes[i - num_loop_state_vars]); axis = handle_negative_axis_validate("scan_input_axes", axis, shape.dim_size()); // update sequence_len if a value is available const auto& dims = shape.dim(); mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1); temporary_type_protos.push_back(RemoveIthDimensionFromShape(*input_type, axis)); subgraph_input_types.push_back(&temporary_type_protos.back()); } else { subgraph_input_types.push_back(input_type); } } } // Run inferencing on the subgraph std::vector output_types; GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); if (graphInferencer) { std::vector input_data; input_data.reserve(num_inputs); for (size_t i = 0; i < num_inputs; ++i) { // ctx.getInputData(i), the input to scan, does not represent the input to // scan body. So, we pass in null, to represent an unknown value. input_data.push_back(nullptr); } output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); } // if empty(), assume inferencing was skipped if (!output_types.empty()) { if (output_types.size() != num_outputs) { fail_type_inference( "Graph attribute inferencing returned type information for ", output_types.size(), " outputs. Expected ", num_outputs); } // propagate type/shape information for loop state variables and outputs for (size_t i = 0; i < num_outputs; ++i) { const bool is_loop_state_var = i < num_loop_state_vars; auto* subgraph_output_type = output_types[i]; auto* scan_output_type = ctx.getOutputType(i); auto* mutable_scan_output_tensor_type = scan_output_type->mutable_tensor_type(); if (!subgraph_output_type->has_tensor_type()) { fail_type_inference("Scan 'body' subgraph outputs should all be tensors but output ", i, " was not"); } auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type(); if (is_loop_state_var) { // merge shape; type already propagated mergeInShapeInfo(subgraph_output_tensor_type, *mutable_scan_output_tensor_type); } else { scan_output_type->mutable_tensor_type()->set_elem_type(subgraph_output_tensor_type.elem_type()); // propagate shape if (subgraph_output_tensor_type.has_shape()) { // infer shape of scan-output from the shape of scan-output-element // by adding sequence-length at the correct axis position const TensorShapeProto& subgraph_output_shape = subgraph_output_tensor_type.shape(); TensorShapeProto inferred_shape; auto subgraph_output_rank = subgraph_output_shape.dim_size(); auto output_rank = subgraph_output_rank + 1; int output_axis = static_cast(output_axes[i - num_loop_state_vars]); output_axis = handle_negative_axis_validate("scan_output_axes", output_axis, output_rank); for (int j = 0; j < output_axis; ++j) *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j); *(inferred_shape.add_dim()) = sequence_len_dim; for (int j = output_axis; j < subgraph_output_rank; ++j) *(inferred_shape.add_dim()) = subgraph_output_shape.dim(j); // Merge inferred shape with existing shape information mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type); } } } } } } // namespace ONNX_NAMESPACE