File size: 4,319 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import itertools
import os
import platform
import unittest
from typing import Any, Optional, Sequence, Tuple

import numpy

import onnx.backend.base
import onnx.backend.test
import onnx.shape_inference
import onnx.version_converter
from onnx import ModelProto, NodeProto, TensorProto
from onnx.backend.base import Device, DeviceType
from onnx.backend.test.runner import BackendIsNotSupposedToImplementIt

# The following just executes the fake backend through the backend test
# infrastructure. Since we don't have full reference implementation of all ops
# in ONNX repo, it's impossible to produce the proper results. However, we can
# run 'checker' (that's what base Backend class does) to verify that all tests
# fed are actually well-formed ONNX models.
#
# If everything is fine, all the tests would be marked as "skipped".
#
# We don't enable report in this test because the report collection logic itself
# fails when models are mal-formed.


class DummyBackend(onnx.backend.base.Backend):
    @classmethod
    def prepare(

        cls, model: ModelProto, device: str = "CPU", **kwargs: Any

    ) -> Optional[onnx.backend.base.BackendRep]:
        super().prepare(model, device, **kwargs)

        onnx.checker.check_model(model)

        # by default test strict shape inference
        kwargs = {"check_type": True, "strict_mode": True, **kwargs}
        model = onnx.shape_inference.infer_shapes(model, **kwargs)

        value_infos = {
            vi.name: vi
            for vi in itertools.chain(model.graph.value_info, model.graph.output)
        }

        if do_enforce_test_coverage_safelist(model):
            for node in model.graph.node:
                for i, output in enumerate(node.output):
                    if node.op_type == "Dropout" and i != 0:
                        continue
                    assert output in value_infos
                    tt = value_infos[output].type.tensor_type
                    assert tt.elem_type != TensorProto.UNDEFINED
                    for dim in tt.shape.dim:
                        assert dim.WhichOneof("value") == "dim_value"

        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the checker"
        )

    @classmethod
    def run_node(

        cls,

        node: NodeProto,

        inputs: Any,

        device: str = "CPU",

        outputs_info: Optional[Sequence[Tuple[numpy.dtype, Tuple[int, ...]]]] = None,

        **kwargs: Any,

    ) -> Optional[Tuple[Any, ...]]:
        super().run_node(node, inputs, device=device, outputs_info=outputs_info)
        raise BackendIsNotSupposedToImplementIt(
            "This is the dummy backend test that doesn't verify the results but does run the checker"
        )

    @classmethod
    def supports_device(cls, device: str) -> bool:
        d = Device(device)
        if d.type == DeviceType.CPU:
            return True
        return False


test_coverage_safelist = {
    "bvlc_alexnet",
    "densenet121",
    "inception_v1",
    "inception_v2",
    "resnet50",
    "shufflenet",
    "SingleRelu",
    "squeezenet_old",
    "vgg19",
    "zfnet",
}


def do_enforce_test_coverage_safelist(model: ModelProto) -> bool:
    if model.graph.name not in test_coverage_safelist:
        return False
    return all(node.op_type not in {"RNN", "LSTM", "GRU"} for node in model.graph.node)


test_kwargs = {
    # https://github.com/onnx/onnx/issues/5510 (test_mvn fails with test_backend_test.py)
    "test_mvn": {"strict_mode": False},
}

backend_test = onnx.backend.test.BackendTest(
    DummyBackend, __name__, test_kwargs=test_kwargs
)
if os.getenv("APPVEYOR"):
    backend_test.exclude(r"(test_vgg19|test_zfnet)")
if platform.architecture()[0] == "32bit":
    backend_test.exclude(r"(test_vgg19|test_zfnet|test_bvlc_alexnet)")

# Needs investigation on onnxruntime.
backend_test.exclude("test_dequantizelinear_e4m3fn_float16")

# import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases)

if __name__ == "__main__":
    unittest.main()