Spaces:
Running
Running
File size: 52,132 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 |
import dataclasses
import importlib
import logging
import os
from typing import (
Any,
Callable,
Dict,
Final,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
from typing_extensions import TypeAlias
import torch
import torch._C
import torch._ops
import torch._prims.executor
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx._compatibility import compatibility
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.utils import _pytree
try:
# Use try-except to initialize package-dependent global variables.
import onnx
import onnxruntime # type: ignore[import]
from onnxruntime.capi import _pybind_state as ORTC # type: ignore[import]
# This is not use directly in DORT but needed by underlying exporter,
# so we still need to check if it exists.
importlib.import_module("onnxscript")
import torch.onnx
import torch.onnx._internal
import torch.onnx._internal.diagnostics
import torch.onnx._internal.exporter
import torch.onnx._internal.fx.decomposition_table
import torch.onnx._internal.fx.passes
from torch.onnx._internal.fx import fx_onnx_interpreter
from torch.onnx._internal.fx.type_utils import (
_TORCH_DTYPE_TO_NUMPY_DTYPE,
_TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
from_python_type_to_onnx_tensor_element_type,
)
_SUPPORT_ONNXRT = True
except ImportError:
_SUPPORT_ONNXRT = False
__all__ = [
"is_onnxrt_backend_supported",
"torch_compile_backend",
"OrtExecutionProvider",
"OrtBackendOptions",
"OrtBackend",
]
def is_onnxrt_backend_supported() -> bool:
"""Returns ``True`` if ONNX Runtime dependencies are installed and usable
to support TorchDynamo backend integration; ``False`` otherwise.
Example::
# xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
>>> import torch
>>> if torch.onnx.is_onnxrt_backend_supported():
... @torch.compile(backend="onnxrt")
... def f(x):
... return x * x
... print(f(torch.randn(10)))
... else:
... print("pip install onnx onnxscript onnxruntime")
...
"""
return _SUPPORT_ONNXRT
_dumped_onnx_model: Dict[str, int] = {}
def _dump_onnx_model(
model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None
) -> str:
"""Stores the onnx model into a file.
The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
where *N* is the number of files already stored with
this prefix.
If graph_module is not None, the graph is stored as a string with
the same filename except the extension (.txt).
"""
prefix = os.environ.get("ONNXRT_DUMP_PATH", None)
if not prefix:
return ""
n = _dumped_onnx_model.get(prefix, -1) + 1
filename = f"{prefix}{n}.onnx"
with open(filename, "wb") as f:
f.write(model_string)
_dumped_onnx_model[prefix] = n
if graph_module is not None:
filename_txt = f"{prefix}{n}.txt"
with open(filename_txt, "w", encoding="utf-8") as f:
f.write(str(graph_module.graph))
return filename
def _infer_default_eps() -> Sequence[str]:
# TODO: select a good default based on the capabilities of the host
# e.g. DML on Windows, etc.
return ["CPUExecutionProvider"]
def _nvtx_range_push(name: str):
"""If PyTorch is installed with CUDA support, this starts NVTX range.
Check torch.cuda.nvtx.range_push's document for more details.
"""
if torch.cuda.is_available():
torch.cuda.nvtx.range_push(name)
def _nvtx_range_pop():
"""If PyTorch is installed with CUDA support, this terminates NVTX range.
Check torch.cuda.nvtx.range_pop's document for more details.
"""
if torch.cuda.is_available():
torch.cuda.nvtx.range_pop()
def _get_ort_device_type(device_type: str):
if device_type == "cuda":
return ORTC.OrtDevice.cuda()
if device_type == "cpu":
return ORTC.OrtDevice.cpu()
# ort pytorch device is mapped to NPU OrtDevice type
if device_type == "ort":
return ORTC.OrtDevice.npu()
raise ValueError("Unsupported device type: " + device_type)
logger = logging.getLogger(__name__)
# Uncomment the following lines to print out development info.
# logging.basicConfig(level=logging.WARNING)
# logger.setLevel(logging.WARNING)
class OrtOperatorSupport(OperatorSupport):
"""Operator support for ONNXRuntime backend.
It has two-level of support decision. One is via support_dict and the other one
is via extra_support_dict. The logic of using support_dict is implemented in
OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
"""
def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
# Use extra_support_dict[op_name] = None to indicate
# we support op_name with all input types. Otherwise,
# see support_dict (type: SupportDict) in operator_support.py
# for specifying supported types.
super().__init__(extra_support_dict)
self._onnx_support_dict = support_dict
def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
# OperatorSupport.is_node_supported returns True for non-callable nodes.
# Since ORT can't execute them, we return False here to override the base
# behavior.
if node.op not in CALLABLE_NODE_OPS:
return False
# This is the and the only place to decide if aten op is supported.
if node.op == "call_function" and node.target in self._onnx_support_dict:
logger.warning(
"support_dict supports node.target: %s (type: %s)",
node.target,
type(node.target),
)
return True
# If node.target is not in support_dict, we still want to check if torch.jit.script
# can convert it to ONNX equivalence. Let's use base mechanism to do this.
# See extra_support_dict for supported ops.
if super().is_node_supported(submodules, node):
logger.warning(
"extra_support_dict supports node.target: %s (type: %s)",
node.target,
type(node.target),
)
return True
logger.warning(
"support_dict and extra_support_dict don't support node.target: %s (type: %s)",
node.target,
type(node.target),
)
return False
def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
"""
In torch.fx.Graph, placeholder is a special assignment node. If it's not
executed in the beginning, it could overwrite values computed by upstream
nodes.
"""
graph = graph_module.graph
placeholders = []
first_not_placeholder = None
for node in graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
if first_not_placeholder is None and node.op != "placeholder":
first_not_placeholder = node
if first_not_placeholder is None:
return
for placeholder in placeholders:
first_not_placeholder.prepend(placeholder)
def _infer_ep_from_device(*args) -> Tuple[str, ...]:
"""Return the first valid device (i.e., GPU or CPU) in argument list."""
eps = []
for arg in args:
if hasattr(arg, "device"):
device = arg.device
if device.type == "cuda":
eps.append("CUDAExecutionProvider")
elif device.type == "cpu":
eps.append("CPUExecutionProvider")
return tuple(eps)
def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
placeholders = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
if hasattr(node, "meta") and "val" in node.meta:
assert isinstance(node.meta["val"], torch.Tensor)
placeholders.append(node)
return tuple(placeholders)
def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
"""Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
for node in graph_module.graph.nodes:
if node.op == "output":
# Output node is unique. Let's retrieve output values from
# this node's input list. And then just return.
return node.args[0]
raise ValueError("No output node found in this torch.fx.GraphModule.")
def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
"""Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
flattened_output_args, _ = _pytree.tree_flatten(
_extract_graph_module_outputs(graph_module)
)
# Output arguments with example value (type: torch.Tensor) in the `graph_module`.
selected_output_args = [
output_arg.meta["val"]
for output_arg in flattened_output_args
# output_arg must have tensor for its device information.
# Otherwise, skip it.
if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
]
return _infer_ep_from_device(*selected_output_args)
def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
"""Sort execution providers in eps based on pre-set priority."""
def get_execution_provider_priority(ep: str) -> int:
if ep == "CPUExecutionProvider":
# Lowest priority.
return 2
if ep == "CUDAExecutionProvider":
# Higher priority than CPU but lower than
# other specialized EPs.
return 1
# Highest priority.
return 0
unique_eps = set(eps)
return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))
def _get_onnx_devices(
values: Tuple[
Union[
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
],
...,
]
) -> Tuple["ORTC.OrtDevice", ...]:
def _device_id_or_zero(device_id: int) -> int:
return device_id or 0
def _map_tensor_or_sym_to_device(
value: Union[
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
],
) -> int:
if isinstance(value, torch.Tensor):
return ORTC.OrtDevice(
_get_ort_device_type(value.device.type),
ORTC.OrtDevice.default_memory(),
_device_id_or_zero(value.device.index),
)
elif isinstance(
value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool)
):
return ORTC.OrtDevice(
_get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0
)
else:
raise ValueError("Unsupported value type: " + str(type(value)))
if len(values) > 0:
ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values)
return ort_devices
else:
return (_map_tensor_or_sym_to_device(1),)
def _get_ortvalues_from_torch_tensors(
tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
) -> Tuple[torch.Tensor, ...]:
ortvalues = ORTC.OrtValueVector()
ortvalues.reserve(len(tensors))
dtypes = []
shapes = []
data_ptrs = []
for tensor in tensors:
dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
shapes.append(tensor.size())
data_ptrs.append(tensor.data_ptr())
ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
return ortvalues
def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
if tensor.is_sparse:
raise ValueError("sparse tensor is not yet supported.")
out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
return out
def _adjust_scalar_from_fx_to_onnx(
dynamo_value: Union[
torch.Tensor,
int,
float,
bool,
],
value_info: "onnx.ValueInfoProto", # type: ignore[name-defined]
) -> torch.Tensor:
"""Helper function to wrap PyTorch variables as torch.Tensor"""
if (
isinstance(dynamo_value, torch.Tensor)
and len(value_info.type.tensor_type.shape.dim) == 0
and dynamo_value.shape == (1,)
):
# ONNX expect a scalar with empty shape.
# In contrast, PyTorch usually allows implicit
# conversion between shape=() and shape=(1,).
#
# Below, PyTorch's shape (1,) is reshaped to ().
return torch.squeeze(dynamo_value)
elif isinstance(dynamo_value, int):
return torch.tensor(dynamo_value, dtype=torch.int64)
elif isinstance(dynamo_value, float):
return torch.tensor(dynamo_value, dtype=torch.float32)
elif isinstance(dynamo_value, bool):
return torch.tensor(dynamo_value, dtype=torch.bool)
else:
assert isinstance(dynamo_value, torch.Tensor)
return dynamo_value.contiguous()
def _adjust_scalar_from_onnx_to_fx(
tensor: torch.Tensor,
prim_value: Union[
torch.Tensor,
torch.SymInt,
int,
torch.SymFloat,
float,
torch.SymBool,
bool,
],
) -> Union[torch.Tensor, int, float, bool,]:
"""Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
if isinstance(
prim_value,
(torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool),
):
# Convert tensor back to scalar to match Dynamo's expectation.
return tensor.item()
return tensor
def _run_onnx_session_with_ortvaluevector(
sess: "onnxruntime.InferenceSession",
input_names: Tuple[str, ...],
inputs: Tuple[torch.Tensor, ...],
input_devices: Tuple["ORTC.OrtDevice", ...],
output_names: Tuple[str, ...],
outputs: Tuple[torch.Tensor, ...],
output_devices: Tuple["ORTC.OrtDevice", ...],
preallocate_output: bool,
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
normalized_prim_outputs: Tuple[
Union[
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
],
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
_nvtx_range_push("contiguous")
inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
for arg, value_info in zip(inputs, input_value_infos)
)
_nvtx_range_pop()
_nvtx_range_push("push_back_batch")
ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)
# preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
# Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
# to torch Tensor transferring the ownership.
if preallocate_output:
pth_outputs = tuple(
_to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs
)
ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
else:
ort_outputs = ORTC.OrtValueVector()
_nvtx_range_pop()
_nvtx_range_push("run_with_ortvaluevector")
run_options = onnxruntime.RunOptions()
run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
sess.run_with_ortvaluevector(
run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices
)
_nvtx_range_pop()
# Post-processing step:
# wrap ORT's outputs to the schema represented by
# `prim_output` (obtained by running the original
# torch.fx.GraphModule).
if preallocate_output:
# Profile the ORT-to-PyTorch type cast below
_nvtx_range_push("after run_with_ortvaluevector")
# Outputs are stored on pre-allocated torch.Tensors' memory,
# so this case doesn't need to convert ORTValue to torch.Tensor.
pth_outputs = tuple(
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc]
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
)
_nvtx_range_pop()
return pth_outputs
else:
# Profile the two ORT-to-PyTorch type casts below
_nvtx_range_push("after run_with_ortvaluevector")
# Map ORTValue to torch.Tensor.
pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(
ort_outputs
)
# Change some torch.Tensor to int, float, bool.
pth_outputs = tuple(
_adjust_scalar_from_onnx_to_fx(onnx_output, prim_output) # type: ignore[misc]
for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
)
_nvtx_range_pop()
return pth_outputs
def _run_onnx_session_with_fetch(
sess: "onnxruntime.InferenceSession",
input_names: Tuple[str, ...],
inputs: Tuple[torch.Tensor, ...],
input_devices: Tuple["ORTC.OrtDevice", ...],
output_names: Tuple[str, ...],
outputs: Tuple[torch.Tensor, ...],
output_devices: Tuple["ORTC.OrtDevice", ...],
preallocate_output: bool,
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
normalized_prim_outputs: Tuple[
Union[
torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
],
...,
],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
inputs = tuple(
_adjust_scalar_from_fx_to_onnx(arg, value_info)
for arg, value_info in zip(inputs, input_value_infos)
)
feed = {
name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy())
for name, tensor in zip(input_names, inputs)
}
ort_outputs = sess.run(output_names, feed)
pth_outputs = tuple(
_adjust_scalar_from_onnx_to_fx(
torch.from_numpy(value),
prim_output,
)
for value, prim_output in zip(ort_outputs, normalized_prim_outputs)
)
return pth_outputs
class OrtExecutionInfoPerSession:
"""Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""
def __init__(
self,
session: "onnxruntime.InferenceSession",
input_names: Tuple[str, ...],
input_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
output_names: Tuple[str, ...],
output_value_infos: Tuple["onnx.ValueInfoProto", ...], # type: ignore[name-defined]
input_devices: Tuple["ORTC.OrtDevice", ...],
output_devices: Tuple["ORTC.OrtDevice", ...],
example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
):
# Carrier of ONNX model and its executor.
self.session: onnxruntime.InferenceSession = session
# For the ONNX model stored in self.session, self.input_names[i] is the
# name of the i-th positional input.
self.input_names: Tuple[str, ...] = input_names
# self.input_name[i]'s type information is stored in self.input_value_infos[i].
self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos # type: ignore[name-defined]
# Similar to self.input_names, but for outputs.
self.output_names: Tuple[str, ...] = output_names
# Similar to self.input_value_infos but for outputs.
self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos # type: ignore[name-defined]
# For the ONNX model stored in self.session, self.input_devices[i] is the
# i-th positional input's device.
self.input_devices: Tuple["ORTC.OrtDevice", ...] = input_devices
# Similar to self.input_devices, but for outputs.
self.output_devices: Tuple["ORTC.OrtDevice", ...] = output_devices
# This is the outputs of executing the original torch.fx.GraphModule with example inputs
# (i.e., args passed into OrtBackend._ort_acclerated_call).
self.example_outputs: Union[
Tuple[torch.Tensor, ...], torch.Tensor
] = example_outputs
def is_supported(self, *args):
# Compare the args and the input schema in ONNX model and
# return the first match.
if len(args) != len(self.input_value_infos):
return False
for arg, value_info in zip(args, self.input_value_infos):
if not isinstance(arg, (torch.Tensor, float, int)):
return False
# Check Python scalars such as int, float, and bool.
if isinstance(arg, (int, float, bool)):
# Map, e.g., float to onnx.TensorProto.FLOAT.
onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg))
if onnx_dtype != value_info.type.tensor_type.elem_type:
return False
if len(value_info.type.tensor_type.shape.dim) != 0:
return False
continue
# Check tensor.
onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype]
if onnx_dtype != value_info.type.tensor_type.elem_type:
return False
for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
if isinstance(dim, int) and (
onnx_dim.dim_value == dim or onnx_dim.dim_param
):
continue
elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
continue
else:
return False
return True
@dataclasses.dataclass
class OrtExecutionInfoForAllGraphModules:
def __init__(self):
# All sessions (and their related information) created by exporting the same GraphModule
# with different inputs.
self.execution_info_per_graph_module: Dict[
torch.fx.GraphModule, List[OrtExecutionInfoPerSession]
] = {}
def search_reusable_session_execution_info(
self, graph_module: torch.fx.GraphModule, *args
):
if graph_module not in self.execution_info_per_graph_module:
return None
# All execution information for ONNX models exported from the same `graph_module`
# with different inputs.
candidates = self.execution_info_per_graph_module[graph_module]
for candidate in candidates:
if candidate.is_supported(*args):
# Returns the first session that accepts this input schema.
return candidate
# No reusable session found.
return None
def cache_session_execution_info(
self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession
):
if graph_module not in self.execution_info_per_graph_module:
self.execution_info_per_graph_module[graph_module] = [info]
else:
self.execution_info_per_graph_module[graph_module].append(info)
OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]]
"""Either the name of an ONNX Runtime execution provider as a string or
a 2-tuple of the name and a dictionary of execution provider options.
Examples::
>>> "CPUExecutionProvider"
>>> ("CUDAExecutionProvider", {"device_id": 3})
"""
@dataclasses.dataclass(frozen=True)
@compatibility(is_backward_compatible=False)
class OrtBackendOptions:
"""Options for constructing an ``OrtBackend``, the ONNX Runtime
backend (``"onnxrt"``) for ``torch.compile``.
Example::
>>> @torch.compile(
... backend="onnxrt",
... options=torch.onnx._OrtBackendOptions(...),
... )
... def ort_function(x):
... return x ** x
"""
preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
"""An optional sequence of execution providers to be prioritized ahead of any
execution providers that may be inferred (see ``infer_execution_providers``).
"""
infer_execution_providers: bool = True
"""Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph."""
default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
"""The default fallback execution providers. If not specified, one will be
be selected based on the host environment (most likely ``"CPUExecutionProvider"``).
"""
# preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
# in order to avoid internal allocation of output buffers in InferenceSession.
# If output ortvalue returned from InferenceSession is allocated internally,
# it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
# When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
# should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
# It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
# and use the preallocated output buffers for InferenceSession not holding any ownership for them.
# TODO(wschin): Make it to inference session level flag.
# See https://github.com/pytorch/pytorch/issues/106869.
preallocate_output: bool = False
"""If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side."""
use_aot_autograd: bool = True
"""Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend
to support training (i.e., backward graphs are also sent to ``OrtBackend``).
Symbolic execution is used to capture the forward pass and backward passes as a single graph.
Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used
to split the entire graph into forward sub-graph and backward sub-graph. Finally, both
sub-graphs are compiled by ``OrtBackend``.
"""
export_options: Optional["torch.onnx.ExportOptions"] = None
"""Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``."""
ort_session_options: Optional["onnxruntime.SessionOptions"] = None
"""Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``."""
pre_ort_model_transforms: Optional[ # type: ignore[name-defined]
Sequence[Callable[["onnx.ModelProto"], None]]
] = None
"""A list of graph transforms to be applied to the ONNX model before it
is fed to ONNXRuntime's InferenceSession."""
@compatibility(is_backward_compatible=False)
class OrtBackend:
"""A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.
The compiler entry point is OrtBackend.compile, which
1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
sub-graphs.
2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
"""
def __init__(self, options: Optional[OrtBackendOptions] = None):
self._options: Final = OrtBackendOptions() if options is None else options
# options.export_options contains information shared between exporter and DORT.
# For example, they should use the same decomposition table when
# 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
# 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
# (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
#
# Convert user-facing option to internal option used by ONNX exporter
# to access required information.
# Some useful fields:
# - Decomposition table for decomposing FX operators in exporter is
# self._resolved_onnx_exporter_options.decomposition_table.
# - self._resolved_onnx_exporter_options.onnx_registry records what
# aten/prim ops are supported by exporter and their exporters (type: callable).
self._resolved_onnx_exporter_options = (
torch.onnx._internal.exporter.ResolvedExportOptions(
torch.onnx.ExportOptions()
if self._options.export_options is None
else self._options.export_options
)
)
# Given DORT's computation flow:
# 1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators
# and send them to DORT.
# 2. Then, DORT exports the selected sub-graphs into ONNX.
# 3. Finally DORT calls ORT to do the computation.
# OrtOperatorSupport and create_onnx_friendly_decomposition_table(...)
# must use the same support_dict. If the support_dict here contains something not
# supported by exporter, exporter will fails in step 2 since the selected graphs may
# contains unsupported operators such as aten::_who_you_are.
# This restriction is automatically done since DORT and exporter shares the same
# self._resolved_onnx_exporter_options.
support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
self._resolved_onnx_exporter_options.onnx_registry
)
extra_support_dict: Dict[str, Any] = {
"getattr": None,
# To send operator.getitem to ORT, add the corresponding string
# recognized by PyTorch's OperatorSupport class.
"_operator.getitem": None,
# To send operator.mul to ORT, add the corresponding string
# recognized by PyTorch's OperatorSupport class.
"_operator.mul": None,
"_operator.add": None,
"_operator.sub": None,
}
self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
# TODO(wschin): this is a naive implementation of cache without proper guard
# See https://github.com/pytorch/pytorch/issues/106868.
self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
# Conceptually, this filed is a 2-layer dictionary
# GraphModule 0
# ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
# ONNX Model 1
# ...
# GraphModule 1
# ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
# ONNX Model 3
# ...
# ...
# , which caches all previous compilation result so that we can reuse them.
# ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
# (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
# graphs captured by Dynamo and sent to OrtBackend.compile.
self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()
self._assert_allclose_to_baseline = False
self.execution_count = 0
# Function which invokes ORT do to the real computation.
self.run = (
_run_onnx_session_with_ortvaluevector
if hasattr(ORTC.OrtValueVector, "push_back_batch")
else _run_onnx_session_with_fetch
)
def _select_eps(
self, graph_module: torch.fx.GraphModule, *args
) -> Sequence[Tuple[str, Mapping[str, Any]]]:
inferred_eps: Tuple[str, ...] = tuple()
if self._options.infer_execution_providers:
if eps_from_args := _infer_ep_from_device(*args):
# If user feeds CUDA tensor as input argument,
# we want to use CUDA EP.
# Thus, `eps_from_args` (deduced from input arguments)
# has highest priority.
inferred_eps = eps_from_args
elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module):
# If there is no EP in input arguments, we deduce EP from
# graph_module's outputs. Those outputs may come from
# FakeTensorProp or Dynamo's built-in symbolic shape inference.
inferred_eps = eps_from_graph_module
selected_eps = []
for ep in (
*(self._options.preferred_execution_providers or []),
*_sort_eps(inferred_eps),
*(self._options.default_execution_providers or _infer_default_eps()),
):
if isinstance(ep, str):
ep = (ep, {})
elif isinstance(ep, tuple) and ep[1] is None:
ep = (ep[0], {})
if ep is not None and ep not in selected_eps:
selected_eps.append(ep)
return selected_eps
def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
"""This function replaces GraphModule._wrapped_call in compiled model.
The _wrapped_call is the underlying implementation of forward method. Replacing
it means we delegate the computation to _ort_acclerated_call and therefore
onnxruntime.InferenceSession.
"""
cached_execution_info_per_session = (
self._all_ort_execution_info.search_reusable_session_execution_info(
graph_module, *args
)
)
if cached_execution_info_per_session:
onnx_session = cached_execution_info_per_session.session
input_names = cached_execution_info_per_session.input_names
output_names = cached_execution_info_per_session.output_names
input_value_infos = cached_execution_info_per_session.input_value_infos
output_value_infos = cached_execution_info_per_session.output_value_infos
input_devices = cached_execution_info_per_session.input_devices
output_devices = cached_execution_info_per_session.output_devices
prim_outputs = cached_execution_info_per_session.example_outputs
else:
# It's first time seeing such as graph. Let's make a new session
# (type: onnxruntime.InferenceSession) for it.
graph_module = torch.onnx._internal.fx.passes.MovePlaceholderToFront(
self._resolved_onnx_exporter_options.diagnostic_context,
graph_module,
).run()
# Generate reference outputs. They are used to indicate output
# tensors' types and devices when calling ORT.
#
# WARNING: The downstream code should not change prim_outputs and
# this backend should always produces output with schema identical to prim_outputs'.
if self._resolved_onnx_exporter_options.dynamic_shapes:
# No pre-allocation when dynamic shape is enabled.
self.preallocate_output = False
extracted_outputs = _extract_graph_module_outputs(graph_module)
def maybe_map_to_meta_val(value):
if hasattr(value, "meta") and "val" in value.meta:
# Select outputs with "val" information. Without "val",
# it's not possible access output_arg.meta["val"].device.
return value.meta["val"]
else:
return value
prim_outputs = _pytree.tree_map(
maybe_map_to_meta_val, extracted_outputs
)
else:
try:
prim_outputs = FakeTensorProp(graph_module).propagate(
*args, **kwargs
)
except Exception:
logger.warning("FakeTensorProb failed for %s", graph_module)
# When FakeTensorProp fails, it is not possible to preallocate output buffers
# because the output shapes are not inferred.
self.preallocate_output = False
# rethrow FakeTensorProb failure because it is not yet currently handled.
raise
# Create the object to iterate through the nodes in graph one-by-one
# and calls the corresponding ONNX exporter for each node.
fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
)
# Cast FX variables if they will result schema-mismatch when searching
# for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
# but ONNX expects add(double_tensor, double_tensor).
graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion(
self._resolved_onnx_exporter_options.diagnostic_context, graph_module
).run()
# Start the per-node exporting process. It's conceptually a for loop
# scanning through the nodes in the graph.
exported = fx_interpreter.run(
fx_graph_module=graph_module,
onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
op_level_debug=self._resolved_onnx_exporter_options.op_level_debug,
)
# Convert the exported result to ONNX ModelProto.
onnx_model = exported.to_model_proto(
opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
)
# Modify ONNX model using pre-registered graph transforms.
# They are in-place modifications for avoiding unnecessary
# copy of ONNX initializers.
if self._options.pre_ort_model_transforms:
for transform in self._options.pre_ort_model_transforms:
transform(onnx_model)
onnx_model_bytes = onnx_model.SerializeToString()
if os.environ.get("ONNXRT_DUMP_PATH", None):
# If not empty, environment variable ONNXRT_DUMP_PATH defined the path
# where generated onnx files should be stored.
# This module keeps a global variables keeping track of the
# stored models.
# If ONNXRT_DUMP_PATH="dumped/dumped_model_"
# The first file name will be 'dumped/dumped_model_0.onnx'.
# For every dumped model, a text file 'dumped/dumped_model_0.txt'
# is created as well to contain the string representing the graph_module.
_dump_onnx_model(onnx_model_bytes, graph_module=graph_module)
# Initialize a ORT session to execute this ONNX model.
# Note that TorchDynamo assumes all inputs/outputs are on the
# same device, but it's subject to change (very likely with
# dynamic shape support), so we add execution providers
# based on the logic in _select_eps: (explicitly preferred EPs,
# EPs inferred from inputs or graph, and the fallback default EP)/
#
# TODO(wschin): enable external allocators.
# See https://github.com/pytorch/pytorch/issues/106867
onnx_session = onnxruntime.InferenceSession(
path_or_bytes=onnx_model_bytes,
sess_options=self._options.ort_session_options,
providers=self._select_eps(graph_module, *args),
)
# Cache ORT session. It's reused for the same "graph_module".
# Generate ONNX model and extract its input and output names.
input_names = tuple(input.name for input in onnx_model.graph.input)
output_names = tuple(output.name for output in onnx_model.graph.output)
input_devices = _get_onnx_devices(args)
# Cache devices for inputs and outputs. They are used to invoke
# ORT session. Output devices indicate where (e.g., GPU or CPU)
# to store outputs
if isinstance(prim_outputs, tuple):
output_devices = _get_onnx_devices(prim_outputs)
else:
output_devices = _get_onnx_devices((prim_outputs,))
input_value_infos = tuple(input for input in onnx_model.graph.input)
output_value_infos = tuple(output for output in onnx_model.graph.output)
execution_info_per_session = OrtExecutionInfoPerSession(
session=onnx_session,
input_names=input_names,
input_value_infos=input_value_infos,
output_names=output_names,
output_value_infos=output_value_infos,
input_devices=input_devices,
output_devices=output_devices,
example_outputs=prim_outputs,
)
self._all_ort_execution_info.cache_session_execution_info(
graph_module, execution_info_per_session
)
self.execution_count += 1
# ORT always returns a tuple of outputs. If the original output is a tensor,
# ORT output's first element must be extracted and returned. Otherwise, type
# mismatch may happen in downstream computation.
is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
normalized_prim_outputs = (
(prim_outputs,) if is_single_tensor_output else prim_outputs
)
assert isinstance(normalized_prim_outputs, tuple)
assert all(
isinstance(elem, (torch.Tensor, torch.SymInt, int))
for elem in normalized_prim_outputs
)
_nvtx_range_push("run_onnx_session_with_ortvaluevector")
onnx_outputs = self.run(
onnx_session,
input_names,
args,
input_devices,
output_names,
normalized_prim_outputs,
output_devices,
self._options.preallocate_output,
input_value_infos,
normalized_prim_outputs,
)
_nvtx_range_pop()
if self._assert_allclose_to_baseline:
# Compute baseline.
baseline_outputs = torch._prims.executor.execute(
graph_module, *args, executor="aten"
)
normalized_baseline_ouptuts = (
(baseline_outputs,) if is_single_tensor_output else baseline_outputs
)
# Ensure every output tensor is close to the corresponding baseline.
for onnx_output, baseline_output in zip(
onnx_outputs, normalized_baseline_ouptuts
):
torch.testing.assert_close(onnx_output, baseline_output)
return onnx_outputs[0] if is_single_tensor_output else onnx_outputs
def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
# Deferred import since CapabilityBasedPartitioner is not decorated with
# @compatibility; importing it at the module level will result in the test
# failing: pytest test/test_fx.py -k test_public_api_surface
# because this module is imported into torch.onnx.
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
# FX graph based partitioning based on ONNX supported ops.
# Given a graph module
# GraphModule0
# node_0
# node_1
# node_2
# node_3
# node_4
# If only node_2 is not supported by ONNX, this graph module will be partitioned into
# GraphModule0
# GraphModule1
# node_0
# node_1
# node_2
# GraphModule2
# node_3
# node_4
# by calling CapabilityBasedPartitioner.partition_and_fuse.
# Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call)
# will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT.
if graph_module in self._partitioner_cache:
partitioned_prim_graph_module = self._partitioner_cache[graph_module]
else:
prim_graph_module = graph_module
partitioner = CapabilityBasedPartitioner(
prim_graph_module,
self._supported_ops,
allows_single_node_partition=True,
)
partitioned_prim_graph_module = partitioner.partition_and_fuse()
self._partitioner_cache[graph_module] = partitioned_prim_graph_module
# Overriding fused_module's __call__() function with ort_acclerated_call()
# This loop goes through all graph partitions (each of them is an ONNX-representable graph)
# and override their _wrapped_call function with _ort_accelerated_call.
# Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
for node in partitioned_prim_graph_module.graph.nodes:
# TODO(wschin): use a better way to identify fused submodule
# See https://github.com/pytorch/pytorch/issues/106872.
if node.op == "call_module" and "fused_" in node.name:
fused_module = getattr(partitioned_prim_graph_module, node.name)
# self.ort_acclerated_call is responsible for exporting graph to ONNX,
# creating ORT session, and running ORT session.
fused_module._wrapped_call = self._ort_acclerated_call
return partitioned_prim_graph_module
def __call__(
self, graph_module: torch.fx.GraphModule, args
) -> torch.fx.GraphModule:
"""If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
the ``compile`` method is invoked directly."""
if self._options.use_aot_autograd:
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
return aot_autograd(
fw_compiler=self.compile,
partition_fn=min_cut_rematerialization_partition,
decompositions=self._resolved_onnx_exporter_options.decomposition_table,
)(graph_module, args)
return self.compile(graph_module, args)
__instance_cache_max_count: Final = 8
__instance_cache: Final[List["OrtBackend"]] = []
@staticmethod
def get_cached_instance_for_options(
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
) -> "OrtBackend":
"""Returns a possibly cached instance of an ``OrtBackend``. If an existing
backend was created previously through this function with the same options,
it will be returned. Otherwise a new backend will be created, cached, and
returned.
Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
will always be returned, since ``onnxruntime.SessionOptions`` cannot
participate in caching."""
def reusable(a: OrtBackendOptions, b: OrtBackendOptions):
if (
a.preferred_execution_providers != b.preferred_execution_providers
or a.infer_execution_providers != b.infer_execution_providers
or a.default_execution_providers != b.default_execution_providers
or a.preallocate_output != b.preallocate_output
or a.use_aot_autograd != b.use_aot_autograd
or a.pre_ort_model_transforms != b.pre_ort_model_transforms
):
return False
# onnxruntime.SessionOptions is a pybind11 object, cannot be pickled,
# and holds too much potential state to reasonably check manually;
# ort_session_options is provided at all, the backend does not participate
# in caching.
if a.ort_session_options is not None or b.ort_session_options is not None:
return False
if a.export_options is b.export_options:
return True
# Similarly, some objects in ExportOptions are too stateful to use for
# caching. We should revisit this.
if a.export_options is not None and b.export_options is not None:
return (
a.export_options.dynamic_shapes == b.export_options.dynamic_shapes
and a.export_options.op_level_debug
== b.export_options.op_level_debug
and a.export_options.diagnostic_options
== b.export_options.diagnostic_options
and a.export_options.onnx_registry is b.export_options.onnx_registry
and a.export_options.fake_context is b.export_options.fake_context
)
# We can't account for how the two option sets may differ, so it's not safe to reuse.
return False
if not isinstance(options, OrtBackendOptions):
options = OrtBackendOptions(**(options or {}))
backend = next(
(b for b in OrtBackend.__instance_cache if reusable(b._options, options)),
None,
)
if backend is None:
assert (
len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count
), (
f"No more than {OrtBackend.__instance_cache_max_count} instances of "
f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly "
"to pass to `torch.compile`. "
"See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 "
"for discussion."
)
OrtBackend.__instance_cache.append(backend := OrtBackend(options))
return backend
@staticmethod
def clear_cached_instances():
OrtBackend.__instance_cache.clear()
@staticmethod
def get_cached_instances():
return tuple(OrtBackend.__instance_cache)
@compatibility(is_backward_compatible=False)
def torch_compile_backend(
graph_module: torch.fx.GraphModule,
args,
*,
options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
):
return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)
|