Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import inspect | |
from collections import defaultdict | |
from textwrap import dedent | |
from typing import Any, ClassVar, Dict, List, Tuple, Type | |
import numpy as np | |
def process_snippet(op_name: str, name: str, export: Any) -> Tuple[str, str]: | |
snippet_name = name[len("export_") :] or op_name.lower() | |
source_code = dedent(inspect.getsource(export)) | |
# remove the function signature line | |
lines = source_code.splitlines() | |
assert lines[0] == "@staticmethod" | |
assert lines[1].startswith("def export") | |
return snippet_name, dedent("\n".join(lines[2:])) | |
Snippets: Dict[str, List[Tuple[str, str]]] = defaultdict(list) | |
class _Exporter(type): | |
exports: ClassVar[Dict[str, List[Tuple[str, str]]]] = defaultdict(list) | |
def __init__( | |
cls, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] | |
) -> None: | |
for k, v in dct.items(): | |
if k.startswith("export"): | |
if not isinstance(v, staticmethod): | |
raise ValueError("Only staticmethods could be named as export.*") | |
export = getattr(cls, k) | |
Snippets[name].append(process_snippet(name, k, export)) | |
# export functions should call expect and so populate | |
# TestCases | |
np.random.seed(seed=0) | |
export() | |
super().__init__(name, bases, dct) | |
class Base(metaclass=_Exporter): | |
pass | |