File size: 6,891 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""Submodule containing all the ONNX schema definitions."""
from __future__ import annotations

from typing import Sequence, overload

from onnx import AttributeProto, FunctionProto

class SchemaError(Exception): ...

class OpSchema:
    def __init__(

        self,

        name: str,

        domain: str,

        since_version: int,

        doc: str = "",

        *,

        inputs: Sequence[OpSchema.FormalParameter] = (),

        outputs: Sequence[OpSchema.FormalParameter] = (),

        type_constraints: Sequence[tuple[str, Sequence[str], str]] = (),

        attributes: Sequence[OpSchema.Attribute] = (),

    ) -> None: ...
    @property
    def file(self) -> str: ...
    @property
    def line(self) -> int: ...
    @property
    def support_level(self) -> SupportType: ...
    @property
    def doc(self) -> str | None: ...
    @property
    def since_version(self) -> int: ...
    @property
    def deprecated(self) -> bool: ...
    @property
    def domain(self) -> str: ...
    @property
    def name(self) -> str: ...
    @property
    def min_input(self) -> int: ...
    @property
    def max_input(self) -> int: ...
    @property
    def min_output(self) -> int: ...
    @property
    def max_output(self) -> int: ...
    @property
    def attributes(self) -> dict[str, Attribute]: ...
    @property
    def inputs(self) -> Sequence[FormalParameter]: ...
    @property
    def outputs(self) -> Sequence[FormalParameter]: ...
    @property
    def type_constraints(self) -> Sequence[TypeConstraintParam]: ...
    @property
    def has_type_and_shape_inference_function(self) -> bool: ...
    @property
    def has_data_propagation_function(self) -> bool: ...
    @staticmethod
    def is_infinite(v: int) -> bool: ...
    def consumed(self, schema: OpSchema, i: int) -> tuple[UseType, int]: ...
    def _infer_node_outputs(

        self,

        node_proto: bytes,

        value_types: dict[str, bytes],

        input_data: dict[str, bytes],

        input_sparse_data: dict[str, bytes],

    ) -> dict[str, bytes]: ...
    @property
    def function_body(self) -> FunctionProto: ...

    class TypeConstraintParam:
        def __init__(

            self,

            type_param_str: str,

            allowed_type_strs: Sequence[str],

            description: str = "",

        ) -> None:
            """Type constraint parameter.



            Args:

                type_param_str: Type parameter string, for example, "T", "T1", etc.

                allowed_type_strs: Allowed type strings for this type parameter. E.g. ["tensor(float)"].

                description: Type parameter description.

            """
        @property
        def type_param_str(self) -> str: ...
        @property
        def description(self) -> str: ...
        @property
        def allowed_type_strs(self) -> Sequence[str]: ...

    class FormalParameterOption:
        Single: OpSchema.FormalParameterOption = ...
        Optional: OpSchema.FormalParameterOption = ...
        Variadic: OpSchema.FormalParameterOption = ...

    class DifferentiationCategory:
        Unknown: OpSchema.DifferentiationCategory = ...
        Differentiable: OpSchema.DifferentiationCategory = ...
        NonDifferentiable: OpSchema.DifferentiationCategory = ...

    class FormalParameter:
        def __init__(

            self,

            name: str,

            type_str: str,

            description: str = "",

            *,

            param_option: OpSchema.FormalParameterOption = OpSchema.FormalParameterOption.Single,  # noqa: F821

            is_homogeneous: bool = True,

            min_arity: int = 1,

            differentiation_category: OpSchema.DifferentiationCategory = OpSchema.DifferentiationCategory.Unknown,  # noqa: F821

        ) -> None: ...
        @property
        def name(self) -> str: ...
        @property
        def types(self) -> set[str]: ...
        @property
        def type_str(self) -> str: ...
        @property
        def description(self) -> str: ...
        @property
        def option(self) -> OpSchema.FormalParameterOption: ...
        @property
        def is_homogeneous(self) -> bool: ...
        @property
        def min_arity(self) -> int: ...
        @property
        def differentiation_category(self) -> OpSchema.DifferentiationCategory: ...

    class AttrType:
        FLOAT: OpSchema.AttrType = ...
        INT: OpSchema.AttrType = ...
        STRING: OpSchema.AttrType = ...
        TENSOR: OpSchema.AttrType = ...
        GRAPH: OpSchema.AttrType = ...
        SPARSE_TENSOR: OpSchema.AttrType = ...
        TYPE_PROTO: OpSchema.AttrType = ...
        FLOATS: OpSchema.AttrType = ...
        INTS: OpSchema.AttrType = ...
        STRINGS: OpSchema.AttrType = ...
        TENSORS: OpSchema.AttrType = ...
        GRAPHS: OpSchema.AttrType = ...
        SPARSE_TENSORS: OpSchema.AttrType = ...
        TYPE_PROTOS: OpSchema.AttrType = ...

    class Attribute:
        @overload
        def __init__(

            self,

            name: str,

            type: OpSchema.AttrType,  # noqa: A002

            description: str = "",

            *,

            required: bool = True,

        ) -> None: ...
        @overload
        def __init__(

            self,

            name: str,

            default_value: AttributeProto,

            description: str = "",

        ) -> None: ...
        @property
        def name(self) -> str: ...
        @property
        def description(self) -> str: ...
        @property
        def type(self) -> OpSchema.AttrType: ...
        @property
        def default_value(self) -> AttributeProto: ...
        @property
        def required(self) -> bool: ...

    class SupportType(int):
        COMMON: OpSchema.SupportType = ...
        EXPERIMENTAL: OpSchema.SupportType = ...

    class UseType:
        DEFAULT: OpSchema.UseType = ...
        CONSUME_ALLOWED: OpSchema.UseType = ...
        CONSUME_ENFORCED: OpSchema.UseType = ...

@overload
def has_schema(op_type: str, domain: str = "") -> bool: ...
@overload
def has_schema(

    op_type: str, max_inclusive_version: int, domain: str = ""

) -> bool: ...
def schema_version_map() -> dict[str, tuple[int, int]]: ...
@overload
def get_schema(

    op_type: str, max_inclusive_version: int, domain: str = ""

) -> OpSchema: ...
@overload
def get_schema(op_type: str, domain: str = "") -> OpSchema: ...
def get_all_schemas() -> Sequence[OpSchema]: ...
def get_all_schemas_with_history() -> Sequence[OpSchema]: ...
def set_domain_to_version(domain: str, min_version: int, max_version: int, last_release_version: int = -1) -> None: ...
def register_schema(schema: OpSchema) -> None: ...
def deregister_schema(op_type: str, version: int, domain: str) -> None: ...