Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import itertools | |
import os | |
import platform | |
import unittest | |
from typing import Any, Optional, Sequence, Tuple | |
import numpy | |
import onnx.backend.base | |
import onnx.backend.test | |
import onnx.shape_inference | |
import onnx.version_converter | |
from onnx import ModelProto, NodeProto, TensorProto | |
from onnx.backend.base import Device, DeviceType | |
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt | |
# The following just executes the fake backend through the backend test | |
# infrastructure. Since we don't have full reference implementation of all ops | |
# in ONNX repo, it's impossible to produce the proper results. However, we can | |
# run 'checker' (that's what base Backend class does) to verify that all tests | |
# fed are actually well-formed ONNX models. | |
# | |
# If everything is fine, all the tests would be marked as "skipped". | |
# | |
# We don't enable report in this test because the report collection logic itself | |
# fails when models are mal-formed. | |
class DummyBackend(onnx.backend.base.Backend): | |
def prepare( | |
cls, model: ModelProto, device: str = "CPU", **kwargs: Any | |
) -> Optional[onnx.backend.base.BackendRep]: | |
super().prepare(model, device, **kwargs) | |
onnx.checker.check_model(model) | |
# by default test strict shape inference | |
kwargs = {"check_type": True, "strict_mode": True, **kwargs} | |
model = onnx.shape_inference.infer_shapes(model, **kwargs) | |
value_infos = { | |
vi.name: vi | |
for vi in itertools.chain(model.graph.value_info, model.graph.output) | |
} | |
if do_enforce_test_coverage_safelist(model): | |
for node in model.graph.node: | |
for i, output in enumerate(node.output): | |
if node.op_type == "Dropout" and i != 0: | |
continue | |
assert output in value_infos | |
tt = value_infos[output].type.tensor_type | |
assert tt.elem_type != TensorProto.UNDEFINED | |
for dim in tt.shape.dim: | |
assert dim.WhichOneof("value") == "dim_value" | |
raise BackendIsNotSupposedToImplementIt( | |
"This is the dummy backend test that doesn't verify the results but does run the checker" | |
) | |
def run_node( | |
cls, | |
node: NodeProto, | |
inputs: Any, | |
device: str = "CPU", | |
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None, | |
**kwargs: Any, | |
) -> Optional[Tuple[Any, ...]]: | |
super().run_node(node, inputs, device=device, outputs_info=outputs_info) | |
raise BackendIsNotSupposedToImplementIt( | |
"This is the dummy backend test that doesn't verify the results but does run the checker" | |
) | |
def supports_device(cls, device: str) -> bool: | |
d = Device(device) | |
if d.type == DeviceType.CPU: | |
return True | |
return False | |
test_coverage_safelist = { | |
"bvlc_alexnet", | |
"densenet121", | |
"inception_v1", | |
"inception_v2", | |
"resnet50", | |
"shufflenet", | |
"SingleRelu", | |
"squeezenet_old", | |
"vgg19", | |
"zfnet", | |
} | |
def do_enforce_test_coverage_safelist(model: ModelProto) -> bool: | |
if model.graph.name not in test_coverage_safelist: | |
return False | |
return all(node.op_type not in {"RNN", "LSTM", "GRU"} for node in model.graph.node) | |
test_kwargs = { | |
# https://github.com/onnx/onnx/issues/5510 (test_mvn fails with test_backend_test.py) | |
"test_mvn": {"strict_mode": False}, | |
} | |
backend_test = onnx.backend.test.BackendTest( | |
DummyBackend, __name__, test_kwargs=test_kwargs | |
) | |
if os.getenv("APPVEYOR"): | |
backend_test.exclude(r"(test_vgg19|test_zfnet)") | |
if platform.architecture()[0] == "32bit": | |
backend_test.exclude(r"(test_vgg19|test_zfnet|test_bvlc_alexnet)") | |
# Needs investigation on onnxruntime. | |
backend_test.exclude("test_dequantizelinear_e4m3fn_float16") | |
# import all test cases at global scope to make them visible to python.unittest | |
globals().update(backend_test.test_cases) | |
if __name__ == "__main__": | |
unittest.main() | |