File size: 1,855 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
/*

 * SPDX-License-Identifier: Apache-2.0

 */

#pragma once

#include <cmath>
#include <vector>

#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"

namespace ONNX_NAMESPACE {
// The below is called by ops after opset 11, inclusively.
void resizeShapeInference(InferenceContext& ctx);

void gridSampleShapeInference(InferenceContext& ctx);

void resizeShapeInferenceHelper(

    const TensorShapeProto& input_shape,

    const std::vector<float>& scales_data,

    TensorShapeProto* output_shape);

void resizeShapeInferenceHelper(

    const TensorShapeProto& input_shape,

    const std::vector<int64_t>& sizes_data,

    TensorShapeProto* output_shape);

// Belows are called by ops between opset versions in the name inclusively.
void resizeShapeInference_opset7_to_10(InferenceContext& ctx);
void resizeShapeInference_opset11_to_12(InferenceContext& ctx);
void resizeShapeInference_opset13_to_18(InferenceContext& ctx);
void resizeShapeInference_opset18_to_19(InferenceContext& ctx);

void resizeShapeInferenceHelper_opset7_to_10(

    const TensorShapeProto& input_shape,

    const std::vector<float>& scales_data,

    TensorShapeProto* output_shape);

enum class KeepAspectRatioPolicy {
  STRETCH,
  NOT_LARGER,
  NOT_SMALLER,
};

void KeepAspectRatioHelper(

    KeepAspectRatioPolicy policy,

    const TensorShapeProto& input_shape,

    const std::vector<int64_t>& axes,

    std::vector<int64_t>& sizes_data);

extern const char* NonZero_ver9_doc;

std::function<void(OpSchema&)> PadDocGenerator(

    const char* description,

    const char* mode_description,

    const std::vector<std::string> op_schema = OpSchema::all_tensor_types_ir4(),

    const std::string op_schema_description = "Constrain input and output types to all tensor types.");
} // namespace ONNX_NAMESPACE