File size: 3,688 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""ONNX exporter exceptions."""
from __future__ import annotations

import textwrap
from typing import Optional

from torch import _C
from torch.onnx import _constants
from torch.onnx._internal import diagnostics

__all__ = [
    "OnnxExporterError",
    "OnnxExporterWarning",
    "CheckerError",
    "SymbolicValueError",
    "UnsupportedOperatorError",
]


class OnnxExporterWarning(UserWarning):
    """Base class for all warnings in the ONNX exporter."""

    pass


class OnnxExporterError(RuntimeError):
    """Errors raised by the ONNX exporter."""

    pass


class CheckerError(OnnxExporterError):
    """Raised when ONNX checker detects an invalid model."""

    pass


class UnsupportedOperatorError(OnnxExporterError):
    """Raised when an operator is unsupported by the exporter."""

    def __init__(self, name: str, version: int, supported_version: Optional[int]):
        if supported_version is not None:
            diagnostic_rule: diagnostics.infra.Rule = (
                diagnostics.rules.operator_supported_in_newer_opset_version
            )
            msg = diagnostic_rule.format_message(name, version, supported_version)
            diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
        else:
            if name.startswith(("aten::", "prim::", "quantized::")):
                diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function
                msg = diagnostic_rule.format_message(
                    name, version, _constants.PYTORCH_GITHUB_ISSUES_URL
                )
                diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
            else:
                diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function
                msg = diagnostic_rule.format_message(name)
                diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg)
        super().__init__(msg)


class SymbolicValueError(OnnxExporterError):
    """Errors around TorchScript values and nodes."""

    def __init__(self, msg: str, value: _C.Value):
        message = (
            f"{msg}  [Caused by the value '{value}' (type '{value.type()}') in the "
            f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] "
        )

        code_location = value.node().sourceRange()
        if code_location:
            message += f"\n    (node defined in {code_location})"

        try:
            # Add its input and output to the message.
            message += "\n\n"
            message += textwrap.indent(
                (
                    "Inputs:\n"
                    + (
                        "\n".join(
                            f"    #{i}: {input_}  (type '{input_.type()}')"
                            for i, input_ in enumerate(value.node().inputs())
                        )
                        or "    Empty"
                    )
                    + "\n"
                    + "Outputs:\n"
                    + (
                        "\n".join(
                            f"    #{i}: {output}  (type '{output.type()}')"
                            for i, output in enumerate(value.node().outputs())
                        )
                        or "    Empty"
                    )
                ),
                "    ",
            )
        except AttributeError:
            message += (
                " Failed to obtain its input and output for debugging. "
                "Please refer to the TorchScript graph for debugging information."
            )

        super().__init__(message)