Spaces:
Running
Running
File size: 42,875 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 |
"""
This file does three things:
- Contains the definition of SymNode
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- Does not depend on sympy at import time
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""
import builtins
import itertools
import logging
import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, Type, TYPE_CHECKING, Union
import torch
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import ( # noqa: F401
sym_float,
sym_ite,
sym_max,
sym_min,
sym_not,
SymBool,
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
SymTypes = (SymInt, SymFloat, SymBool)
def _to_symtype(t):
if t is bool:
return SymBool
if t is int:
return SymInt
if t is float:
return SymFloat
return t
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(
self,
expr,
shape_env,
pytype,
hint: Optional[Union[int, float, bool]],
constant=None,
fx_node=None,
):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
# in hopes that we've learned enough about the unbacked symints to
# discharge the hint; otherwise, you're likely to just error out.
#
# (A previous version of this system had some optimizations to only
# recompute when it was possible we had learned enough about the
# unbacked symint that a hint was now possible, but as we added more
# potential refinements to unbacked symints this got harder to keep
# in sync, so we've deleted it for now.)
if hint is not None:
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
# Record the FX node of the current node if we are doing translation
# validation. They will be used for building the input assertions for
# the translation validation problem.
self.fx_node = (
fx_node if self.shape_env._translation_validation_enabled else None
)
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
@property
def expr(self):
return self.shape_env.replace(self._expr)
# Recompute the hint and see if we've got it now
# Precondition: self._hint is None
def _update_hint(self):
r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
if r is not None:
self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def has_hint(self):
if self._hint is None:
self._update_hint()
return self._hint is not None
def require_hint(self, fallback=None):
if self._hint is None:
self._update_hint()
if self._hint is None:
if fallback is not None:
return fallback
# NB: we expect this to raise
return self.shape_env.size_hint(self.expr)
return self._hint
def maybe_as_int(self):
if self.expr.is_number:
return int(self.expr)
else:
return None
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def is_bool(self):
return self.pytype is bool
def is_nested_int(self):
# Unbacked SymInts cannot be nested int today
return (
self._hint is not None
and isinstance(self._hint, SymInt)
and self._hint.node.is_nested_int()
)
def wrap_int(self, num):
assert type(num) is int
import sympy
return SymNode(
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
)
def wrap_float(self, num):
assert type(num) is float
import sympy
return SymNode(
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
)
def wrap_bool(self, num):
assert type(num) is bool
import sympy
return SymNode(
sympy.true if num else sympy.false,
self.shape_env,
bool,
num,
constant=num,
fx_node=num,
)
def clone(self):
return self
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return self.str()
# These methods call the metaprogrammed methods, they're hand written
# here so we get good stack traces
def abs(self) -> "SymNode":
return self._abs() # type: ignore[attr-defined]
def pos(self) -> "SymNode":
return self._pos() # type: ignore[attr-defined]
def round(self, ndigits=None) -> "SymNode":
return self._round(ndigits) # type: ignore[attr-defined]
def add(self, other) -> "SymNode":
return self._add(other) # type: ignore[attr-defined]
def sub(self, other) -> "SymNode":
return self._sub(other) # type: ignore[attr-defined]
def mul(self, other) -> "SymNode":
return self._mul(other) # type: ignore[attr-defined]
def mod(self, other) -> "SymNode":
return self._mod(other) # type: ignore[attr-defined]
def pow(self, other) -> "SymNode":
return self._pow(other) # type: ignore[attr-defined]
def and_(self, other) -> "SymNode":
return self._and_(other) # type: ignore[attr-defined]
def or_(self, other) -> "SymNode":
return self._or_(other) # type: ignore[attr-defined]
def truediv(self, other) -> "SymNode":
return self._truediv(other) # type: ignore[attr-defined]
def floordiv(self, other) -> "SymNode":
return self._floordiv(other) # type: ignore[attr-defined]
def lshift(self, other) -> "SymNode":
return self._lshift(other) # type: ignore[attr-defined]
def rshift(self, other) -> "SymNode":
return self._rshift(other) # type: ignore[attr-defined]
def sym_not(self) -> "SymNode": # noqa: F811
return self._sym_not() # type: ignore[attr-defined]
def eq(self, other) -> "SymNode":
return self._eq(other) # type: ignore[attr-defined]
def ne(self, other) -> "SymNode":
return self._ne(other) # type: ignore[attr-defined]
def gt(self, other) -> "SymNode":
return self._gt(other) # type: ignore[attr-defined]
def lt(self, other) -> "SymNode":
return self._lt(other) # type: ignore[attr-defined]
def le(self, other) -> "SymNode":
return self._le(other) # type: ignore[attr-defined]
def ge(self, other) -> "SymNode":
return self._ge(other) # type: ignore[attr-defined]
def floor(self) -> "SymNode":
return self._floor() # type: ignore[attr-defined]
def is_integer(self) -> "SymNode":
return self._is_integer() # type: ignore[attr-defined]
def sym_float(self) -> "SymNode": # noqa: F811
return self._sym_float() # type: ignore[attr-defined]
def sym_int(self) -> "SymNode":
return self._sym_int() # type: ignore[attr-defined]
def ceil(self) -> "SymNode":
return self._ceil() # type: ignore[attr-defined]
def neg(self) -> "SymNode":
return self._neg() # type: ignore[attr-defined]
def sym_min(self, other) -> "SymNode": # noqa: F811
return self._sym_min(other) # type: ignore[attr-defined]
def sym_max(self, other) -> "SymNode": # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]
def sym_ite(self, then_val, else_val) -> "SymNode":
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
def is_contiguous(self, sizes, strides) -> "SymNode":
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
# Make C++ happy
def sym_or(self, other):
return self.or_(other)
def sym_and(self, other):
return self.and_(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return int(r)
except Exception:
log.warning("Failed to convert to int: %s", r)
raise
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, expect_rational=False
)
try:
return float(r)
except Exception:
log.warning("Failed to convert to float: %s", r)
raise
def guard_bool(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def expect_true(self, file, line):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
if self.has_hint() and not free_unbacked_symbols(self.expr):
# OK to generate guards
return self.guard_bool(file, line)
# Generate a deferred runtime assert (this might actually end up doing
# a regular guard if we can!)
# TODO: file/line here is very important, because the assert has been
# deferred so you can't backtrace easily
return self.shape_env.defer_runtime_assert(
self.expr, f"{file}:{line}", fx_node=self.fx_node
)
def expect_size(self, file, line):
from torch.fx.experimental.symbolic_shapes import _advise_is_size
b = self.ge(self.wrap_int(0))
# Generate a deferred runtime assert
r = b.expect_true(file, line)
# Refine compile time range, but only if it's unbacked.
# If you refine range for hinted variables, you can end up making
# improper deductions since compile time reasoning may be
# incompatible with runtime reasoning.
if r and not self.has_hint():
_advise_is_size(SymInt(self))
return r
def guard_size_oblivious(self, file, line):
"""
Like guard_bool, but if we encounter unbacked symbols, if those symbols
are size-like, we will treat them as >= 2 for the purposes of the analysis.
This CHANGES the runtime semantics, but all size-oblivious sites have been
audited to ensure that the runtime semantics don't change in a material way.
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
an unbacked one size, or a tensor reporting as non-contiguous even if it's
contiguous if it would have been reported contiguous due to being empty.
"""
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def bool_(self):
return self.guard_bool("", 0)
def is_symbolic(self):
return True
def nested_int(self):
return None
def is_constant(self):
return False
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"pos": operator.pos,
"abs": operator.abs,
"add": operator.add,
"and": operator.and_,
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
"floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
"is_integer": lambda x: x.is_integer(),
"le": operator.le,
"lshift": operator.lshift,
"lt": operator.lt,
"mod": operator.mod,
"mul": operator.mul,
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"pow": operator.pow,
"round": builtins.round,
"rshift": operator.rshift,
"sub": operator.sub,
"sym_float": sym_float,
"sym_ite": sym_ite,
"sym_max": sym_max,
"sym_min": sym_min,
"sym_not": sym_not,
"truediv": operator.truediv,
}
unary_magic_methods = {
"abs",
"sym_float",
"ceil",
"floor",
"neg",
"sym_not",
"pos",
}
# Adding math ops: sqrt, cos, sin, ...
def _get_sym_node_fn(name):
def fn(self):
return getattr(self, f"_sym_{name}")()
return fn
math_op_names = (
"sqrt",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"asin",
"acos",
"atan",
)
for name in math_op_names:
sym_name = f"sym_{name}"
priv_sym_name = f"_{sym_name}"
setattr(SymNode, sym_name, _get_sym_node_fn(name))
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
unary_magic_methods.add(sym_name)
__all__.append(sym_name)
# Unary methods that are not magic methods
unary_nonmagic_methods = {
"is_integer",
}
unary_methods = unary_magic_methods | unary_nonmagic_methods
# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that implicitly convert SymBool into SymInt
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
# Methods that are only for float
only_float_magic_methods = {"is_integer"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
always_float_magic_methods = {"truediv", "sym_float", "pow"}
for name in math_op_names:
sym_name = f"sym_{name}"
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor"}
always_bool_magic_methods = {
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
"and",
"or",
"sym_not",
"is_non_overlapping_and_dense",
"is_integer",
}
# Methods that have a `__foo__` as well as `__rfoo__`
def _sympy_truediv(a, b):
from torch.utils._sympy.functions import TrueDiv
return TrueDiv(a, b)
def _sympy_floordiv(a, b):
from torch.utils._sympy.functions import FloorDiv
return FloorDiv(a, b)
def _sympy_mod(a, b):
from torch.utils._sympy.functions import Mod
return Mod(a, b)
def _sympy_pow(a, b):
from torch.utils._sympy.functions import Pow
return Pow(a, b)
def _sympy_and(a, b):
import sympy
return sympy.And(a, b)
def _sympy_or(a, b):
import sympy
return sympy.Or(a, b)
def _sympy_lshift(a, b):
from torch.utils._sympy.functions import LShift
return LShift(a, b)
def _sympy_rshift(a, b):
from torch.utils._sympy.functions import RShift
return RShift(a, b)
reflectable_magic_methods = {
"add": operator.add,
"sub": operator.sub,
"mul": operator.mul,
"mod": _sympy_mod,
"pow": _sympy_pow,
"and": _sympy_and,
"or": _sympy_or,
"truediv": _sympy_truediv,
"floordiv": _sympy_floordiv,
"lshift": _sympy_lshift,
"rshift": _sympy_rshift,
}
def _floor_ceil_helper(a, fn):
import sympy
if isinstance(a, sympy.Mul):
aa = a.args
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
coef = sympy.Integer(aa[0])
if aa[0] == coef: # structural equality test
return coef * aa[1]
if (
isinstance(a, sympy.Float)
and a == sympy.Integer(a)
or isinstance(a, sympy.Integer)
):
return sympy.Integer(a)
return fn(a)
def _sympy_floor(a):
import sympy
return _floor_ceil_helper(a, sympy.floor)
def _sympy_ceil(a):
import sympy
return _floor_ceil_helper(a, sympy.ceiling)
def _sympy_eq(a, b):
import sympy
return sympy.Eq(a, b)
def _sympy_ne(a, b):
import sympy
return sympy.Ne(a, b)
def _sympy_gt(a, b):
import sympy
return sympy.Gt(a, b)
def _sympy_lt(a, b):
import sympy
return sympy.Lt(a, b)
def _sympy_le(a, b):
import sympy
return sympy.Le(a, b)
def _sympy_ge(a, b):
import sympy
return sympy.Ge(a, b)
def _sympy_min(a, b):
import sympy
return sympy.Min(a, b)
def _sympy_max(a, b):
import sympy
return sympy.Max(a, b)
def _sympy_ite(a, t, f):
import sympy
return sympy.Piecewise((t, a), (f, True))
current_module = sys.modules[__name__]
def _get_sym_math_fn(name):
def fn(a):
import sympy
return getattr(sympy, name)(a)
return fn
for name in math_op_names:
priv_sympy_name = f"_sympy_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
def _sympy_abs(a):
import sympy
return sympy.Abs(a)
def _sympy_round(number, ndigits=None):
from torch.utils._sympy.functions import Round, RoundDecimal
if ndigits is None:
return Round(number)
else:
return RoundDecimal(number, ndigits)
def _sympy_sym_float(a):
# Cannot use sympy.Float(a) here, coz it expects python literals
# Multiply by 1.0 to cast to float. This is needed when the input
# is a SymInt which has the assumption that it is integer and
# SymPy will otherwise assume that return value cannot be a float.
return a * 1.0
def _sympy_is_integer(a):
import sympy
return sympy.Eq(sympy.floor(a), a)
magic_methods = {
**reflectable_magic_methods,
"sym_not": operator.invert,
"pos": operator.pos,
"eq": _sympy_eq,
"ne": _sympy_ne,
"gt": _sympy_gt,
"lt": _sympy_lt,
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": operator.neg,
"sym_min": _sympy_min,
"sym_max": _sympy_max,
"sym_ite": _sympy_ite,
"abs": _sympy_abs,
"round": _sympy_round,
"is_integer": _sympy_is_integer,
}
for name in math_op_names:
sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
def sympy_is_contiguous(sizes, strides):
dim = len(sizes)
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
def sympy_is_contiguous_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if len(dim_order) != dim:
return sympy.false
is_contiguous = sympy.true
z = sympy.Integer(1)
# Contiguous if the strides make sense (or the dim is size 1)
for d in dim_order:
is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z)
z *= sizes[d]
# OR if any size is zero
for d in range(dim):
is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0))
return is_contiguous
# NB: There is a TODO in C++ to allow omitting the batch dim. If that
# happens you will need to refactor this
def sympy_is_channels_last_contiguous_2d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_contiguous_3d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if dim != len(dim_order):
return sympy.false
m = sympy.Integer(0)
r = sympy.true
# special case for trivial C dimension. default to NCHW
r &= sympy.Ne(strides[1], 0)
for d in dim_order:
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
# Fallback to NCHW as default layout for ambiguous cases
# This is the flaw of implicit memory_format from strides.
# N111 tensor with identical strides for size 1 dimension;
# Two cases could lead us here:
# a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
# b. N11W contiguous Tensor sliced on the W-dimension.
# ([N,1,1,1]@[W,W,W,W])
if d == 0:
r &= sympy.Ne(m, strides[1])
# This is necessary to:
# 1. distinguish the memory_format of N1H1;
# [H, 1, 1, 1] channels_last stride
# [H, H, 1, 1] contiguous stride
# 2. permutation of 1C1W:
# [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
# [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
# channels_last
m = strides[d] * sympy.Max(sizes[d], 1)
return r
def sympy_is_channels_last_strides_2d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_strides_3d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
sizes_strides_methods = {
# TODO: These could also be done with indicators, maybe it is better
# for reasoning to do it that way
"is_contiguous": sympy_is_contiguous,
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
}
alternate_impl_if_hinted_methods = {
"sym_min": builtins.min,
"sym_max": builtins.max,
}
def to_node(self, num):
if isinstance(num, SymTypes):
return num.node
elif type(num) is bool:
return self.wrap_bool(num)
elif type(num) is int:
return self.wrap_int(num)
elif type(num) is float:
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
elif x.is_bool():
return SymBool(x)
else:
raise AssertionError(f"unrecognized return type {x}")
def method_to_operator(method):
return METHOD_TO_OPERATOR[method]
def _make_node_magic(method, func):
func = lru_cache(256)(func)
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
alternate_impl = alternate_impl_if_hinted_methods.get(method)
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
assert isinstance(other, SymNode)
# TODO: consider constant prop here
try:
out = func(self.expr, other.expr)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
out = safe_expand(out)
sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out)
pytype: Type
# This is not strictly correct. In Python, a**b may return complex when
# a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
# returns a float while both arguments are ints: 2**(-1). Also, max and
# min do not type promote. To avoid having data-dependent control flow
# here, we just set the type to float if one of the args is a float. In
# case of a type mismatch, we assume that it will be detected during
# evaluation.
if method in always_float_magic_methods:
pytype = float
elif method in always_bool_magic_methods:
pytype = bool
elif self.pytype is float or other.pytype is float:
pytype = float
else:
pytype = self.pytype
if (
pytype is not None
and out_hint is not None
and not isinstance(out_hint, SymTypes)
):
out_hint = pytype(out_hint)
# Create a FX node that corresponds to the operation being applied to
# this node.
fx_node, _ = self.shape_env._create_fx_call_function(
op, (self.fx_node, other.fx_node)
)
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
if sym_function_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
if method == "floor" or method == "ceiling":
expr = self.shape_env._simplify_floor_div(expr)
try:
out = func(expr)
except Exception:
log.warning("failed to eval %s(%s)", method, expr)
raise
sym_node_log.debug("%s %s -> %s", func, expr, out)
out_hint = None
if self.hint is not None:
out_hint = op(self.hint)
out = safe_expand(out)
pytype: Type
if method in always_int_magic_methods:
pytype = int
elif method in always_bool_magic_methods:
pytype = bool
elif method in always_float_magic_methods:
pytype = float
else:
pytype = self.pytype
fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
if method in unary_methods:
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint
if sym_function_mode():
return to_node(
pred_node,
handle_sym_dispatch(
sym_ite,
(
wrap_node(pred_node),
wrap_node(then_node),
wrap_node(else_node),
),
{},
),
)
try:
out = func(pred_node.expr, then_node.expr, else_node.expr)
except Exception:
log.warning(
"failed to eval %s(%s, %s, %s)",
method,
pred_node.expr,
then_node.expr,
else_node.expr,
)
raise
out = safe_expand(out)
fx_node, _ = pred_node.shape_env._create_fx_call_function(
sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
)
return SymNode(
out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
)
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
expr = self.expr
try:
out = func(expr, ndigits)
except Exception:
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
raise
out = safe_expand(out)
pytype = int if ndigits is None else self.pytype
out_hint = None
if self.hint is not None:
out_hint = op(self.hint, ndigits)
# Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
# same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
# without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
# hack down below works, because all round function down the line all take ndigits=None as default in their
# signature.
# TODO: Remove the args construction below if a different sentinel is used by FX.
args = [self.fx_node]
if ndigits is not None:
args.append(ndigits)
fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
setattr(SymNode, f"_{method_attr}", round_impl)
else:
setattr(SymNode, f"_{method_attr}", binary_magic_impl)
def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
op = getattr(sys.modules[__name__], method)
if sym_function_mode():
return to_node(
self,
handle_sym_dispatch(
op,
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
{},
),
)
size_exprs = [s.expr for s in sizes]
stride_exprs = [s.expr for s in strides]
try:
out = func(size_exprs, stride_exprs)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
raise
# bool is never expandable
size_hints = []
out_hint = None
for s in sizes:
if s.hint is None:
break
size_hints.append(s.hint)
else:
stride_hints = []
for s in strides:
if s.hint is None:
break
stride_hints.append(s.hint)
else:
out_hint = op(size_hints, stride_hints)
# NB: This is the indicator function, not the actual bool!
pytype: Type
if method.endswith("_indicator"):
pytype = int
else:
pytype = bool
return SymNode(out, self.shape_env, pytype, out_hint)
setattr(SymNode, f"_{method}", sizes_strides_impl)
# TODO: This is technically hotpath, but in the ideal end state
# guards on this will resolve at a higher level so you never
# spend time in this code
def sizes_strides_user(sizes, strides):
import sympy
from torch.fx.experimental.symbolic_shapes import (
eval_is_non_overlapping_and_dense,
)
for a in itertools.chain(sizes, strides):
if isinstance(a, SymInt):
return wrap_node(
getattr(a.node, method)(
[to_node(a.node, b) for b in sizes],
[to_node(a.node, b) for b in strides],
)
)
if method == "is_non_overlapping_and_dense_indicator":
return eval_is_non_overlapping_and_dense(sizes, strides)
else:
# TODO: this is an awful implementation
return bool(
func(
[sympy.sympify(a) for a in sizes],
[sympy.sympify(a) for a in strides],
)
)
# Skip for is_non_overlapping_and_dense_indicator
if not hasattr(sys.modules[__name__], method):
setattr(sys.modules[__name__], method, sizes_strides_user)
for method, func in magic_methods.items():
_make_node_magic(method, func)
for method, func in sizes_strides_methods.items():
_make_node_sizes_strides(method, func)
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"sym_{method}"
else:
method_attr = method
def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
if isinstance(x, (int, float, bool)):
return x
if isinstance(x, SymBool):
return x.node.guard_bool("", 0)
raise AssertionError("expect to be called with constant SymBools")
def is_constant(x):
if isinstance(x, (int, float, bool)):
return True
if isinstance(x, (SymInt, SymFloat, SymBool)):
return x.node.is_constant()
return False
if method in bool_becomes_int_magic_methods:
def promote(x):
"""Implements True+True=2, which works in python but not sympy"""
if isinstance(x, SymBool):
return SymInt(x.node.wrap_int(int(x)))
return x
else:
def promote(x):
return x
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
def unary_magic_impl(self):
self = promote(self)
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
def binary_magic_impl(self, other):
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
self = promote(self)
other = promote(other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(self.node, method_attr)(other_node))
return get_constant(ret) if is_constant(ret) else ret
def rbinary_magic_impl(self, other):
self = promote(self)
other = promote(other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
if other_node is NotImplemented:
return NotImplemented
ret = wrap_node(getattr(other_node, method_attr)(self.node))
return get_constant(ret) if is_constant(ret) else ret
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
elif method in unary_nonmagic_methods:
orig = getattr(user_type, method)
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
elif method == "sym_ite":
def sym_ite_magic_impl(pred, then_val, else_val):
pred_node = pred.node
then_node = to_node(pred_node, then_val)
else_node = to_node(pred_node, else_val)
if then_node is NotImplemented or else_node is NotImplemented:
return NotImplemented
assert (
isinstance(then_node, SymNode)
and isinstance(else_node, SymNode)
and then_node.pytype == else_node.pytype
)
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
return get_constant(ret) if ret.node.is_constant() else ret
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
elif method == "round":
def round_magic_impl(self, ndigits=None):
if is_constant(self):
return builtins.round(get_constant(self), ndigits)
return wrap_node(getattr(self.node, method)(ndigits))
setattr(user_type, f"__{method}__", round_magic_impl)
else:
setattr(user_type, f"__{method}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment]
if method in only_bool_magic_methods:
_make_user_magic(method, SymBool)
continue
if method in only_float_magic_methods:
_make_user_magic(method, SymFloat)
continue
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
_make_user_magic(method, SymBool)
_make_user_magic(method, SymInt)
_make_user_magic(method, SymFloat)
del method
del func
|