File size: 1,540 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import inspect
from collections import defaultdict
from textwrap import dedent
from typing import Any, ClassVar, Dict, List, Tuple, Type

import numpy as np


def process_snippet(op_name: str, name: str, export: Any) -> Tuple[str, str]:
    snippet_name = name[len("export_") :] or op_name.lower()
    source_code = dedent(inspect.getsource(export))
    # remove the function signature line
    lines = source_code.splitlines()
    assert lines[0] == "@staticmethod"
    assert lines[1].startswith("def export")
    return snippet_name, dedent("\n".join(lines[2:]))


Snippets: Dict[str, List[Tuple[str, str]]] = defaultdict(list)


class _Exporter(type):
    exports: ClassVar[Dict[str, List[Tuple[str, str]]]] = defaultdict(list)

    def __init__(

        cls, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any]

    ) -> None:
        for k, v in dct.items():
            if k.startswith("export"):
                if not isinstance(v, staticmethod):
                    raise ValueError("Only staticmethods could be named as export.*")
                export = getattr(cls, k)
                Snippets[name].append(process_snippet(name, k, export))
                # export functions should call expect and so populate
                # TestCases
                np.random.seed(seed=0)
                export()
        super().__init__(name, bases, dct)


class Base(metaclass=_Exporter):
    pass