Spaces:
Sleeping
Sleeping
File size: 4,319 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import itertools
import os
import platform
import unittest
from typing import Any, Optional, Sequence, Tuple
import numpy
import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto, NodeProto, TensorProto
from onnx.backend.base import Device, DeviceType
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt
# The following just executes the fake backend through the backend test
# infrastructure. Since we don't have full reference implementation of all ops
# in ONNX repo, it's impossible to produce the proper results. However, we can
# run 'checker' (that's what base Backend class does) to verify that all tests
# fed are actually well-formed ONNX models.
#
# If everything is fine, all the tests would be marked as "skipped".
#
# We don't enable report in this test because the report collection logic itself
# fails when models are mal-formed.
class DummyBackend(onnx.backend.base.Backend):
@classmethod
def prepare(
cls, model: ModelProto, device: str = "CPU", **kwargs: Any
) -> Optional[onnx.backend.base.BackendRep]:
super().prepare(model, device, **kwargs)
onnx.checker.check_model(model)
# by default test strict shape inference
kwargs = {"check_type": True, "strict_mode": True, **kwargs}
model = onnx.shape_inference.infer_shapes(model, **kwargs)
value_infos = {
vi.name: vi
for vi in itertools.chain(model.graph.value_info, model.graph.output)
}
if do_enforce_test_coverage_safelist(model):
for node in model.graph.node:
for i, output in enumerate(node.output):
if node.op_type == "Dropout" and i != 0:
continue
assert output in value_infos
tt = value_infos[output].type.tensor_type
assert tt.elem_type != TensorProto.UNDEFINED
for dim in tt.shape.dim:
assert dim.WhichOneof("value") == "dim_value"
raise BackendIsNotSupposedToImplementIt(
"This is the dummy backend test that doesn't verify the results but does run the checker"
)
@classmethod
def run_node(
cls,
node: NodeProto,
inputs: Any,
device: str = "CPU",
outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,
**kwargs: Any,
) -> Optional[Tuple[Any, ...]]:
super().run_node(node, inputs, device=device, outputs_info=outputs_info)
raise BackendIsNotSupposedToImplementIt(
"This is the dummy backend test that doesn't verify the results but does run the checker"
)
@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
if d.type == DeviceType.CPU:
return True
return False
test_coverage_safelist = {
"bvlc_alexnet",
"densenet121",
"inception_v1",
"inception_v2",
"resnet50",
"shufflenet",
"SingleRelu",
"squeezenet_old",
"vgg19",
"zfnet",
}
def do_enforce_test_coverage_safelist(model: ModelProto) -> bool:
if model.graph.name not in test_coverage_safelist:
return False
return all(node.op_type not in {"RNN", "LSTM", "GRU"} for node in model.graph.node)
test_kwargs = {
# https://github.com/onnx/onnx/issues/5510 (test_mvn fails with test_backend_test.py)
"test_mvn": {"strict_mode": False},
}
backend_test = onnx.backend.test.BackendTest(
DummyBackend, __name__, test_kwargs=test_kwargs
)
if os.getenv("APPVEYOR"):
backend_test.exclude(r"(test_vgg19|test_zfnet)")
if platform.architecture()[0] == "32bit":
backend_test.exclude(r"(test_vgg19|test_zfnet|test_bvlc_alexnet)")
# Needs investigation on onnxruntime.
backend_test.exclude("test_dequantizelinear_e4m3fn_float16")
# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)
if __name__ == "__main__":
unittest.main()
|