Spaces:
Sleeping
Sleeping
File size: 13,832 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "onnx/defs/controlflow/utils.h"
#include <string>
#include <vector>
namespace ONNX_NAMESPACE {
void ClearShape(TypeProto& input_type) {
if (input_type.has_tensor_type()) {
input_type.mutable_tensor_type()->clear_shape();
} else if (input_type.has_sequence_type()) {
auto& seq_type = *input_type.mutable_sequence_type();
if (seq_type.has_elem_type()) {
ClearShape(*(seq_type.mutable_elem_type()));
}
} else if (input_type.has_optional_type()) {
auto& opt_type = *input_type.mutable_optional_type();
if (opt_type.has_elem_type()) {
ClearShape(*(opt_type.mutable_elem_type()));
}
}
}
void IfInferenceFunction(InferenceContext& ctx) {
// there are no inputs so we just need to run the subgraph inferencing for
// then/else subgraphs and apply those to the outputs.
std::vector<const TypeProto*> subgraph_input_types; // none
std::vector<const TensorProto*> input_data; // none
std::vector<const TypeProto*> then_output_types;
std::vector<const TypeProto*> else_output_types;
// Run inferencing on the subgraph
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("then_branch");
if (graphInferencer) {
then_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
graphInferencer = ctx.getGraphAttributeInferencer("else_branch");
if (graphInferencer) {
else_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
auto num_outputs = ctx.getNumOutputs();
auto num_then_outputs = then_output_types.size();
auto num_else_outputs = else_output_types.size();
// the output types for then and else should be the same
if (num_then_outputs != num_else_outputs) {
fail_type_inference(
"then_branch and else_branch produce different number of outputs. ",
num_then_outputs,
" != ",
num_else_outputs);
}
if (num_then_outputs != num_outputs) {
fail_type_inference("If node has ", num_outputs, " but subgraphs produce ", num_then_outputs);
}
for (size_t i = 0, end = then_output_types.size(); i < end; ++i) {
auto then_output = then_output_types[i];
auto else_output = else_output_types[i];
auto* if_output = ctx.getOutputType(i);
*if_output = *then_output;
UnionTypeInfo(*else_output, *if_output);
}
}
void LoopInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
assert(num_inputs >= 2);
auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'
std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);
std::vector<TypeProto> temporary_type_protos;
temporary_type_protos.reserve(num_inputs - 2);
// create TypeProto to validate iteration number type is the same as the
// optional 'M' input for max iterations.
TypeProto iter_num_type;
iter_num_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64);
subgraph_input_types.push_back(&iter_num_type);
// 'cond'
subgraph_input_types.push_back(ctx.getInputType(1));
// loop state value types get propagated to outputs, but shape may change
// across iterations so don't propagate it to the outputs and don't pass it
// into the subgraph inferencing
for (size_t i = 2; i < num_inputs; ++i) {
propagateElemTypeFromInputToOutput(ctx, i, i - 2);
// copy so we can remove the shape before passing to the subgraph
// inferencing
temporary_type_protos.push_back(*ctx.getInputType(i));
auto& input_type = temporary_type_protos.back();
ClearShape(input_type);
subgraph_input_types.push_back(&input_type);
}
// Run inferencing on the subgraph
std::vector<const TypeProto*> subgraph_output_types;
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
if (graphInferencer) {
std::vector<const TensorProto*> input_data;
input_data.push_back(nullptr); // iteration number
for (size_t i = 1; i < num_inputs; ++i) {
input_data.push_back(ctx.getInputData(i));
}
subgraph_output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
// if empty(), assume inferencing was skipped
if (!subgraph_output_types.empty()) {
auto num_outputs = ctx.getNumOutputs();
// subgraph outputs the condition value first but that is only used
// internally and not returned by Loop.
if (subgraph_output_types.size() != num_outputs + 1) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
subgraph_output_types.size(),
" outputs. Expected ",
num_outputs + 1);
}
// check loop state values match. we should already have type/shape info
for (size_t i = 0; i < num_outputs; ++i) {
auto* subgraph_output_type = subgraph_output_types[i + 1]; // skip 'cond'
auto* loop_output_type = ctx.getOutputType(i);
const bool is_loop_state_var = i < num_loop_state_vars;
if (!subgraph_output_type->has_tensor_type() && !subgraph_output_type->has_sequence_type() &&
!subgraph_output_type->has_optional_type()) {
fail_type_inference(
"Loop 'body' subgraph outputs should all be tensors or sequences or optionals, but output ",
i,
" was ",
subgraph_output_type->value_case());
}
if (!is_loop_state_var && !subgraph_output_type->has_tensor_type()) {
fail_type_inference(
"Loop 'body' subgraph scan outputs should all be tensors but output ",
i,
" was ",
subgraph_output_type->value_case());
}
// if there's an existing type check it matches. otherwise propagate
propagateElemTypeWithValidation(subgraph_output_type, loop_output_type);
if (is_loop_state_var) {
// shape may change across iterations so ignore.
} else {
// propagate shape
if (subgraph_output_type->tensor_type().has_shape()) {
// per iteration output. first dimension will be number of iterations
// but we don't know that value yet
TypeProto inferred_type(*subgraph_output_type);
auto* mutable_inferred_tensor_type = inferred_type.mutable_tensor_type();
auto* mutable_inferred_shape = mutable_inferred_tensor_type->mutable_shape();
mutable_inferred_shape->clear_dim();
// add empty dimension for number of iterations
mutable_inferred_shape->add_dim();
// add dimensions from subgraph output shape
for (const auto& dim : subgraph_output_type->tensor_type().shape().dim()) {
(*mutable_inferred_shape->add_dim()) = dim;
}
mergeInShapeInfo(*mutable_inferred_tensor_type, *loop_output_type->mutable_tensor_type());
}
}
}
}
}
int handle_negative_axis_validate(const std::string& attrib, int axis, int rank) {
if (!(-rank <= axis && axis < rank)) {
fail_shape_inference(attrib, " axis value ", axis, " is invalid for a tensor of rank ", rank);
}
return (axis >= 0 ? axis : axis + rank);
}
void ScanInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
auto num_scan_inputs = narrow_cast<size_t>(ctx.getAttribute("num_scan_inputs")->i());
auto num_loop_state_vars = num_inputs - num_scan_inputs;
auto num_outputs = ctx.getNumOutputs();
auto num_scan_outputs = num_outputs - num_loop_state_vars;
std::vector<int64_t> axes, output_axes;
if (getRepeatedAttribute(ctx, "scan_input_axes", axes)) {
if (axes.size() != num_scan_inputs) {
fail_shape_inference(
"Number of scan input axes specified (",
axes.size(),
") is not equal to number of scan inputs (",
num_scan_inputs,
").");
}
} else {
axes.insert(axes.end(), num_scan_inputs, 0);
}
if (getRepeatedAttribute(ctx, "scan_output_axes", output_axes)) {
if (output_axes.size() != num_scan_outputs) {
fail_shape_inference(
"Number of scan output axes specified (",
output_axes.size(),
") is not equal to number of scan outputs (",
num_scan_outputs,
").");
}
} else {
output_axes.insert(output_axes.end(), num_scan_outputs, 0);
}
std::vector<TypeProto> temporary_type_protos;
temporary_type_protos.reserve(num_inputs);
std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);
TensorShapeProto_Dimension sequence_len_dim;
for (size_t i = 0; i < num_inputs; ++i) {
bool is_loop_state_var = i < num_loop_state_vars;
bool has_shape = hasInputShape(ctx, i);
const auto* input_type = ctx.getInputType(i);
// Enforce type constraint for inputs
if (!input_type || !input_type->has_tensor_type()) {
fail_type_inference("Scan input ", i, " was not a tensor.");
}
if (is_loop_state_var) {
// If it's a loop state variable we can propagate type and shape 1:1 to
// the matching Scan output.
// We can also pass through the type and shape to the subgraph but need to
// remove the batch size dimension from the shape.
propagateElemTypeFromInputToOutput(ctx, i, i);
if (has_shape)
propagateShapeFromInputToOutput(ctx, i, i);
subgraph_input_types.push_back(input_type);
} else {
// For other inputs there is no fixed relationships to the Scan outputs,
// so we don't propagate type/shape information.
// We can pass through the type and shape to the subgraph inputs but
// need to remove the sequence length dimensions from the shape.
if (has_shape) {
const auto& shape = input_type->tensor_type().shape();
// remove sequence length dimensions and add to subgraph_input_types
int axis = static_cast<int>(axes[i - num_loop_state_vars]);
axis = handle_negative_axis_validate("scan_input_axes", axis, shape.dim_size());
// update sequence_len if a value is available
const auto& dims = shape.dim();
mergeInDimensionInfo(dims.Get(axis), sequence_len_dim, 1);
temporary_type_protos.push_back(RemoveIthDimensionFromShape(*input_type, axis));
subgraph_input_types.push_back(&temporary_type_protos.back());
} else {
subgraph_input_types.push_back(input_type);
}
}
}
// Run inferencing on the subgraph
std::vector<const TypeProto*> output_types;
GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body");
if (graphInferencer) {
std::vector<const TensorProto*> input_data;
input_data.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; ++i) {
// ctx.getInputData(i), the input to scan, does not represent the input to
// scan body. So, we pass in null, to represent an unknown value.
input_data.push_back(nullptr);
}
output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
}
// if empty(), assume inferencing was skipped
if (!output_types.empty()) {
if (output_types.size() != num_outputs) {
fail_type_inference(
"Graph attribute inferencing returned type information for ",
output_types.size(),
" outputs. Expected ",
num_outputs);
}
// propagate type/shape information for loop state variables and outputs
for (size_t i = 0; i < num_outputs; ++i) {
const bool is_loop_state_var = i < num_loop_state_vars;
auto* subgraph_output_type = output_types[i];
auto* scan_output_type = ctx.getOutputType(i);
auto* mutable_scan_output_tensor_type = scan_output_type->mutable_tensor_type();
if (!subgraph_output_type->has_tensor_type()) {
fail_type_inference("Scan 'body' subgraph outputs should all be tensors but output ", i, " was not");
}
auto& subgraph_output_tensor_type = subgraph_output_type->tensor_type();
if (is_loop_state_var) {
// merge shape; type already propagated
mergeInShapeInfo(subgraph_output_tensor_type, *mutable_scan_output_tensor_type);
} else {
scan_output_type->mutable_tensor_type()->set_elem_type(subgraph_output_tensor_type.elem_type());
// propagate shape
if (subgraph_output_tensor_type.has_shape()) {
// infer shape of scan-output from the shape of scan-output-element
// by adding sequence-length at the correct axis position
const TensorShapeProto& subgraph_output_shape = subgraph_output_tensor_type.shape();
TensorShapeProto inferred_shape;
auto subgraph_output_rank = subgraph_output_shape.dim_size();
auto output_rank = subgraph_output_rank + 1;
int output_axis = static_cast<int>(output_axes[i - num_loop_state_vars]);
output_axis = handle_negative_axis_validate("scan_output_axes", output_axis, output_rank);
for (int j = 0; j < output_axis; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
*(inferred_shape.add_dim()) = sequence_len_dim;
for (int j = output_axis; j < subgraph_output_rank; ++j)
*(inferred_shape.add_dim()) = subgraph_output_shape.dim(j);
// Merge inferred shape with existing shape information
mergeInShapeInfo(inferred_shape, *mutable_scan_output_tensor_type);
}
}
}
}
}
} // namespace ONNX_NAMESPACE
|