Spaces:
Sleeping
Sleeping
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
namespace ONNX_NAMESPACE { | |
namespace defs { | |
namespace math { | |
namespace utils { | |
template <typename T> | |
T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) { | |
if (t == nullptr) { | |
return T{}; | |
} | |
auto data_type = t->data_type(); | |
switch (data_type) { | |
case ONNX_NAMESPACE::TensorProto::FLOAT: | |
return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0)); | |
case ONNX_NAMESPACE::TensorProto::DOUBLE: | |
return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0)); | |
case ONNX_NAMESPACE::TensorProto::INT32: | |
return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0)); | |
case ONNX_NAMESPACE::TensorProto::INT64: | |
return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0)); | |
default: | |
fail_shape_inference("Unsupported input data type of ", data_type); | |
} | |
} | |
void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx); | |
void QLinearMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx); | |
const char* QLinearMatMulDoc(); | |
} // namespace utils | |
} // namespace math | |
} // namespace defs | |
} // namespace ONNX_NAMESPACE | |