Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
1.86 kB
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include <cmath>
#include <vector>
#include "onnx/defs/schema.h"
#include "onnx/defs/tensor_proto_util.h"
namespace ONNX_NAMESPACE {
// The below is called by ops after opset 11, inclusively.
void resizeShapeInference(InferenceContext& ctx);
void gridSampleShapeInference(InferenceContext& ctx);
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape);
void resizeShapeInferenceHelper(
const TensorShapeProto& input_shape,
const std::vector<int64_t>& sizes_data,
TensorShapeProto* output_shape);
// Belows are called by ops between opset versions in the name inclusively.
void resizeShapeInference_opset7_to_10(InferenceContext& ctx);
void resizeShapeInference_opset11_to_12(InferenceContext& ctx);
void resizeShapeInference_opset13_to_18(InferenceContext& ctx);
void resizeShapeInference_opset18_to_19(InferenceContext& ctx);
void resizeShapeInferenceHelper_opset7_to_10(
const TensorShapeProto& input_shape,
const std::vector<float>& scales_data,
TensorShapeProto* output_shape);
enum class KeepAspectRatioPolicy {
STRETCH,
NOT_LARGER,
NOT_SMALLER,
};
void KeepAspectRatioHelper(
KeepAspectRatioPolicy policy,
const TensorShapeProto& input_shape,
const std::vector<int64_t>& axes,
std::vector<int64_t>& sizes_data);
extern const char* NonZero_ver9_doc;
std::function<void(OpSchema&)> PadDocGenerator(
const char* description,
const char* mode_description,
const std::vector<std::string> op_schema = OpSchema::all_tensor_types_ir4(),
const std::string op_schema_description = "Constrain input and output types to all tensor types.");
} // namespace ONNX_NAMESPACE