Spaces:
Sleeping
Sleeping
import re | |
import torch._C as C | |
""" | |
PythonDispatcher class is a thin python-binding to C++ dispatcher and it | |
is designed to show how dispatcher precompute works. In particular, | |
it shows for a certain op `foo`, what the computed dispatch table looks | |
like after user register their kernels to certains dispatch keys. | |
In the real C++ dispatcher we support many dispatch keys for different | |
functionalities. For simplicity PythonDispatcher only supports dispatch | |
keys for a single example of each use case. These use cases are listed below: | |
- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & | |
autograd kernel in pytorch core library. | |
E.g. CPU, CUDA | |
- FPGA/AutogradOther: represents in-tree backends which we usually have backend specific | |
inference kernels, but they share the same autograd kernel specified in AutogradOther. | |
E.g. FPGA, SparseCsrCPU | |
- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd | |
kernel defined in pytorch core library. Backend owner is responsible for registering both | |
inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. | |
E.g. XLA, XPU, MPS | |
- CompositeExplicitAutograd: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. | |
Kernels registered to this key MUST work for inference for all backends. | |
- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. | |
Kernels registered to this key MUST work for autograd for all backends. | |
- CompositeImplicitAutograd: alias key CompositeImplicitAutograd = CompositeExplicitAutograd + Autograd | |
Kernels registered to this key MUST work for both inference + autograd for all backends. | |
Note we only allow registrations to alias keys inside pytorch core library. E.g | |
you shouldn't register a CompositeImplicitAutograd or CompositeExplicitAutograd | |
kernel from torch-xla extension, instead you should upstream the kernel into | |
pytorch/pytorch repo so that it's available for all backends and continuously | |
tested even without the extension. | |
Usage: | |
dispatcher = PythonDispatcher() | |
dispatcher.register(["CPU", "XLA", "CompositeImplicitAutograd"]) | |
print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. | |
# For more debugging information | |
# print(dispatcher.keys()) | |
# print(dispatcher.registrations()) | |
# print(dispatcher.rawRegistrations()) | |
# print(dispatcher.rawDispatchTable()) | |
PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. | |
This file only provides the simplified API for developers, relevant test code is located in | |
test/test_dispatch.py | |
""" | |
class PythonDispatcher: | |
namespace = "__test__" | |
name = "foo" | |
# fmt: off | |
runtime_keys = [ | |
"CPU", "AutogradCPU", | |
"FPGA", "AutogradOther", | |
"XLA", "AutogradXLA", | |
"Lazy", "AutogradLazy", | |
] | |
# fmt: on | |
alias_keys = [ | |
"CompositeExplicitAutograd", | |
"Autograd", | |
"CompositeImplicitAutograd", | |
] | |
supported_keys = runtime_keys + alias_keys | |
def __init__(self): | |
C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] | |
self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") | |
self.ref.def_("foo(Tensor x) -> Tensor") | |
""" | |
Returns a list of dispatch keys supported by PythonDispatcher. | |
You can register kernels to these keys. | |
""" | |
def keys(self): | |
return self.supported_keys | |
""" | |
Register kernels to the target dispatchKeys. | |
dispatchKeys(list[str]): a list of dispatch keys that you want to register | |
your own kernel. Note that you don't need to write the kernel yourself in | |
this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is | |
automatically generated and registered. | |
""" | |
def register(self, dispatchKeys): | |
# Overriden is not supported and triggers a warning in C++ dispatcher. | |
if len(set(dispatchKeys)) != len(dispatchKeys): | |
raise RuntimeError( | |
f"Overriden is not allowed but found duplicates in {dispatchKeys}." | |
) | |
# We currently forbid this in codegen instead of C++ dispatcher. | |
if ( | |
"CompositeImplicitAutograd" in dispatchKeys | |
and "CompositeExplicitAutograd" in dispatchKeys | |
): | |
raise RuntimeError( | |
"Registration to both CompositeImplicitAutograd and CompositeExplicitAutograd is not allowed." | |
) | |
for key in dispatchKeys: | |
if key not in self.supported_keys: | |
raise RuntimeError( | |
f"{key} is not supported, please select a dispatch key in {self.supported_keys}." | |
) | |
self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) | |
""" | |
Helper function to format (key, kernel). | |
""" | |
def _format_line(self, key, kernel): | |
return f"{key:<15} {kernel}\n" | |
""" | |
Helper function to print a table header. | |
""" | |
def _format_header(self, header): | |
s = f""" | |
{header} | |
""" | |
s += self._format_line("key", "kernel") | |
s += "---------------------------\n" | |
return s | |
""" | |
Returns raw output of all registration info for debugging only. | |
Use registrations() for a simplified version. | |
""" | |
def rawRegistrations(self): | |
return C._dispatch_dump(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] | |
""" | |
Returns raw output of computed dispatch table for debugging only. | |
Use dispatchTable() for a simplified version. | |
""" | |
def rawDispatchTable(self): | |
return C._dispatch_dump_table(f"{self.namespace}::{self.name}") # type: ignore[attr-defined] | |
""" | |
Returns a table(str) including all the registrations from users. | |
Note this includes registrations to both runtime keys and alias keys. | |
""" | |
def registrations(self): | |
output = self._format_header("Registered Kernels") | |
state = self.rawRegistrations() | |
state_entries = state.split("\n") | |
for line in state_entries: | |
first = line.split(":")[0] | |
if any(first.startswith(k) for k in self.supported_keys): | |
kernel = line.split("::")[0].split(" ")[1] | |
output += self._format_line(first, kernel) | |
return output | |
""" | |
Returns the computed dispatch table(str). Note this only include | |
runtime keys, registrations to alias keys have been decoded to their | |
mapped runtime keys. | |
""" | |
def dispatchTable(self): | |
output = self._format_header("Computed Dispatch Table") | |
table = self.rawDispatchTable() | |
table_entries = table.split("\n") | |
regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") | |
for line in table_entries: | |
k = line.split(":")[0] | |
if k in self.runtime_keys: | |
entry = regex.sub("[", line) | |
output += self._format_line(k, entry.split(": ")[1]) | |
return output | |