Spaces:
Running
Running
import dataclasses | |
import importlib | |
import logging | |
import os | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Final, | |
List, | |
Mapping, | |
Optional, | |
Sequence, | |
Set, | |
Tuple, | |
Union, | |
) | |
from typing_extensions import TypeAlias | |
import torch | |
import torch._C | |
import torch._ops | |
import torch._prims.executor | |
import torch.fx | |
from torch._subclasses.fake_tensor import FakeTensor | |
from torch.fx._compatibility import compatibility | |
from torch.fx.passes.fake_tensor_prop import FakeTensorProp | |
from torch.fx.passes.operator_support import OperatorSupport | |
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS | |
from torch.utils import _pytree | |
try: | |
# Use try-except to initialize package-dependent global variables. | |
import onnx | |
import onnxruntime # type: ignore[import] | |
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import] | |
# This is not use directly in DORT but needed by underlying exporter, | |
# so we still need to check if it exists. | |
importlib.import_module("onnxscript") | |
import torch.onnx | |
import torch.onnx._internal | |
import torch.onnx._internal.diagnostics | |
import torch.onnx._internal.exporter | |
import torch.onnx._internal.fx.decomposition_table | |
import torch.onnx._internal.fx.passes | |
from torch.onnx._internal.fx import fx_onnx_interpreter | |
from torch.onnx._internal.fx.type_utils import ( | |
_TORCH_DTYPE_TO_NUMPY_DTYPE, | |
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE, | |
from_python_type_to_onnx_tensor_element_type, | |
) | |
_SUPPORT_ONNXRT = True | |
except ImportError: | |
_SUPPORT_ONNXRT = False | |
__all__ = [ | |
"is_onnxrt_backend_supported", | |
"torch_compile_backend", | |
"OrtExecutionProvider", | |
"OrtBackendOptions", | |
"OrtBackend", | |
] | |
def is_onnxrt_backend_supported() -> bool: | |
"""Returns ``True`` if ONNX Runtime dependencies are installed and usable | |
to support TorchDynamo backend integration; ``False`` otherwise. | |
Example:: | |
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX) | |
>>> import torch | |
>>> if torch.onnx.is_onnxrt_backend_supported(): | |
... @torch.compile(backend="onnxrt") | |
... def f(x): | |
... return x * x | |
... print(f(torch.randn(10))) | |
... else: | |
... print("pip install onnx onnxscript onnxruntime") | |
... | |
""" | |
return _SUPPORT_ONNXRT | |
_dumped_onnx_model: Dict[str, int] = {} | |
def _dump_onnx_model( | |
model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None | |
) -> str: | |
"""Stores the onnx model into a file. | |
The name is "{ONNXRT_DUMP_PATH}{N}.onnx" | |
where *N* is the number of files already stored with | |
this prefix. | |
If graph_module is not None, the graph is stored as a string with | |
the same filename except the extension (.txt). | |
""" | |
prefix = os.environ.get("ONNXRT_DUMP_PATH", None) | |
if not prefix: | |
return "" | |
n = _dumped_onnx_model.get(prefix, -1) + 1 | |
filename = f"{prefix}{n}.onnx" | |
with open(filename, "wb") as f: | |
f.write(model_string) | |
_dumped_onnx_model[prefix] = n | |
if graph_module is not None: | |
filename_txt = f"{prefix}{n}.txt" | |
with open(filename_txt, "w", encoding="utf-8") as f: | |
f.write(str(graph_module.graph)) | |
return filename | |
def _infer_default_eps() -> Sequence[str]: | |
# TODO: select a good default based on the capabilities of the host | |
# e.g. DML on Windows, etc. | |
return ["CPUExecutionProvider"] | |
def _nvtx_range_push(name: str): | |
"""If PyTorch is installed with CUDA support, this starts NVTX range. | |
Check torch.cuda.nvtx.range_push's document for more details. | |
""" | |
if torch.cuda.is_available(): | |
torch.cuda.nvtx.range_push(name) | |
def _nvtx_range_pop(): | |
"""If PyTorch is installed with CUDA support, this terminates NVTX range. | |
Check torch.cuda.nvtx.range_pop's document for more details. | |
""" | |
if torch.cuda.is_available(): | |
torch.cuda.nvtx.range_pop() | |
def _get_ort_device_type(device_type: str): | |
if device_type == "cuda": | |
return ORTC.OrtDevice.cuda() | |
if device_type == "cpu": | |
return ORTC.OrtDevice.cpu() | |
# ort pytorch device is mapped to NPU OrtDevice type | |
if device_type == "ort": | |
return ORTC.OrtDevice.npu() | |
raise ValueError("Unsupported device type: " + device_type) | |
logger = logging.getLogger(__name__) | |
# Uncomment the following lines to print out development info. | |
# logging.basicConfig(level=logging.WARNING) | |
# logger.setLevel(logging.WARNING) | |
class OrtOperatorSupport(OperatorSupport): | |
"""Operator support for ONNXRuntime backend. | |
It has two-level of support decision. One is via support_dict and the other one | |
is via extra_support_dict. The logic of using support_dict is implemented in | |
OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported. | |
""" | |
def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): | |
# Use extra_support_dict[op_name] = None to indicate | |
# we support op_name with all input types. Otherwise, | |
# see support_dict (type: SupportDict) in operator_support.py | |
# for specifying supported types. | |
super().__init__(extra_support_dict) | |
self._onnx_support_dict = support_dict | |
def is_node_supported( | |
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node | |
) -> bool: | |
# OperatorSupport.is_node_supported returns True for non-callable nodes. | |
# Since ORT can't execute them, we return False here to override the base | |
# behavior. | |
if node.op not in CALLABLE_NODE_OPS: | |
return False | |
# This is the and the only place to decide if aten op is supported. | |
if node.op == "call_function" and node.target in self._onnx_support_dict: | |
logger.warning( | |
"support_dict supports node.target: %s (type: %s)", | |
node.target, | |
type(node.target), | |
) | |
return True | |
# If node.target is not in support_dict, we still want to check if torch.jit.script | |
# can convert it to ONNX equivalence. Let's use base mechanism to do this. | |
# See extra_support_dict for supported ops. | |
if super().is_node_supported(submodules, node): | |
logger.warning( | |
"extra_support_dict supports node.target: %s (type: %s)", | |
node.target, | |
type(node.target), | |
) | |
return True | |
logger.warning( | |
"support_dict and extra_support_dict don't support node.target: %s (type: %s)", | |
node.target, | |
type(node.target), | |
) | |
return False | |
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: | |
""" | |
In torch.fx.Graph, placeholder is a special assignment node. If it's not | |
executed in the beginning, it could overwrite values computed by upstream | |
nodes. | |
""" | |
graph = graph_module.graph | |
placeholders = [] | |
first_not_placeholder = None | |
for node in graph.nodes: | |
if node.op == "placeholder": | |
placeholders.append(node) | |
if first_not_placeholder is None and node.op != "placeholder": | |
first_not_placeholder = node | |
if first_not_placeholder is None: | |
return | |
for placeholder in placeholders: | |
first_not_placeholder.prepend(placeholder) | |
def _infer_ep_from_device(*args) -> Tuple[str, ...]: | |
"""Return the first valid device (i.e., GPU or CPU) in argument list.""" | |
eps = [] | |
for arg in args: | |
if hasattr(arg, "device"): | |
device = arg.device | |
if device.type == "cuda": | |
eps.append("CUDAExecutionProvider") | |
elif device.type == "cpu": | |
eps.append("CPUExecutionProvider") | |
return tuple(eps) | |
def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]: | |
placeholders = [] | |
for node in graph_module.graph.nodes: | |
if node.op == "placeholder": | |
if hasattr(node, "meta") and "val" in node.meta: | |
assert isinstance(node.meta["val"], torch.Tensor) | |
placeholders.append(node) | |
return tuple(placeholders) | |
def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: | |
"""Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" | |
for node in graph_module.graph.nodes: | |
if node.op == "output": | |
# Output node is unique. Let's retrieve output values from | |
# this node's input list. And then just return. | |
return node.args[0] | |
raise ValueError("No output node found in this torch.fx.GraphModule.") | |
def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: | |
"""Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" | |
flattened_output_args, _ = _pytree.tree_flatten( | |
_extract_graph_module_outputs(graph_module) | |
) | |
# Output arguments with example value (type: torch.Tensor) in the `graph_module`. | |
selected_output_args = [ | |
output_arg.meta["val"] | |
for output_arg in flattened_output_args | |
# output_arg must have tensor for its device information. | |
# Otherwise, skip it. | |
if (hasattr(output_arg, "meta") and "val" in output_arg.meta) | |
] | |
return _infer_ep_from_device(*selected_output_args) | |
def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: | |
"""Sort execution providers in eps based on pre-set priority.""" | |
def get_execution_provider_priority(ep: str) -> int: | |
if ep == "CPUExecutionProvider": | |
# Lowest priority. | |
return 2 | |
if ep == "CUDAExecutionProvider": | |
# Higher priority than CPU but lower than | |
# other specialized EPs. | |
return 1 | |
# Highest priority. | |
return 0 | |
unique_eps = set(eps) | |
return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) | |
def _get_onnx_devices( | |
values: Tuple[ | |
Union[ | |
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool | |
], | |
..., | |
] | |
) -> Tuple["ORTC.OrtDevice", ...]: | |
def _device_id_or_zero(device_id: int) -> int: | |
return device_id or 0 | |
def _map_tensor_or_sym_to_device( | |
value: Union[ | |
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool | |
], | |
) -> int: | |
if isinstance(value, torch.Tensor): | |
return ORTC.OrtDevice( | |
_get_ort_device_type(value.device.type), | |
ORTC.OrtDevice.default_memory(), | |
_device_id_or_zero(value.device.index), | |
) | |
elif isinstance( | |
value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool) | |
): | |
return ORTC.OrtDevice( | |
_get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0 | |
) | |
else: | |
raise ValueError("Unsupported value type: " + str(type(value))) | |
if len(values) > 0: | |
ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values) | |
return ort_devices | |
else: | |
return (_map_tensor_or_sym_to_device(1),) | |
def _get_ortvalues_from_torch_tensors( | |
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...] | |
) -> Tuple[torch.Tensor, ...]: | |
ortvalues = ORTC.OrtValueVector() | |
ortvalues.reserve(len(tensors)) | |
dtypes = [] | |
shapes = [] | |
data_ptrs = [] | |
for tensor in tensors: | |
dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype]) | |
shapes.append(tensor.size()) | |
data_ptrs.append(tensor.data_ptr()) | |
ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) | |
return ortvalues | |
def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: | |
if tensor.is_sparse: | |
raise ValueError("sparse tensor is not yet supported.") | |
out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) | |
return out | |
def _adjust_scalar_from_fx_to_onnx( | |
dynamo_value: Union[ | |
torch.Tensor, | |
int, | |
float, | |
bool, | |
], | |
value_info: "onnx.ValueInfoProto", # type: ignore[name-defined] | |
) -> torch.Tensor: | |
"""Helper function to wrap PyTorch variables as torch.Tensor""" | |
if ( | |
isinstance(dynamo_value, torch.Tensor) | |
and len(value_info.type.tensor_type.shape.dim) == 0 | |
and dynamo_value.shape == (1,) | |
): | |
# ONNX expect a scalar with empty shape. | |
# In contrast, PyTorch usually allows implicit | |
# conversion between shape=() and shape=(1,). | |
# | |
# Below, PyTorch's shape (1,) is reshaped to (). | |
return torch.squeeze(dynamo_value) | |
elif isinstance(dynamo_value, int): | |
return torch.tensor(dynamo_value, dtype=torch.int64) | |
elif isinstance(dynamo_value, float): | |
return torch.tensor(dynamo_value, dtype=torch.float32) | |
elif isinstance(dynamo_value, bool): | |
return torch.tensor(dynamo_value, dtype=torch.bool) | |
else: | |
assert isinstance(dynamo_value, torch.Tensor) | |
return dynamo_value.contiguous() | |
def _adjust_scalar_from_onnx_to_fx( | |
tensor: torch.Tensor, | |
prim_value: Union[ | |
torch.Tensor, | |
torch.SymInt, | |
int, | |
torch.SymFloat, | |
float, | |
torch.SymBool, | |
bool, | |
], | |
) -> Union[torch.Tensor, int, float, bool,]: | |
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables""" | |
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor." | |
if isinstance( | |
prim_value, | |
(torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool), | |
): | |
# Convert tensor back to scalar to match Dynamo's expectation. | |
return tensor.item() | |
return tensor | |
def _run_onnx_session_with_ortvaluevector( | |
sess: "onnxruntime.InferenceSession", | |
input_names: Tuple[str, ...], | |
inputs: Tuple[torch.Tensor, ...], | |
input_devices: Tuple["ORTC.OrtDevice", ...], | |
output_names: Tuple[str, ...], | |
outputs: Tuple[torch.Tensor, ...], | |
output_devices: Tuple["ORTC.OrtDevice", ...], | |
preallocate_output: bool, | |
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] | |
normalized_prim_outputs: Tuple[ | |
Union[ | |
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool | |
], | |
..., | |
], | |
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: | |
_nvtx_range_push("contiguous") | |
inputs = tuple( | |
_adjust_scalar_from_fx_to_onnx(arg, value_info) | |
for arg, value_info in zip(inputs, input_value_infos) | |
) | |
_nvtx_range_pop() | |
_nvtx_range_push("push_back_batch") | |
ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) | |
# preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. | |
# Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue | |
# to torch Tensor transferring the ownership. | |
if preallocate_output: | |
pth_outputs = tuple( | |
_to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs | |
) | |
ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) | |
else: | |
ort_outputs = ORTC.OrtValueVector() | |
_nvtx_range_pop() | |
_nvtx_range_push("run_with_ortvaluevector") | |
run_options = onnxruntime.RunOptions() | |
run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") | |
sess.run_with_ortvaluevector( | |
run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices | |
) | |
_nvtx_range_pop() | |
# Post-processing step: | |
# wrap ORT's outputs to the schema represented by | |
# `prim_output` (obtained by running the original | |
# torch.fx.GraphModule). | |
if preallocate_output: | |
# Profile the ORT-to-PyTorch type cast below | |
_nvtx_range_push("after run_with_ortvaluevector") | |
# Outputs are stored on pre-allocated torch.Tensors' memory, | |
# so this case doesn't need to convert ORTValue to torch.Tensor. | |
pth_outputs = tuple( | |
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] | |
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) | |
) | |
_nvtx_range_pop() | |
return pth_outputs | |
else: | |
# Profile the two ORT-to-PyTorch type casts below | |
_nvtx_range_push("after run_with_ortvaluevector") | |
# Map ORTValue to torch.Tensor. | |
pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor( | |
ort_outputs | |
) | |
# Change some torch.Tensor to int, float, bool. | |
pth_outputs = tuple( | |
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc] | |
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs) | |
) | |
_nvtx_range_pop() | |
return pth_outputs | |
def _run_onnx_session_with_fetch( | |
sess: "onnxruntime.InferenceSession", | |
input_names: Tuple[str, ...], | |
inputs: Tuple[torch.Tensor, ...], | |
input_devices: Tuple["ORTC.OrtDevice", ...], | |
output_names: Tuple[str, ...], | |
outputs: Tuple[torch.Tensor, ...], | |
output_devices: Tuple["ORTC.OrtDevice", ...], | |
preallocate_output: bool, | |
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] | |
normalized_prim_outputs: Tuple[ | |
Union[ | |
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool | |
], | |
..., | |
], | |
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]: | |
inputs = tuple( | |
_adjust_scalar_from_fx_to_onnx(arg, value_info) | |
for arg, value_info in zip(inputs, input_value_infos) | |
) | |
feed = { | |
name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy()) | |
for name, tensor in zip(input_names, inputs) | |
} | |
ort_outputs = sess.run(output_names, feed) | |
pth_outputs = tuple( | |
_adjust_scalar_from_onnx_to_fx( | |
torch.from_numpy(value), | |
prim_output, | |
) | |
for value, prim_output in zip(ort_outputs, normalized_prim_outputs) | |
) | |
return pth_outputs | |
class OrtExecutionInfoPerSession: | |
"""Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" | |
def __init__( | |
self, | |
session: "onnxruntime.InferenceSession", | |
input_names: Tuple[str, ...], | |
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] | |
output_names: Tuple[str, ...], | |
output_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined] | |
input_devices: Tuple["ORTC.OrtDevice", ...], | |
output_devices: Tuple["ORTC.OrtDevice", ...], | |
example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor], | |
): | |
# Carrier of ONNX model and its executor. | |
self.session: onnxruntime.InferenceSession = session | |
# For the ONNX model stored in self.session, self.input_names[i] is the | |
# name of the i-th positional input. | |
self.input_names: Tuple[str, ...] = input_names | |
# self.input_name[i]'s type information is stored in self.input_value_infos[i]. | |
self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined] | |
# Similar to self.input_names, but for outputs. | |
self.output_names: Tuple[str, ...] = output_names | |
# Similar to self.input_value_infos but for outputs. | |
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined] | |
# For the ONNX model stored in self.session, self.input_devices[i] is the | |
# i-th positional input's device. | |
self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices | |
# Similar to self.input_devices, but for outputs. | |
self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices | |
# This is the outputs of executing the original torch.fx.GraphModule with example inputs | |
# (i.e., args passed into OrtBackend._ort_acclerated_call). | |
self.example_outputs: Union[ | |
Tuple[torch.Tensor, ...], torch.Tensor | |
] = example_outputs | |
def is_supported(self, *args): | |
# Compare the args and the input schema in ONNX model and | |
# return the first match. | |
if len(args) != len(self.input_value_infos): | |
return False | |
for arg, value_info in zip(args, self.input_value_infos): | |
if not isinstance(arg, (torch.Tensor, float, int)): | |
return False | |
# Check Python scalars such as int, float, and bool. | |
if isinstance(arg, (int, float, bool)): | |
# Map, e.g., float to onnx.TensorProto.FLOAT. | |
onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg)) | |
if onnx_dtype != value_info.type.tensor_type.elem_type: | |
return False | |
if len(value_info.type.tensor_type.shape.dim) != 0: | |
return False | |
continue | |
# Check tensor. | |
onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype] | |
if onnx_dtype != value_info.type.tensor_type.elem_type: | |
return False | |
for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): | |
if isinstance(dim, int) and ( | |
onnx_dim.dim_value == dim or onnx_dim.dim_param | |
): | |
continue | |
elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: | |
continue | |
else: | |
return False | |
return True | |
class OrtExecutionInfoForAllGraphModules: | |
def __init__(self): | |
# All sessions (and their related information) created by exporting the same GraphModule | |
# with different inputs. | |
self.execution_info_per_graph_module: Dict[ | |
torch.fx.GraphModule, List[OrtExecutionInfoPerSession] | |
] = {} | |
def search_reusable_session_execution_info( | |
self, graph_module: torch.fx.GraphModule, *args | |
): | |
if graph_module not in self.execution_info_per_graph_module: | |
return None | |
# All execution information for ONNX models exported from the same `graph_module` | |
# with different inputs. | |
candidates = self.execution_info_per_graph_module[graph_module] | |
for candidate in candidates: | |
if candidate.is_supported(*args): | |
# Returns the first session that accepts this input schema. | |
return candidate | |
# No reusable session found. | |
return None | |
def cache_session_execution_info( | |
self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession | |
): | |
if graph_module not in self.execution_info_per_graph_module: | |
self.execution_info_per_graph_module[graph_module] = [info] | |
else: | |
self.execution_info_per_graph_module[graph_module].append(info) | |
OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]] | |
"""Either the name of an ONNX Runtime execution provider as a string or | |
a 2-tuple of the name and a dictionary of execution provider options. | |
Examples:: | |
>>> "CPUExecutionProvider" | |
>>> ("CUDAExecutionProvider", {"device_id": 3}) | |
""" | |
class OrtBackendOptions: | |
"""Options for constructing an ``OrtBackend``, the ONNX Runtime | |
backend (``"onnxrt"``) for ``torch.compile``. | |
Example:: | |
>>> @torch.compile( | |
... backend="onnxrt", | |
... options=torch.onnx._OrtBackendOptions(...), | |
... ) | |
... def ort_function(x): | |
... return x ** x | |
""" | |
preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None | |
"""An optional sequence of execution providers to be prioritized ahead of any | |
execution providers that may be inferred (see ``infer_execution_providers``). | |
""" | |
infer_execution_providers: bool = True | |
"""Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph.""" | |
default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None | |
"""The default fallback execution providers. If not specified, one will be | |
be selected based on the host environment (most likely ``"CPUExecutionProvider"``). | |
""" | |
# preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession | |
# in order to avoid internal allocation of output buffers in InferenceSession. | |
# If output ortvalue returned from InferenceSession is allocated internally, | |
# it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. | |
# When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor | |
# should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. | |
# It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, | |
# and use the preallocated output buffers for InferenceSession not holding any ownership for them. | |
# TODO(wschin): Make it to inference session level flag. | |
# See https://github.com/pytorch/pytorch/issues/106869. | |
preallocate_output: bool = False | |
"""If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side.""" | |
use_aot_autograd: bool = True | |
"""Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend | |
to support training (i.e., backward graphs are also sent to ``OrtBackend``). | |
Symbolic execution is used to capture the forward pass and backward passes as a single graph. | |
Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used | |
to split the entire graph into forward sub-graph and backward sub-graph. Finally, both | |
sub-graphs are compiled by ``OrtBackend``. | |
""" | |
export_options: Optional["torch.onnx.ExportOptions"] = None | |
"""Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``.""" | |
ort_session_options: Optional["onnxruntime.SessionOptions"] = None | |
"""Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``.""" | |
pre_ort_model_transforms: Optional[ # type: ignore[name-defined] | |
Sequence[Callable[["onnx.ModelProto"], None]] | |
] = None | |
"""A list of graph transforms to be applied to the ONNX model before it | |
is fed to ONNXRuntime's InferenceSession.""" | |
class OrtBackend: | |
"""A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. | |
The compiler entry point is OrtBackend.compile, which | |
1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported | |
sub-graphs. | |
2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. | |
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. | |
""" | |
def __init__(self, options: Optional[OrtBackendOptions] = None): | |
self._options: Final = OrtBackendOptions() if options is None else options | |
# options.export_options contains information shared between exporter and DORT. | |
# For example, they should use the same decomposition table when | |
# 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) | |
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model | |
# (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). | |
# | |
# Convert user-facing option to internal option used by ONNX exporter | |
# to access required information. | |
# Some useful fields: | |
# - Decomposition table for decomposing FX operators in exporter is | |
# self._resolved_onnx_exporter_options.decomposition_table. | |
# - self._resolved_onnx_exporter_options.onnx_registry records what | |
# aten/prim ops are supported by exporter and their exporters (type: callable). | |
self._resolved_onnx_exporter_options = ( | |
torch.onnx._internal.exporter.ResolvedExportOptions( | |
torch.onnx.ExportOptions() | |
if self._options.export_options is None | |
else self._options.export_options | |
) | |
) | |
# Given DORT's computation flow: | |
# 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators | |
# and send them to DORT. | |
# 2. Then, DORT exports the selected sub-graphs into ONNX. | |
# 3. Finally DORT calls ORT to do the computation. | |
# OrtOperatorSupport and create_onnx_friendly_decomposition_table(...) | |
# must use the same support_dict. If the support_dict here contains something not | |
# supported by exporter, exporter will fails in step 2 since the selected graphs may | |
# contains unsupported operators such as aten::_who_you_are. | |
# This restriction is automatically done since DORT and exporter shares the same | |
# self._resolved_onnx_exporter_options. | |
support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( | |
self._resolved_onnx_exporter_options.onnx_registry | |
) | |
extra_support_dict: Dict[str, Any] = { | |
"getattr": None, | |
# To send operator.getitem to ORT, add the corresponding string | |
# recognized by PyTorch's OperatorSupport class. | |
"_operator.getitem": None, | |
# To send operator.mul to ORT, add the corresponding string | |
# recognized by PyTorch's OperatorSupport class. | |
"_operator.mul": None, | |
"_operator.add": None, | |
"_operator.sub": None, | |
} | |
self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) | |
# TODO(wschin): this is a naive implementation of cache without proper guard | |
# See https://github.com/pytorch/pytorch/issues/106868. | |
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} | |
# Conceptually, this filed is a 2-layer dictionary | |
# GraphModule 0 | |
# ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) | |
# ONNX Model 1 | |
# ... | |
# GraphModule 1 | |
# ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) | |
# ONNX Model 3 | |
# ... | |
# ... | |
# , which caches all previous compilation result so that we can reuse them. | |
# ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs | |
# (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different | |
# graphs captured by Dynamo and sent to OrtBackend.compile. | |
self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() | |
self._assert_allclose_to_baseline = False | |
self.execution_count = 0 | |
# Function which invokes ORT do to the real computation. | |
self.run = ( | |
_run_onnx_session_with_ortvaluevector | |
if hasattr(ORTC.OrtValueVector, "push_back_batch") | |
else _run_onnx_session_with_fetch | |
) | |
def _select_eps( | |
self, graph_module: torch.fx.GraphModule, *args | |
) -> Sequence[Tuple[str, Mapping[str, Any]]]: | |
inferred_eps: Tuple[str, ...] = tuple() | |
if self._options.infer_execution_providers: | |
if eps_from_args := _infer_ep_from_device(*args): | |
# If user feeds CUDA tensor as input argument, | |
# we want to use CUDA EP. | |
# Thus, `eps_from_args` (deduced from input arguments) | |
# has highest priority. | |
inferred_eps = eps_from_args | |
elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module): | |
# If there is no EP in input arguments, we deduce EP from | |
# graph_module's outputs. Those outputs may come from | |
# FakeTensorProp or Dynamo's built-in symbolic shape inference. | |
inferred_eps = eps_from_graph_module | |
selected_eps = [] | |
for ep in ( | |
*(self._options.preferred_execution_providers or []), | |
*_sort_eps(inferred_eps), | |
*(self._options.default_execution_providers or _infer_default_eps()), | |
): | |
if isinstance(ep, str): | |
ep = (ep, {}) | |
elif isinstance(ep, tuple) and ep[1] is None: | |
ep = (ep[0], {}) | |
if ep is not None and ep not in selected_eps: | |
selected_eps.append(ep) | |
return selected_eps | |
def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): | |
"""This function replaces GraphModule._wrapped_call in compiled model. | |
The _wrapped_call is the underlying implementation of forward method. Replacing | |
it means we delegate the computation to _ort_acclerated_call and therefore | |
onnxruntime.InferenceSession. | |
""" | |
cached_execution_info_per_session = ( | |
self._all_ort_execution_info.search_reusable_session_execution_info( | |
graph_module, *args | |
) | |
) | |
if cached_execution_info_per_session: | |
onnx_session = cached_execution_info_per_session.session | |
input_names = cached_execution_info_per_session.input_names | |
output_names = cached_execution_info_per_session.output_names | |
input_value_infos = cached_execution_info_per_session.input_value_infos | |
output_value_infos = cached_execution_info_per_session.output_value_infos | |
input_devices = cached_execution_info_per_session.input_devices | |
output_devices = cached_execution_info_per_session.output_devices | |
prim_outputs = cached_execution_info_per_session.example_outputs | |
else: | |
# It's first time seeing such as graph. Let's make a new session | |
# (type: onnxruntime.InferenceSession) for it. | |
graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront( | |
self._resolved_onnx_exporter_options.diagnostic_context, | |
graph_module, | |
).run() | |
# Generate reference outputs. They are used to indicate output | |
# tensors' types and devices when calling ORT. | |
# | |
# WARNING: The downstream code should not change prim_outputs and | |
# this backend should always produces output with schema identical to prim_outputs'. | |
if self._resolved_onnx_exporter_options.dynamic_shapes: | |
# No pre-allocation when dynamic shape is enabled. | |
self.preallocate_output = False | |
extracted_outputs = _extract_graph_module_outputs(graph_module) | |
def maybe_map_to_meta_val(value): | |
if hasattr(value, "meta") and "val" in value.meta: | |
# Select outputs with "val" information. Without "val", | |
# it's not possible access output_arg.meta["val"].device. | |
return value.meta["val"] | |
else: | |
return value | |
prim_outputs = _pytree.tree_map( | |
maybe_map_to_meta_val, extracted_outputs | |
) | |
else: | |
try: | |
prim_outputs = FakeTensorProp(graph_module).propagate( | |
*args, **kwargs | |
) | |
except Exception: | |
logger.warning("FakeTensorProb failed for %s", graph_module) | |
# When FakeTensorProp fails, it is not possible to preallocate output buffers | |
# because the output shapes are not inferred. | |
self.preallocate_output = False | |
# rethrow FakeTensorProb failure because it is not yet currently handled. | |
raise | |
# Create the object to iterate through the nodes in graph one-by-one | |
# and calls the corresponding ONNX exporter for each node. | |
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( | |
diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context | |
) | |
# Cast FX variables if they will result schema-mismatch when searching | |
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch, | |
# but ONNX expects add(double_tensor, double_tensor). | |
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( | |
self._resolved_onnx_exporter_options.diagnostic_context, graph_module | |
).run() | |
# Start the per-node exporting process. It's conceptually a for loop | |
# scanning through the nodes in the graph. | |
exported = fx_interpreter.run( | |
fx_graph_module=graph_module, | |
onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher, | |
op_level_debug=self._resolved_onnx_exporter_options.op_level_debug, | |
) | |
# Convert the exported result to ONNX ModelProto. | |
onnx_model = exported.to_model_proto( | |
opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version, | |
) | |
# Modify ONNX model using pre-registered graph transforms. | |
# They are in-place modifications for avoiding unnecessary | |
# copy of ONNX initializers. | |
if self._options.pre_ort_model_transforms: | |
for transform in self._options.pre_ort_model_transforms: | |
transform(onnx_model) | |
onnx_model_bytes = onnx_model.SerializeToString() | |
if os.environ.get("ONNXRT_DUMP_PATH", None): | |
# If not empty, environment variable ONNXRT_DUMP_PATH defined the path | |
# where generated onnx files should be stored. | |
# This module keeps a global variables keeping track of the | |
# stored models. | |
# If ONNXRT_DUMP_PATH="dumped/dumped_model_" | |
# The first file name will be 'dumped/dumped_model_0.onnx'. | |
# For every dumped model, a text file 'dumped/dumped_model_0.txt' | |
# is created as well to contain the string representing the graph_module. | |
_dump_onnx_model(onnx_model_bytes, graph_module=graph_module) | |
# Initialize a ORT session to execute this ONNX model. | |
# Note that TorchDynamo assumes all inputs/outputs are on the | |
# same device, but it's subject to change (very likely with | |
# dynamic shape support), so we add execution providers | |
# based on the logic in _select_eps: (explicitly preferred EPs, | |
# EPs inferred from inputs or graph, and the fallback default EP)/ | |
# | |
# TODO(wschin): enable external allocators. | |
# See https://github.com/pytorch/pytorch/issues/106867 | |
onnx_session = onnxruntime.InferenceSession( | |
path_or_bytes=onnx_model_bytes, | |
sess_options=self._options.ort_session_options, | |
providers=self._select_eps(graph_module, *args), | |
) | |
# Cache ORT session. It's reused for the same "graph_module". | |
# Generate ONNX model and extract its input and output names. | |
input_names = tuple(input.name for input in onnx_model.graph.input) | |
output_names = tuple(output.name for output in onnx_model.graph.output) | |
input_devices = _get_onnx_devices(args) | |
# Cache devices for inputs and outputs. They are used to invoke | |
# ORT session. Output devices indicate where (e.g., GPU or CPU) | |
# to store outputs | |
if isinstance(prim_outputs, tuple): | |
output_devices = _get_onnx_devices(prim_outputs) | |
else: | |
output_devices = _get_onnx_devices((prim_outputs,)) | |
input_value_infos = tuple(input for input in onnx_model.graph.input) | |
output_value_infos = tuple(output for output in onnx_model.graph.output) | |
execution_info_per_session = OrtExecutionInfoPerSession( | |
session=onnx_session, | |
input_names=input_names, | |
input_value_infos=input_value_infos, | |
output_names=output_names, | |
output_value_infos=output_value_infos, | |
input_devices=input_devices, | |
output_devices=output_devices, | |
example_outputs=prim_outputs, | |
) | |
self._all_ort_execution_info.cache_session_execution_info( | |
graph_module, execution_info_per_session | |
) | |
self.execution_count += 1 | |
# ORT always returns a tuple of outputs. If the original output is a tensor, | |
# ORT output's first element must be extracted and returned. Otherwise, type | |
# mismatch may happen in downstream computation. | |
is_single_tensor_output = isinstance(prim_outputs, torch.Tensor) | |
normalized_prim_outputs = ( | |
(prim_outputs,) if is_single_tensor_output else prim_outputs | |
) | |
assert isinstance(normalized_prim_outputs, tuple) | |
assert all( | |
isinstance(elem, (torch.Tensor, torch.SymInt, int)) | |
for elem in normalized_prim_outputs | |
) | |
_nvtx_range_push("run_onnx_session_with_ortvaluevector") | |
onnx_outputs = self.run( | |
onnx_session, | |
input_names, | |
args, | |
input_devices, | |
output_names, | |
normalized_prim_outputs, | |
output_devices, | |
self._options.preallocate_output, | |
input_value_infos, | |
normalized_prim_outputs, | |
) | |
_nvtx_range_pop() | |
if self._assert_allclose_to_baseline: | |
# Compute baseline. | |
baseline_outputs = torch._prims.executor.execute( | |
graph_module, *args, executor="aten" | |
) | |
normalized_baseline_ouptuts = ( | |
(baseline_outputs,) if is_single_tensor_output else baseline_outputs | |
) | |
# Ensure every output tensor is close to the corresponding baseline. | |
for onnx_output, baseline_output in zip( | |
onnx_outputs, normalized_baseline_ouptuts | |
): | |
torch.testing.assert_close(onnx_output, baseline_output) | |
return onnx_outputs[0] if is_single_tensor_output else onnx_outputs | |
def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: | |
# Deferred import since CapabilityBasedPartitioner is not decorated with | |
# @compatibility; importing it at the module level will result in the test | |
# failing: pytest test/test_fx.py -k test_public_api_surface | |
# because this module is imported into torch.onnx. | |
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner | |
# FX graph based partitioning based on ONNX supported ops. | |
# Given a graph module | |
# GraphModule0 | |
# node_0 | |
# node_1 | |
# node_2 | |
# node_3 | |
# node_4 | |
# If only node_2 is not supported by ONNX, this graph module will be partitioned into | |
# GraphModule0 | |
# GraphModule1 | |
# node_0 | |
# node_1 | |
# node_2 | |
# GraphModule2 | |
# node_3 | |
# node_4 | |
# by calling CapabilityBasedPartitioner.partition_and_fuse. | |
# Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call) | |
# will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT. | |
if graph_module in self._partitioner_cache: | |
partitioned_prim_graph_module = self._partitioner_cache[graph_module] | |
else: | |
prim_graph_module = graph_module | |
partitioner = CapabilityBasedPartitioner( | |
prim_graph_module, | |
self._supported_ops, | |
allows_single_node_partition=True, | |
) | |
partitioned_prim_graph_module = partitioner.partition_and_fuse() | |
self._partitioner_cache[graph_module] = partitioned_prim_graph_module | |
# Overriding fused_module's __call__() function with ort_acclerated_call() | |
# This loop goes through all graph partitions (each of them is an ONNX-representable graph) | |
# and override their _wrapped_call function with _ort_accelerated_call. | |
# Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. | |
for node in partitioned_prim_graph_module.graph.nodes: | |
# TODO(wschin): use a better way to identify fused submodule | |
# See https://github.com/pytorch/pytorch/issues/106872. | |
if node.op == "call_module" and "fused_" in node.name: | |
fused_module = getattr(partitioned_prim_graph_module, node.name) | |
# self.ort_acclerated_call is responsible for exporting graph to ONNX, | |
# creating ORT session, and running ORT session. | |
fused_module._wrapped_call = self._ort_acclerated_call | |
return partitioned_prim_graph_module | |
def __call__( | |
self, graph_module: torch.fx.GraphModule, args | |
) -> torch.fx.GraphModule: | |
"""If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler | |
will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise, | |
the ``compile`` method is invoked directly.""" | |
if self._options.use_aot_autograd: | |
from functorch.compile import min_cut_rematerialization_partition | |
from torch._dynamo.backends.common import aot_autograd | |
return aot_autograd( | |
fw_compiler=self.compile, | |
partition_fn=min_cut_rematerialization_partition, | |
decompositions=self._resolved_onnx_exporter_options.decomposition_table, | |
)(graph_module, args) | |
return self.compile(graph_module, args) | |
__instance_cache_max_count: Final = 8 | |
__instance_cache: Final[List["OrtBackend"]] = [] | |
def get_cached_instance_for_options( | |
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, | |
) -> "OrtBackend": | |
"""Returns a possibly cached instance of an ``OrtBackend``. If an existing | |
backend was created previously through this function with the same options, | |
it will be returned. Otherwise a new backend will be created, cached, and | |
returned. | |
Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend`` | |
will always be returned, since ``onnxruntime.SessionOptions`` cannot | |
participate in caching.""" | |
def reusable(a: OrtBackendOptions, b: OrtBackendOptions): | |
if ( | |
a.preferred_execution_providers != b.preferred_execution_providers | |
or a.infer_execution_providers != b.infer_execution_providers | |
or a.default_execution_providers != b.default_execution_providers | |
or a.preallocate_output != b.preallocate_output | |
or a.use_aot_autograd != b.use_aot_autograd | |
or a.pre_ort_model_transforms != b.pre_ort_model_transforms | |
): | |
return False | |
# onnxruntime.SessionOptions is a pybind11 object, cannot be pickled, | |
# and holds too much potential state to reasonably check manually; | |
# ort_session_options is provided at all, the backend does not participate | |
# in caching. | |
if a.ort_session_options is not None or b.ort_session_options is not None: | |
return False | |
if a.export_options is b.export_options: | |
return True | |
# Similarly, some objects in ExportOptions are too stateful to use for | |
# caching. We should revisit this. | |
if a.export_options is not None and b.export_options is not None: | |
return ( | |
a.export_options.dynamic_shapes == b.export_options.dynamic_shapes | |
and a.export_options.op_level_debug | |
== b.export_options.op_level_debug | |
and a.export_options.diagnostic_options | |
== b.export_options.diagnostic_options | |
and a.export_options.onnx_registry is b.export_options.onnx_registry | |
and a.export_options.fake_context is b.export_options.fake_context | |
) | |
# We can't account for how the two option sets may differ, so it's not safe to reuse. | |
return False | |
if not isinstance(options, OrtBackendOptions): | |
options = OrtBackendOptions(**(options or {})) | |
backend = next( | |
(b for b in OrtBackend.__instance_cache if reusable(b._options, options)), | |
None, | |
) | |
if backend is None: | |
assert ( | |
len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count | |
), ( | |
f"No more than {OrtBackend.__instance_cache_max_count} instances of " | |
f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly " | |
"to pass to `torch.compile`. " | |
"See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 " | |
"for discussion." | |
) | |
OrtBackend.__instance_cache.append(backend := OrtBackend(options)) | |
return backend | |
def clear_cached_instances(): | |
OrtBackend.__instance_cache.clear() | |
def get_cached_instances(): | |
return tuple(OrtBackend.__instance_cache) | |
def torch_compile_backend( | |
graph_module: torch.fx.GraphModule, | |
args, | |
*, | |
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None, | |
): | |
return OrtBackend.get_cached_instance_for_options(options)(graph_module, args) | |