Spaces:
Running
Running
File size: 55,624 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 |
import itertools
import logging
import operator
import os
import re
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple
import sympy
import torch
import torch._logging
import torch.fx
from torch._decomp import get_decompositions
from torch._dynamo.utils import defake, dynamo_timed
from torch._logging import LazyString, trace_structured
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
from torch.fx.experimental.symbolic_shapes import has_free_symbols, ShapeEnv, SymTypes
from torch.utils._mode_utils import no_dispatch
from . import config, ir
from .codegen.common import (
DeviceOpOverrides,
get_device_op_overrides,
get_scheduling_for_device,
get_wrapper_codegen_for_device,
register_backend_for_device,
)
from .codegen.cpp_wrapper_cpu import CppWrapperCpu
from .codegen.cpp_wrapper_cuda import CppWrapperCuda
from .codegen.wrapper import WrapperCodeGen
from .exc import (
CppWrapperCodeGenError,
LoweringException,
MissingOperatorWithDecomp,
MissingOperatorWithoutDecomp,
)
from .ir import (
Constant,
FixedLayout,
InputBuffer,
Pointwise,
Reduction,
StorageBox,
TensorBox,
)
from .lowering import (
constrain_to_fx_strides,
FALLBACK_ALLOW_LIST,
fallback_handler,
fallback_node_due_to_unsupported_type,
layout_constraints,
lowerings,
make_fallback,
needs_realized_inputs,
unsupported_output_tensor,
)
from .sizevars import SizeVarAllocator
from .utils import convert_shape_to_inductor, gather_origins, get_sympy_Expr_dtype
from .virtualized import V
log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
if config.is_fbcode():
from torch._inductor.fb.utils import log_module_code
else:
def log_module_code(*args, **kwargs):
pass
def supported_dtype_of_cpp_wrapper(dtype, cuda):
supported_dtype = {
torch.float32,
torch.float64,
torch.int64,
torch.int32,
torch.int16,
torch.int8,
torch.uint8,
torch.bool,
torch.bfloat16,
torch.complex32,
torch.complex64,
torch.complex128,
torch.float16,
}
if cuda:
supported_dtype.add(torch.float8_e4m3fn)
supported_dtype.add(torch.float8_e5m2)
supported_dtype.add(torch.float8_e4m3fnuz)
supported_dtype.add(torch.float8_e5m2fnuz)
return dtype in supported_dtype
def may_get_constant_buffer_dtype(constant_buffer):
assert isinstance(
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
if isinstance(constant_buffer, sympy.core.numbers.Integer):
return torch.int64
if isinstance(constant_buffer, sympy.Expr):
return get_sympy_Expr_dtype(constant_buffer)
if constant_buffer.is_integer:
return torch.int64
elif constant_buffer.is_float:
return torch.float32
else:
return None
def is_magic_method(op):
magic_ops = {method_to_operator(m) for m in magic_methods}
return op in magic_ops
def getattr_recursive(obj, target):
target_atoms = target.split(".")
attr_itr = obj
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
)
attr_itr = getattr(attr_itr, atom)
return attr_itr
class GraphLowering(torch.fx.Interpreter):
graph_outputs: List[ir.IRNode]
def symbolic_sizes_strides(self, ex: torch.Tensor):
"""
Support dynamic shapes and dynamic strides by assigning variables
to each dimension. We duck-shape tensors, so if two tensors
have the same size they get assigned the same symbolic variable.
"""
if self.reuse_shape_env:
return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
ex.stride()
)
else:
from torch._dynamo.source import ConstantSource
# TODO: this should not be needed once #93059 lands
# https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
# TODO: make a dedicated UnknownSource for this?
# NB: This is using the legacy default behavior from
# create_symbolic_sizes_strides_storage_offset but we hope we can
# just delete this entirely
source = ConstantSource(
f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
)
(
size,
stride,
_,
) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
ex,
source,
)
size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
return size, stride
def static_sizes_strides(self, ex: torch.Tensor):
"""
Primarily used to weights
"""
size = [sympy.Integer(i) for i in ex.size()]
stride = [sympy.Integer(i) for i in ex.stride()]
return size, stride
def init_backend_registration(self):
if get_scheduling_for_device("cpu") is None:
from .codegen.cpp import CppScheduling
register_backend_for_device("cpu", CppScheduling, WrapperCodeGen)
if get_scheduling_for_device("cuda") is None:
from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
register_backend_for_device("cuda", CUDACombinedScheduling, WrapperCodeGen)
def __init__(
self,
gm: torch.fx.GraphModule,
example_inputs: Optional[List[torch.Tensor]] = None,
shape_env=None,
num_static_inputs=None,
graph_id=None,
cpp_wrapper=False,
aot_mode=False,
user_visible_outputs=frozenset(),
layout_opt=None,
extern_node_serializer=None,
is_inference=False,
is_const_graph=False,
const_output_index=None,
const_code=None,
const_module=None,
name=None,
):
super().__init__(gm)
self.example_inputs = example_inputs
self.layout_opt = (
layout_opt
if layout_opt is not None
else self.decide_layout_opt(gm, is_inference=is_inference)
)
self.num_channels_last_conv = 0
self.is_inference = is_inference
self.is_const_graph = is_const_graph
self.const_code = const_code
self.const_module = const_module
self.extra_traceback = False # we do our own error wrapping
if shape_env is None:
shape_env = ShapeEnv()
self.reuse_shape_env = False
else:
self._shape_env = shape_env
self.reuse_shape_env = True
self._shape_env = shape_env
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: List[str] = []
self.graph_inputs: Dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.device_types: Set[str] = (
const_module.device_types if const_module else set()
)
self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
self.cuda = False
self.buffers: List[ir.Buffer] = []
self.const_output_index: Dict[str, int] = (
const_output_index if const_output_index else {}
)
self.folded_constants: Set[str] = (
set(const_output_index.keys()) if const_output_index else set()
)
self.constants: Dict[str, torch.Tensor] = (
const_module.constants if const_module else {}
)
self.constant_reprs: Dict[str, str] = {}
self.removed_buffers: Set[str] = set()
self.removed_inplace_buffers: Set[str] = set()
self.mutated_buffers: Set[str] = set()
self.never_reuse_buffers: Set[str] = set()
self.inplaced_to_remove: Set[str] = set()
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
self.extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any]
] = extern_node_serializer
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.num_static_inputs = num_static_inputs
self.lists: Dict[str, List[str]] = {}
self.mutated_inputs: Set[str] = set()
self.mutated_input_idxs: List[int] = []
self.name_to_buffer: Dict[str, ir.Buffer] = {}
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
self.creation_time = time.time()
self.name = name
self.cpp_wrapper = cpp_wrapper
# record multi_kernel choice for cpp_wrapper so the second pass knows
# which sub-kernel is picked. Copy cpp_wrapper to another variable
# since cpp_wrapper flag is set to false for the first pass of codegen.
self.record_multi_kernel_choice = cpp_wrapper
self.multi_kernel_to_choice: Dict[str, int] = {}
self.aot_mode = aot_mode
self.graph_id = graph_id
self.scheduler: "torch._inductor.scheduler.Scheduler" = None # type: ignore[assignment]
self.nodes_prefer_channels_last = (
self.find_nodes_prefer_channels_last() if self.layout_opt else set()
)
self._warned_fallback = {"aten.convolution_backward"}
self.user_visible_outputs = user_visible_outputs
self.cache_key: str = "" # This is the cache key for the compiled artifact
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: List[
Tuple[int, str]
] = (
[]
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs_reason: Optional[str] = None
# only keeping one node per device for stack trace purposes
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
"dynamo_flat_name_to_original_fqn", {}
)
self.allocated_constant_name = (
const_module.allocated_constant_name if const_module is not None else {}
)
self.init_backend_registration()
@staticmethod
def decide_layout_opt(gm, *, is_inference) -> bool:
"""
Decide if we should enable layout optimization for this graph based on
heuristics.
"""
if not config.layout_optimization:
return False
if config.force_layout_optimization:
return True
conv_nodes = [
n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
]
nconv = len(conv_nodes)
if nconv == 0:
return False
# For cpu backend and mkldnn enabled, we always use channels_last for better performance.
if (
torch.backends.mkldnn.enabled
and torch.backends.mkldnn.is_available()
and all(
n.args[idx].meta["val"].device == torch.device("cpu")
for n in conv_nodes
for idx in [0, 1]
)
):
return True
# Following models are skipped due to this:
# jx_nest_base
# volo_d1_224
if len(list(gm.graph.nodes)) >= 300 * nconv:
log.debug("Skipped layout opt because only a few conv")
return False
if any(
has_free_symbols(n.args[idx].meta["val"])
for n in conv_nodes
for idx in [0, 1]
):
log.debug(
"See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
)
return False
def is_grouped(n):
return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
def is_in_out_channel(n):
return (
n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
and n.args[1].meta["val"].size(2) > 1
)
def is_small_channel(n):
return (
n.args[1].meta["val"].size(0) <= 64
and n.args[1].meta["val"].size(1) <= 64
)
# only grouped convolutions benchmarked as slower in conv samples for inference only
if is_inference:
from torch.utils.flop_counter import FlopCounterMode
flop_counts: Dict[str, float] = defaultdict(float)
for node in conv_nodes:
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
node
)
if success:
with FlopCounterMode(display=False) as flop_counter_mode:
with V.fake_mode:
node.target(*args, **kwargs)
counted_flops = flop_counter_mode.get_total_flops()
if is_grouped(node):
node_type = "grouped"
elif is_small_channel(node):
node_type = "small"
elif is_in_out_channel(node):
node_type = "in_out"
else:
node_type = "default"
flop_counts[node_type] += counted_flops
else:
log.debug("Conv inputs meta not found")
# average benchmarked channels last speedup / slowdown, < 1 is speedup.
# taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
# To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
GROUPED_MULTIPLIER = 1.358
DEFAULT_MULTIPLIER = 0.823
IN_OUT_MULTIPLIER = 0.725
SMALL_MULTIPLIER = 0.783
total_flops = sum(flop_counts.values())
# TODO - get different values per hardware
weighted_flops = (
flop_counts["grouped"] * GROUPED_MULTIPLIER
+ flop_counts["small"] * SMALL_MULTIPLIER
+ flop_counts["in_out"] * IN_OUT_MULTIPLIER
+ flop_counts["default"] * DEFAULT_MULTIPLIER
)
do_layout_opt = weighted_flops <= total_flops
if not do_layout_opt:
log.debug(
"Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
total_flops,
weighted_flops,
)
return do_layout_opt
# Channels last layout can dramatically hurt grouped conv perf. E.g.
# Conv with arguments like
# {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
# "stride": [2, 2], "padding": [1, 1], "groups": 2}
# slows down 31x using channels last..
# But a lot of timm models use depthwise separable convolution which will
# result in grouped convolution with in-channel size == 1.
# For those grouped convolution, channels last still helps a lot.
# E.g.
# Conv with arguments
# {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
# "stride": [2, 2], "padding": [1, 1], "groups": 58}
# get 1.86x speedup with channels last layout.
#
# The following heuristics skip using channels-last if the model contains
# grouped convolution with in-channels > 1.
if any(map(is_grouped, conv_nodes)):
log.debug(
"Skip layout opt because found grouped convolution with >1 in_channels!"
)
return False
# For some models that contain convolution with larger in-channel than out-channel, applying
# channels last hurts performance.
# Following models are skipped due to this:
# - pytorch_unet
# - phlippe_densenet (slightly worse)
# - Background_Matting (1.22x -> 0.821x)
# - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
if any(map(is_in_out_channel, conv_nodes)):
log.debug(
"Skip layout opt because some convolutions have smaller out_channel"
)
return False
# Following models are skipped due to this:
# - functorch_maml_omniglot
if all(map(is_small_channel, conv_nodes)):
log.debug("Skip layout opt because all convolution channels are too small")
return False
return True
def qualify_name(self, name: str) -> str:
"""Prepend the given name with the graph name if any."""
if self.name is not None:
return f"{self.name}_{name}"
return name
def make_subgraph(
self,
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
subgraph_name: str,
) -> "GraphLowering":
"""
Make a subgraph of the current graph with all inherited
parts, except the graph module (`gm`) and `example_inputs`.
The subgraphs are lowered separately, but intended to be
inlined in the parent graph's codegening. Hence the need
for maintaining the same `shape_env` and other properties.
The subgraph name is qualified by the parent graph's name.
"""
return GraphLowering(
gm=gm,
example_inputs=example_inputs,
shape_env=self._shape_env,
cpp_wrapper=self.cpp_wrapper,
aot_mode=self.aot_mode,
extern_node_serializer=self.extern_node_serializer,
is_inference=self.is_inference,
name=self.qualify_name(subgraph_name),
)
def find_nodes_prefer_channels_last(self):
"""
The rule to decide if an node prefer channels last is simple.
1. if it's input/output of a convolution
2. if one of its user prefers channels last
We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
channels last.
Consider the scenario: conv -> batch-norm -> relu -> conv
Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
1. the output of batch-norm should be channels last initially since its input is a conv's output.
Forcing the batch-norm's output to be contiguous results in the first copy
2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
We need convert it to channels last layout which results in the second copy.
With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
can be saved.
"""
output_set = set()
for n in reversed(self.module.graph.nodes):
if n.target == torch.ops.aten.convolution.default:
output_set.add(n)
continue
for user in n.users:
if user in output_set:
output_set.add(n)
break
# need a second pass to add downstream nodes of those channel last nodes to the sets.
# This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
#
# Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
# from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
# Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
# tensors and passed to a kernel.
#
# This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
# It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
# This also helps the following models:
# - res2net101_26w_4s
# - res2net50_14w_8s
# - sebotnet33ts_256
for n in self.module.graph.nodes:
if n in output_set:
for child in n.users:
output_set.add(child)
return output_set
def warn_fallback(self, name):
if name not in self._warned_fallback:
self._warned_fallback.add(name)
perf_hint_log.info("Using FallbackKernel: %s", name)
def add_device_info(self, device: torch.device):
self.device_types.add(device.type)
if device.index is not None:
self.device_idxs.add(device.index)
if V.graph.current_node and device not in self.device_node_mapping:
self.device_node_mapping[device] = V.graph.current_node
@property
def fake_mode(self):
return V.fake_mode
def get_buffer(self, buffer_name: str):
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name]
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name]
return None
def get_dtype(self, buffer_name: str):
if buffer_name in self.constants:
return self.constants[buffer_name].dtype
if buffer_name in self.name_to_buffer:
return self.name_to_buffer[buffer_name].get_dtype()
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name].get_dtype()
m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
if m:
return self.get_dtype(m.group(1))
raise KeyError(f"could not find {buffer_name}")
def get_numel(self, buffer_name: str):
from .ir import MultiOutputLayout
if buffer_name in self.constants:
return self.constants[buffer_name].numel()
if buffer_name in self.name_to_buffer:
buf = self.name_to_buffer[buffer_name]
if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
return 1
return buf.get_numel()
if buffer_name in self.graph_inputs:
return self.graph_inputs[buffer_name].get_numel()
raise KeyError(f"could not find {buffer_name}")
@dynamo_timed
def run(self, *args):
return super().run(*args)
def register_buffer(self, buffer: ir.Buffer):
name = self.qualify_name(f"buf{len(self.buffers)}")
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
self.add_device_info(buffer.get_device())
return name
def register_list(self, buffer_names: List[str]):
name = self.qualify_name("list_" + "_".join(buffer_names))
self.lists[name] = buffer_names
return name
def register_users_of(self, node_output):
def register(value):
if isinstance(value, (list, tuple)):
for x in value:
register(x)
if isinstance(value, ir.IRNode):
if (
not hasattr(value, "data")
or not isinstance(value.data, ir.IRNode)
or not (
hasattr(value.data, "data")
and isinstance(value.data.data, ir.IRNode)
)
):
return
for read_name in value.get_read_names():
self.name_to_users[read_name].append(value)
register(node_output)
def mark_buffer_mutated(self, name: str):
"""
When a buffer is mutated we need to make sure all the reads to
the old version are realized before the mutation happens.
"""
assert isinstance(name, str)
self.mutated_buffers.add(name)
if name not in self.name_to_users:
return
for user in self.name_to_users[name]:
user.realize()
def add_tensor_constant(self, data, name=None):
def allocate(name):
if not config.aot_inductor.use_runtime_constant_folding:
for constant_name, value in self.constants.items():
if (
not data.is_mkldnn
and data.size() == value.size()
and data.stride() == value.stride()
and data.dtype == value.dtype
and data.device == value.device
and torch.eq(data, value).all()
):
return constant_name
if name is None:
name = f"constant{len(self.constants)}"
if name[0].isdigit():
name = f"constant_{name}"
name = self.qualify_name(name)
# We may generate a var name for each constant in the codegen.
# Let's only keep sane characters.
prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
name = prefix
cnt = 0
while name in self.constants:
name = f"{prefix}_{cnt}"
cnt += 1
self.constants[name] = data
self.constant_reprs[name] = (
f"{data.device!r} {data.dtype!r} "
f"{tuple(data.size())!r} {tuple(data.stride())!r} "
f"{hash(data):x}"
)
return name
new_name = allocate(name)
self.allocated_constant_name[new_name] = name
return TensorBox.create(
ir.ConstantBuffer(
new_name,
FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
)
)
def constant_name(self, name: str, device_override: Optional[torch.device]):
"""
We AOT copy constants to the devices they are needed on.
If device_override doesn't match the constant's device, then
copy it and return a different name.
"""
if self.constants[name].device == device_override or device_override is None:
return name
alt_name = f"{name}_{device_override.type}{device_override.index or 0}"
if alt_name not in self.constants:
self.constants[alt_name] = self.constants[name].to(device_override)
return alt_name
def placeholder(self, target: str, args, kwargs):
example = super().placeholder(target, args, kwargs)
self.graph_input_names.append(target)
if isinstance(example, SymTypes):
expr = example.node.expr
self.graph_inputs[target] = expr
return expr
elif isinstance(example, (int, bool, float)):
expr = sympy.sympify(example)
self.graph_inputs[target] = expr
return expr
if isinstance(example, BackwardState):
# Ignored arg, must be unused
# Alternately we could filter this out in AotAutograd
return None
assert isinstance(example, torch.Tensor), example
# todo(chilli): We can remove the last check once we turn buffers into
# static shape tensors. That's a hack to workaround Inductor believing
# the buffer should be static but us passing in a fake tensor with
# symbolic shapes.
if not example._has_symbolic_sizes_strides:
# the first N inputs are weights
sizes, strides = self.static_sizes_strides(example)
else:
sizes, strides = self.symbolic_sizes_strides(example)
# TODO(jansel): handle input aliasing
target = self.qualify_name(target)
tensor = TensorBox.create(
InputBuffer(
target,
FixedLayout(example.device, example.dtype, sizes, strides),
)
)
self.graph_inputs[target] = tensor
self.graph_inputs_original[target] = tensor.data.data
self.add_device_info(example.device)
return tensor
def call_function(self, target, args, kwargs):
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
if hasattr(target, "_inductor_lowering_function"):
# passthrough lowerings from .pattern_matcher
return target(*args, **kwargs)
def get_custom_op_layout_constraints(target, args, kwargs):
# Custom operations that require preserving stride order
# which run through implicit fallback must constrain their
# arguments' fx strides
layout_constraint = None
if torch._C.Tag.needs_fixed_stride_order in target.tags:
# We have to set the current args because call_function will immediately
# evaluate this lowering after creating the fallback, without evaluating
# the layout constraint
args, kwargs = constrain_to_fx_strides(
self.current_node, *args, **kwargs
)
# Also register the layout constraint so when the fallback
# is used again, we can constrain the args to the same layout
layout_constraint = constrain_to_fx_strides
return layout_constraint, args, kwargs
if target not in lowerings:
assert isinstance(
target, torch._ops.OpOverload
), f"{target} is not an OpOverload"
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
elif config.implicit_fallbacks:
layout_constraint, args, kwargs = get_custom_op_layout_constraints(
target, args, kwargs
)
error = (
MissingOperatorWithDecomp
if get_decompositions([target])
else MissingOperatorWithoutDecomp
)
log.info(
"Creating implicit fallback for:\n%s",
error.operator_str(target, args, kwargs),
)
make_fallback(target, layout_constraint)
elif get_decompositions([target]):
# There isn't a good way to dynamically patch this in
# since AOT Autograd already ran. The error message tells
# the user how to fix it.
raise MissingOperatorWithDecomp(target, args, kwargs)
else:
raise MissingOperatorWithoutDecomp(target, args, kwargs)
try:
log.debug(" via %s", lowerings[target])
out = lowerings[target](*args, **kwargs)
return out
except Exception as e:
raise LoweringException(e, target, args, kwargs).with_traceback(
e.__traceback__
) from None
@staticmethod
def can_inline_constant(t: torch.Tensor) -> bool:
"""
True if this is a small constant attr that will be inlined.
"""
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(self, target, args, kwargs):
# this is a constant
value = getattr_recursive(self.module, target)
if isinstance(value, torch.fx.GraphModule):
return ir.Subgraph(name=target, graph_module=value)
if (
config.aot_inductor.use_runtime_constant_folding
or config.always_keep_tensor_constants
or unsupported_output_tensor(value)
):
return self.add_tensor_constant(value, target)
with no_dispatch():
if value.shape == ():
return Constant(value.item(), value.dtype, value.device)
if self.can_inline_constant(value):
# tensor lowering has constant inlining logic
from .lowering import tensor
return tensor(value.tolist(), dtype=value.dtype, device=value.device)
return self.add_tensor_constant(value, target)
def call_module(self, target, args, kwargs):
raise AssertionError()
def call_method(self, target, args, kwargs):
raise AssertionError()
def output(self, target, args, kwargs):
result = super().output(target, args, kwargs)
assert isinstance(result, (tuple, list)), type(result)
assert all(
isinstance(
x,
(
TensorBox,
ir.Constant,
type(None),
ir.ConstantBuffer,
sympy.Expr,
sympy.logic.boolalg.Boolean,
int,
),
)
for x in result
), result
self.graph_outputs = [ir.ExternKernel.realize_input(x) for x in result]
value: ir.IRNode
for name, value in self.graph_inputs.items():
assert isinstance(
value, (TensorBox, sympy.Expr)
), f"Unsupported inductor graph input type: {type(value)}"
if not isinstance(value, TensorBox):
continue
value.realize()
assert isinstance(value, TensorBox)
value = value.data
assert isinstance(value, ir.StorageBox)
value_storage_box = value
value = value.data
if not isinstance(value, InputBuffer) or value.get_name() != name:
# one of our inputs was mutated, need to turn that into a copy
ir.MutationLayout.realize_into(value, self.graph_inputs_original[name])
# replace output with mutated input
try:
ind = self.graph_outputs.index(value_storage_box)
self.graph_outputs[ind] = self.graph_inputs_original[name]
except ValueError:
pass
self.finalize()
log.debug(
"Force channels last inputs for %d conv for the current graph with id %d",
self.num_channels_last_conv,
self.graph_id if self.graph_id is not None else -1,
)
def finalize(self):
for buf in self.buffers:
buf.decide_layout()
@contextmanager
def set_current_node(self, node: torch.fx.Node):
old = self.current_node
try:
self.current_node = node
yield
finally:
self.current_node = old
def run_node(self, n: torch.fx.Node):
def debug(msg):
log.debug("lowering %s %s", LazyString(n.format_node), msg)
origins = {n}
if n.op == "call_function":
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node(
n
), V.set_current_node(n):
if (
n.op == "call_function"
and n.target is not operator.getitem
and fallback_node_due_to_unsupported_type(n)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs # type: ignore[possibly-undefined]
)
elif n.op == "call_function" and n.target in layout_constraints:
debug("layout_constraints")
args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
result = self.call_function(n.target, args, kwargs)
elif is_magic_method(n.target):
# TODO: this is sus, it probably should be handled in the
# lowerings themselves similarly to sym_size/sym-stride
debug("is_magic_method")
if isinstance(n.meta["val"], torch.SymInt):
result = n.meta["val"].node.expr
else:
result = super().run_node(n)
else:
debug("")
result = super().run_node(n)
# require the same stride order for dense outputs,
# 1. user-land view() will not throw because inductor
# output different strides than eager
# long term the solution is to make view() always succeed
# with infallible strides.
# 2: as_strided ops, we need make sure its input has same size/stride with
# eager model to align with eager behavior.
as_strided_ops = [
torch.ops.aten.as_strided.default,
torch.ops.aten.as_strided_.default,
torch.ops.aten.as_strided_scatter.default,
]
is_output = any(user.op == "output" for user in n.users)
is_input_for_as_strided = any(
user.target in as_strided_ops for user in n.users
)
if (
is_output
and isinstance(result, TensorBox)
and isinstance(result.data, ir.BaseView)
):
# Realize so that outputs are correctly aliased
result.realize()
if (is_output or is_input_for_as_strided) and isinstance(
n.meta["val"], torch.Tensor
):
strides = n.meta["val"].stride()
dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
# requiring a stride order for a non-dense output wouldn't
# recreate the same strides, and would fail with view, defer for now.
if dense and len(strides):
stride_order = ir.get_stride_order(strides)
if (
len(result.get_size()) == 4
and n in self.nodes_prefer_channels_last
and n.name not in self.user_visible_outputs
and not is_input_for_as_strided
):
stride_order = ir.NHWC_STRIDE_ORDER
result = ir.ExternKernel.require_stride_order(result, stride_order)
# Realize if (1) any user need inputs realized, or (2) there is
# already too many reads and rematerializing can be bad.
num_users = len(set(n.users))
if num_users > 1 and isinstance(result, TensorBox):
for user in n.users:
if user.target in needs_realized_inputs:
result.realize_hint()
# This inclusion is somewhat controversial (from
# discussion between Horace, Natalia, and Elias).
# Currently, it's not very clear why this is helpful.
# The general idea here is that even though a node may
# have FlexibleLayout, we still often *treat* it as if
# it was contiguous. This appears to sometimes result in
# suboptimal behavior.
#
# When we do a better job selecting layout, we should
# revisit this.
need_fixed_layout = [
torch.ops.aten.convolution_backward.default,
torch.ops.aten.mm.default,
torch.ops.aten._int_mm.default,
]
if not self.layout_opt:
need_fixed_layout.append(torch.ops.aten.convolution.default)
if torch._C._has_mkldnn:
need_fixed_layout += [
torch.ops.mkldnn._convolution_pointwise.default,
torch.ops.mkldnn._convolution_pointwise.binary,
torch.ops.mkldnn._convolution_pointwise_.binary,
torch.ops.mkldnn._convolution_transpose_pointwise.default,
torch.ops.mkldnn._linear_pointwise.default,
torch.ops.mkldnn._linear_pointwise.binary,
torch.ops.aten.mkldnn_rnn_layer.default,
torch.ops.onednn.qconv2d_pointwise.default,
torch.ops.onednn.qconv2d_pointwise.binary,
torch.ops.onednn.qlinear_pointwise.default,
torch.ops.onednn.qlinear_pointwise.tensor,
]
if torch._C.has_mkl:
need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
if user.target in need_fixed_layout:
result = ir.ExternKernel.require_stride_order(
result, ir.get_stride_order(n.meta["val"].stride())
)
if user.op == "output":
if isinstance(result.data.data, (Pointwise, Reduction)):
result.realize()
# TODO(jansel): introduce a store vs inline choice
result.mark_reuse(len(n.users))
# Realize if the IRNode already has accumulated lots of reads
if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
# Prevent excessive accumulation in a computed buffer, when
# there are multiple branches each with small number of memory
# reads, but they converge to a user.
result.realize_hint()
# Realize if a Pointwise has too much stuff to be inlined.
# As this may cause RecursionError during Inductor's evaluation.
if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
curr = result.data.data
if isinstance(curr, Pointwise):
# Use inner fn as a rough proxy. Good enough.
if curr.has_large_inner_fn():
result.realize()
# This is not complete, but it doesn't have to be: origin_node
# tracking is best effort. The logic here critically relies on direct
# TensorBox -> StorageBox denoting a non-view; we don't bother trying
# to get views to work. Feel free to add any extra cases as needed.
#
# Note: we can't YOLO tree_map over this result, because if there are
# buffers or a view involved, we might not be able to validly assign
# the origin_node here.
if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
if isinstance(result.data.data, ir.Loops):
result.data.data.origin_node = n
elif isinstance(result.data.data, ir.Buffer):
result.data.data.origin_node = n
if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
result.data.data.data, ir.Loops
):
result.data.data.data.origin_node = n
# Not really multi-output, can straightforwardly recurse in
elif (
isinstance(result.data.data, ir.MultiOutput)
and not result.data.data.indices
):
if isinstance(result.data.data.inputs[0], ir.Buffer):
result.data.data.inputs[0].origin_node = n
self.register_users_of(result)
return result
def validate_can_generate_cpp_wrapper(self):
if config.disable_cpp_codegen:
raise CppWrapperCodeGenError("C++ codegen is disabled")
if sys.platform not in ["linux", "darwin"]:
raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
for value in self.graph_inputs.values():
dtype = None
if isinstance(value, TensorBox):
dtype = value.get_dtype()
elif isinstance(
value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
):
dtype = may_get_constant_buffer_dtype(value)
if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
def init_wrapper_code(self):
self.cuda = "cuda" in self.device_types
if self.cpp_wrapper:
self.validate_can_generate_cpp_wrapper()
self.wrapper_code = CppWrapperCuda() if self.cuda else CppWrapperCpu()
else:
device_types = self.device_types.copy()
device_types.discard("cpu")
# TODO(Eikan): Only support mixing cpu and other device now.
assert len(device_types) <= 1, "Does not support mixing {}".format(
"+".join(device_types)
)
only_cpu = len(device_types) == 0
device_type = "cpu" if only_cpu else device_types.pop()
self.device_ops = get_device_op_overrides(device_type)
wrapper_code_gen_cls = get_wrapper_codegen_for_device(device_type)
assert (
wrapper_code_gen_cls is not None
), f"Device {device_type} not supported"
self.wrapper_code = wrapper_code_gen_cls()
if self.const_module:
# If we have const module, we could reuse the kernels
# This could avoid duplication and save time on doing recompilation (if Triton.)
self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
self.wrapper_code.src_to_kernel = (
self.const_module.wrapper_code.src_to_kernel
)
def codegen_with_cpp_wrapper(self):
"""
For CPU, the cpp wrapper codegen is done in one pass.
For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
generate cpp wrapper code and compile it to a dynamic library in the second pass.
"""
if "cuda" in self.device_types:
# first pass
self.cpp_wrapper = False
compiled = self.compile_to_module().call
def materialize(x):
if isinstance(x, (torch.SymInt, torch.SymFloat)):
# Need concrete value to run dynamic shapes and tune the result
return x.node.hint
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs" + str(type(x))
return x
if tracing_context := torch._guards.TracingContext.try_get():
if tracing_context.output_strides:
tracing_context.output_strides.clear()
params_flat = [
param
for param in tracing_context.params_flat # type: ignore[union-attr]
if param is not None
]
real_inputs = [
materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
]
else:
real_inputs = [materialize(x) for x in V.real_inputs]
with torch.utils._python_dispatch._disable_current_modes():
assert self.example_inputs is not None
compiled(real_inputs)
del real_inputs
# second pass
# TODO: reuse self.scheduler from the first pass to speed up the second pass
self.cpp_wrapper = True
self.removed_buffers.clear()
self.inplaced_to_remove.clear()
return self.codegen()
else:
# cpu
return self.codegen()
def codegen(self):
from .scheduler import Scheduler
self.init_wrapper_code()
self.scheduler = Scheduler(self.buffers)
V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
self.scheduler.codegen()
return self.wrapper_code.generate(self.is_inference)
def codegen_subgraph(self, parent_graph):
"""
This is a more compact version of the `codegen()` above
where we codegen this graph as a subgraph of some parent
graph. The parent graph is passed as an argument: the
intention is to inline codegening of the subgraph in
the parent graph's wrapper code (including the generated
kerenls). The wrapper code is not finalized (via `.generate()`
call), as this will be done in the parent graph's `codegen()`.
"""
from .scheduler import Scheduler
self.wrapper_code = parent_graph.wrapper_code
self.device_ops = parent_graph.device_ops
self.cpp_wrapper = parent_graph.cpp_wrapper
self.scheduler = Scheduler(self.buffers)
self.scheduler.codegen()
def count_bytes(self):
from .scheduler import Scheduler
scheduler = Scheduler(self.buffers)
total_bytes = 0
node_counts = []
node_runtimes = []
for node in scheduler.nodes:
num_bytes = node.get_read_write_buffers_sizes()
total_bytes += num_bytes
node_counts.append((node, num_bytes // 4))
node_runtimes.append((node, node.get_estimated_runtime()))
return total_bytes, node_counts, node_runtimes
@dynamo_timed(phase_name="code_gen")
def compile_to_module(self):
from .codecache import PyCodeCache
code, linemap = (
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
)
linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
key, path = PyCodeCache.write(code)
mod = PyCodeCache.load_by_key_path(
key, path, linemap=linemap, attrs=self.constants
)
self.cache_key = key
self.cache_path = path
self.cache_linemap = linemap
# Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
# TODO. Revisit this once the logging API is more mature
assert mod.__file__ is not None
log_module_code(mod.__file__)
log.debug("Output code written to: %s", mod.__file__)
output_code_log.debug("Output code: \n%s", code)
trace_structured(
"inductor_output_code",
lambda: {"filename": mod.__file__},
payload_fn=lambda: code,
)
output_code_log.info("Output code written to: %s", mod.__file__)
if config.benchmark_kernel:
print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
V.debug.output_code(mod.__file__)
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def compile_to_fn(self):
if self.aot_mode:
from .codecache import AotCodeCompiler
assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
code, linemap = self.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code)
serialized_extern_kernel_nodes = None
if (
config.is_fbcode()
and self.extern_kernel_nodes
and self.extern_node_serializer
):
serialized_extern_kernel_nodes = self.extern_node_serializer(
self.extern_kernel_nodes
)
output_code_log.debug(
"Serialized Extern Kernel Nodes: \n%s",
serialized_extern_kernel_nodes,
)
# Directly return the file path with the compiled code
return AotCodeCompiler.compile(
self, code, serialized_extern_kernel_nodes, cuda=self.cuda
)
else:
return self.compile_to_module().call
def get_output_names(self):
return [
node.get_name()
for node in self.graph_outputs
if not isinstance(node, ir.NoneAsConstantBuffer)
and not isinstance(node, ir.ShapeAsConstantBuffer)
]
def is_unspec_arg(self, name: str):
# dynamo wraps unspec variable as 0d CPU tensor,
# need to convert to scalar during codegen (triton only)
return (
name in self.graph_inputs.keys()
and self.graph_inputs[name].get_numel() == 1
and self.graph_inputs[name].get_device().type == "cpu"
)
|