Spaces:
Running
Running
File size: 1,640 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
# This file is for testing ONNX with ONNX Runtime
# Create a general scenario to use ONNX Runtime with ONNX
import unittest
class TestONNXRuntime(unittest.TestCase):
def test_with_ort_example(self) -> None:
try:
import onnxruntime
del onnxruntime
except ImportError:
raise unittest.SkipTest("onnxruntime not installed") from None
from numpy import float32, random
from onnxruntime import InferenceSession
from onnxruntime.datasets import get_example
from onnx import checker, load, shape_inference, version_converter
# get certain example model from ORT using opset 9
example1 = get_example("sigmoid.onnx")
# test ONNX functions
model = load(example1)
checker.check_model(model)
checker.check_model(model, full_check=True)
inferred_model = shape_inference.infer_shapes(
model, check_type=True, strict_mode=True, data_prop=True
)
converted_model = version_converter.convert_version(inferred_model, 10)
# test ONNX Runtime functions
sess = InferenceSession(
converted_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
x = random.random((3, 4, 5))
x = x.astype(float32)
sess.run([output_name], {input_name: x})
if __name__ == "__main__":
unittest.main()
|