Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
4.32 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import itertools
import os
import platform
import unittest
from typing import Any, Optional, Sequence, Tuple
import numpy
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto, NodeProto, TensorProto
from onnx.backend.base import Device, DeviceType
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
# The following just executes the fake backend through the backend test
# infrastructure. Since we don't have full reference implementation of all ops
# in ONNX repo, it's impossible to produce the proper results. However, we can
# run 'checker' (that's what base Backend class does) to verify that all tests
# fed are actually well-formed ONNX models.
#
# If everything is fine, all the tests would be marked as "skipped".
#
# We don't enable report in this test because the report collection logic itself
# fails when models are mal-formed.
class DummyBackend(onnx.backend.base.Backend):
@classmethod
def prepare(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> Optional[onnx.backend.base.BackendRep]:
super().prepare(model, device, **kwargs)
onnx.checker.check_model(model)
# by default test strict shape inference
kwargs = {"check_type": True, "strict_mode": True, **kwargs}
model = onnx.shape_inference.infer_shapes(model, **kwargs)
value_infos = {
vi.name: vi
for vi in itertools.chain(model.graph.value_info, model.graph.output)
}
if do_enforce_test_coverage_safelist(model):
for node in model.graph.node:
for i, output in enumerate(node.output):
if node.op_type == "Dropout" and i != 0:
continue
assert output in value_infos
tt = value_infos[output].type.tensor_type
assert tt.elem_type != TensorProto.UNDEFINED
for dim in tt.shape.dim:
assert dim.WhichOneof("value") == "dim_value"
raise BackendIsNotSupposedToImplementIt(
"This is the dummy backend test that doesn't verify the results but does run the checker"
)
@classmethod
def run_node(
cls,
node: NodeProto,
inputs: Any,
device: str = "CPU",
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
**kwargs: Any,
) -> Optional[Tuple[Any, ...]]:
super().run_node(node, inputs, device=device, outputs_info=outputs_info)
raise BackendIsNotSupposedToImplementIt(
"This is the dummy backend test that doesn't verify the results but does run the checker"
)
@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
if d.type == DeviceType.CPU:
return True
return False
test_coverage_safelist = {
"bvlc_alexnet",
"densenet121",
"inception_v1",
"inception_v2",
"resnet50",
"shufflenet",
"SingleRelu",
"squeezenet_old",
"vgg19",
"zfnet",
}
def do_enforce_test_coverage_safelist(model: ModelProto) -> bool:
if model.graph.name not in test_coverage_safelist:
return False
return all(node.op_type not in {"RNN", "LSTM", "GRU"} for node in model.graph.node)
test_kwargs = {
# https://github.com/onnx/onnx/issues/5510 (test_mvn fails with test_backend_test.py)
"test_mvn": {"strict_mode": False},
}
backend_test = onnx.backend.test.BackendTest(
DummyBackend, __name__, test_kwargs=test_kwargs
)
if os.getenv("APPVEYOR"):
backend_test.exclude(r"(test_vgg19|test_zfnet)")
if platform.architecture()[0] == "32bit":
backend_test.exclude(r"(test_vgg19|test_zfnet|test_bvlc_alexnet)")
# Needs investigation on onnxruntime.
backend_test.exclude("test_dequantizelinear_e4m3fn_float16")
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)
if __name__ == "__main__":
unittest.main()