Spaces:
Running
Running
# mypy: ignore-errors | |
import functools | |
import sys | |
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple | |
import torch | |
from torch import fx | |
class CompiledFn(Protocol): | |
def __call__(self, *args: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |
... | |
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn] | |
_BACKENDS: Dict[str, CompilerFn] = dict() | |
def register_backend( | |
compiler_fn: Optional[CompilerFn] = None, | |
name: Optional[str] = None, | |
tags: Sequence[str] = (), | |
): | |
""" | |
Decorator to add a given compiler to the registry to allow calling | |
`torch.compile` with string shorthand. Note: for projects not | |
imported by default, it might be easier to pass a function directly | |
as a backend and not use a string. | |
Args: | |
compiler_fn: Callable taking a FX graph and fake tensor inputs | |
name: Optional name, defaults to `compiler_fn.__name__` | |
tags: Optional set of string tags to categorize backend with | |
""" | |
if compiler_fn is None: | |
# @register_backend(name="") syntax | |
return functools.partial(register_backend, name=name, tags=tags) | |
assert callable(compiler_fn) | |
name = name or compiler_fn.__name__ | |
assert name not in _BACKENDS, f"duplicate name: {name}" | |
_BACKENDS[name] = compiler_fn | |
compiler_fn._tags = tuple(tags) | |
return compiler_fn | |
register_debug_backend = functools.partial(register_backend, tags=("debug",)) | |
register_experimental_backend = functools.partial( | |
register_backend, tags=("experimental",) | |
) | |
def lookup_backend(compiler_fn): | |
"""Expand backend strings to functions""" | |
if isinstance(compiler_fn, str): | |
if compiler_fn not in _BACKENDS: | |
_lazy_import() | |
if compiler_fn not in _BACKENDS: | |
_lazy_import_entry_point(compiler_fn) | |
if compiler_fn not in _BACKENDS: | |
from ..exc import InvalidBackend | |
raise InvalidBackend(name=compiler_fn) | |
compiler_fn = _BACKENDS[compiler_fn] | |
return compiler_fn | |
def list_backends(exclude_tags=("debug", "experimental")) -> List[str]: | |
""" | |
Return valid strings that can be passed to: | |
torch.compile(..., backend="name") | |
""" | |
_lazy_import() | |
exclude_tags = set(exclude_tags or ()) | |
return sorted( | |
[ | |
name | |
for name, backend in _BACKENDS.items() | |
if not exclude_tags.intersection(backend._tags) | |
] | |
) | |
def _lazy_import(): | |
from .. import backends | |
from ..utils import import_submodule | |
import_submodule(backends) | |
from ..repro.after_dynamo import dynamo_minifier_backend | |
assert dynamo_minifier_backend is not None | |
def _lazy_import_entry_point(backend_name: str): | |
from importlib.metadata import entry_points | |
compiler_fn = None | |
group_name = "torch_dynamo_backends" | |
if sys.version_info < (3, 10): | |
backend_eps = entry_points() | |
eps = [ep for ep in backend_eps.get(group_name, ()) if ep.name == backend_name] | |
if len(eps) > 0: | |
compiler_fn = eps[0].load() | |
else: | |
backend_eps = entry_points(group=group_name) | |
if backend_name in backend_eps.names: | |
compiler_fn = backend_eps[backend_name].load() | |
if compiler_fn is not None and backend_name not in list_backends(tuple()): | |
register_backend(compiler_fn=compiler_fn, name=backend_name) | |