File size: 3,567 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# mypy: ignore-errors

import functools
import sys
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple

import torch
from torch import fx


class CompiledFn(Protocol):
    def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        ...


CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]

_BACKENDS: Dict[str, CompilerFn] = dict()


def register_backend(

    compiler_fn: Optional[CompilerFn] = None,

    name: Optional[str] = None,

    tags: Sequence[str] = (),

):
    """

    Decorator to add a given compiler to the registry to allow calling

    `torch.compile` with string shorthand.  Note: for projects not

    imported by default, it might be easier to pass a function directly

    as a backend and not use a string.



    Args:

        compiler_fn: Callable taking a FX graph and fake tensor inputs

        name: Optional name, defaults to `compiler_fn.__name__`

        tags: Optional set of string tags to categorize backend with

    """
    if compiler_fn is None:
        # @register_backend(name="") syntax
        return functools.partial(register_backend, name=name, tags=tags)
    assert callable(compiler_fn)
    name = name or compiler_fn.__name__
    assert name not in _BACKENDS, f"duplicate name: {name}"
    _BACKENDS[name] = compiler_fn
    compiler_fn._tags = tuple(tags)
    return compiler_fn


register_debug_backend = functools.partial(register_backend, tags=("debug",))
register_experimental_backend = functools.partial(
    register_backend, tags=("experimental",)
)


def lookup_backend(compiler_fn):
    """Expand backend strings to functions"""
    if isinstance(compiler_fn, str):
        if compiler_fn not in _BACKENDS:
            _lazy_import()
        if compiler_fn not in _BACKENDS:
            _lazy_import_entry_point(compiler_fn)
        if compiler_fn not in _BACKENDS:
            from ..exc import InvalidBackend

            raise InvalidBackend(name=compiler_fn)
        compiler_fn = _BACKENDS[compiler_fn]
    return compiler_fn


def list_backends(exclude_tags=("debug", "experimental")) -> List[str]:
    """

    Return valid strings that can be passed to:



        torch.compile(..., backend="name")

    """
    _lazy_import()
    exclude_tags = set(exclude_tags or ())
    return sorted(
        [
            name
            for name, backend in _BACKENDS.items()
            if not exclude_tags.intersection(backend._tags)
        ]
    )


@functools.lru_cache(None)
def _lazy_import():
    from .. import backends
    from ..utils import import_submodule

    import_submodule(backends)

    from ..repro.after_dynamo import dynamo_minifier_backend

    assert dynamo_minifier_backend is not None


@functools.lru_cache(None)
def _lazy_import_entry_point(backend_name: str):
    from importlib.metadata import entry_points

    compiler_fn = None
    group_name = "torch_dynamo_backends"
    if sys.version_info < (3, 10):
        backend_eps = entry_points()
        eps = [ep for ep in backend_eps.get(group_name, ()) if ep.name == backend_name]
        if len(eps) > 0:
            compiler_fn = eps[0].load()
    else:
        backend_eps = entry_points(group=group_name)
        if backend_name in backend_eps.names:
            compiler_fn = backend_eps[backend_name].load()

    if compiler_fn is not None and backend_name not in list_backends(tuple()):
        register_backend(compiler_fn=compiler_fn, name=backend_name)