Spaces:
Running
Running
import functools | |
import logging | |
import math | |
import operator | |
import sympy | |
import builtins | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union | |
import torch | |
import torch.fx | |
import torch.fx.traceback as fx_traceback | |
from torch._dynamo.exc import TorchDynamoException | |
from torch.fx.node import Argument, Target | |
from torch.utils._sympy.interp import sympy_interp | |
log = logging.getLogger(__name__) | |
try: | |
import z3 # type: ignore[import] | |
# Translation Validation for Dynamo guards | |
# ======================================== | |
# | |
# Checks whether optimizations applied to the collected guards are | |
# valid. In other words, whether the guard function we actually run | |
# does not have false positives (unsound). | |
# | |
# In order to do so, we build the guards using 2 different information | |
# attached to each 'SymNode': | |
# 1. SymPy expressions | |
# 2. FX nodes | |
# | |
# SymPy expressions have implicit optimizations baked within itself, | |
# which may have a few bugs. On the other hand, we build the FX graph | |
# manually, with no optimizations enabled. This gives us access to | |
# the "ground truth". | |
# | |
# We then convert into Z3 expressions both the SymPy expressions | |
# (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function | |
# and the FX nodes (see [Note: PopulateValidator]) that go through | |
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. | |
# (see [Note: TranslationValidator]) | |
# Better Z3 to string implementation (for a small fraction of Z3). | |
# | |
# Here are the things we clean before showing the Z3 expression: | |
# - Rename a few ops (e.g. "Distinct" ==> "!=") | |
# | |
# - Ignore ToInt and ToReal operations: | |
# usually they don't really matter | |
# | |
# - Transform (ToInt (/ ...)) into (idiv ...): | |
# this is the pattern for floor division | |
# | |
# - Collect a chain of the same operations into one | |
def z3str(e: z3.ExprRef) -> str: | |
assert z3.is_expr(e), f"unsupported expression type: {e}" | |
def get_args_str(e: z3.ExprRef) -> List[str]: | |
return [z3str(e.arg(i)) for i in range(e.num_args())] | |
# First, we simplify the given expression. | |
# This is done using rewriting rules, so shouldn't take long. | |
e = z3.simplify(e) | |
# Only support function applications. | |
# Even Z3 "variables" are, in fact, function applications. | |
if not z3.is_app(e): | |
raise ValueError(f"can't print Z3 expression: {e}") | |
if z3.is_int_value(e) or z3.is_rational_value(e): | |
return e.as_string() # type: ignore[attr-defined] | |
decl = e.decl() | |
kind = decl.kind() | |
op = str(decl) | |
args = get_args_str(e) | |
if kind == z3.Z3_OP_POWER: | |
op = "pow" | |
elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): | |
# Collect the arguments of chains of ADD and MUL. | |
# This is safe, since they are associative. | |
def collect_str_args(e): | |
if not (z3.is_app(e) and e.decl().kind() == kind): | |
return [z3str(e)] | |
else: | |
return [ | |
x | |
for i in range(e.num_args()) | |
for x in collect_str_args(e.arg(i)) | |
] | |
args = collect_str_args(e) | |
elif kind == z3.Z3_OP_NOT: | |
# Revert some conversions that z3.simplify applies: | |
# - a != b ==> (Not (== a b)) ==> (!= a b) | |
# - a < b ==> (Not (<= b a)) ==> (> b a) | |
# - a > b ==> (Not (<= a b)) ==> (> a b) | |
assert e.num_args() == 1 | |
arg = e.arg(0) | |
assert z3.is_app(arg) | |
argkind = arg.decl().kind() | |
logic_inverse = { | |
z3.Z3_OP_EQ: "!=", | |
z3.Z3_OP_LE: ">", | |
z3.Z3_OP_GE: "<", | |
} | |
if argkind in logic_inverse: | |
op = logic_inverse[argkind] | |
args = get_args_str(arg) | |
elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): | |
assert e.num_args() == 1 | |
argstr = z3str(e.arg(0)) | |
# Check if it's the floor division pattern. | |
if argstr.startswith("(/"): | |
return "(idiv" + argstr[2:] | |
# Otherwise, just ignore it. | |
return argstr | |
elif kind == z3.Z3_OP_UNINTERPRETED: | |
assert e.num_args() == 0 | |
return str(decl) | |
string = op + " " + " ".join(args) | |
return f"({string.rstrip()})" | |
# Implementation of Python semantics as Z3 expressions. | |
# | |
# Z3 Real-Int theory has operators with semantics that differ that of | |
# Python. Therefore, in order to get it right, we need to implement | |
# the (Python) semantics we are relying on in Z3. | |
class _Z3Ops: | |
# Validator used for adding assertions as needed. | |
# e.g. div(a, b) requires b != 0. | |
validator: "TranslationValidator" | |
# The 2 functions below are used for conditionally casting between | |
# integer and reals. | |
# | |
# Returns a real expression from 'x'. | |
def to_real(x: z3.ArithRef) -> z3.ArithRef: | |
return x if x.is_real() else z3.ToReal(x) | |
# Returns an integer expression from 'x'. | |
def to_int(x: z3.ArithRef) -> z3.ArithRef: | |
return x if x.is_int() else z3.ToInt(x) | |
# Implements Python division semantics. | |
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: | |
self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] | |
return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) | |
def floor(self, number: z3.ArithRef) -> z3.ArithRef: | |
# Z3 ToInt function rounds a real number towards negative infinity. | |
return _Z3Ops.to_int(number) | |
# Python semantics for 'FloorDiv' states that before applying the floor | |
# function, the operands are converted to their common type. | |
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: | |
cast_result_to_real = numerator.is_real() or denominator.is_real() | |
result = _Z3Ops.to_int(self.div(numerator, denominator)) | |
# Since the 'result' is already an integer, we just have to check | |
# whether we should cast it to real. | |
return _Z3Ops.to_real(result) if cast_result_to_real else result | |
def ceil(self, number: z3.ArithRef) -> z3.ArithRef: | |
return z3.If( | |
self.floor(number) < number, | |
self.floor(number + 1), | |
number | |
) # type: ignore[return-value] | |
def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: | |
return z3.If(a > b, a, b) # type: ignore[return-value] | |
def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: | |
return z3.If(a < b, a, b) # type: ignore[return-value] | |
# Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q | |
# It should work with both integer and reals. | |
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: | |
return p - self.floordiv(p, q) * q | |
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: | |
# Z3 can't handle complex numbers very well. | |
self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] | |
return base ** exp | |
def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: | |
# Square-root: | |
# 1. Only work with reals | |
number = _Z3Ops.to_real(number) | |
# 2. The number should be positive or zero. | |
# Otherwise, Z3 returns 'unknown'. | |
self.validator.add_assertion(number >= 0) | |
return number ** 0.5 | |
def abs(self, number: z3.ArithRef) -> z3.ArithRef: | |
return z3.Abs(number) | |
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: | |
if ndigits is not None: | |
raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") | |
# Pythons builtin 'round' implements the 'round half to even' strategy | |
# See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even | |
# z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to | |
# floating point numbers, which is different from real numbers that we are dealing with here. | |
# Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and | |
# 'round half down' (ceil(x - 0.5)). | |
# Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... | |
# to round down, i.e. use the 'round half down' strategy | |
return z3.If( | |
self.mod(number, z3.IntVal(2)) == 0.5, | |
self.ceil(number - 0.5), | |
self.floor(number + 0.5), | |
) | |
# Lifts a callable to be used in Z3. | |
# | |
# This function replaces the given 'op' by a function that: | |
# | |
# 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) | |
# | |
# 2. Calls an operation that corresponds to 'op', but works with Z3 | |
# inhabitants (left as is if it works as is) | |
def z3op(op: Callable, validator: "TranslationValidator") -> Callable: | |
# Operations that have booleans as their argument. | |
# This is needed because the argument of some FX nodes were | |
# literal integers, instead of booleans. So, whenever this flag | |
# is set, we also convert ints to booleans. | |
boolean_ops = {operator.not_, operator.and_, operator.or_} | |
as_bool = op in boolean_ops | |
# Lifts the function into 'z3.ExprRef' domain. | |
def lift(func): | |
def wrap(a) -> z3.ExprRef: | |
if isinstance(a, (z3.ArithRef, z3.BoolRef)): | |
return a | |
# Convert it into a Z3 value, if it is some of the supported | |
# types below. | |
if isinstance(a, bool) or (as_bool and isinstance(a, int)): | |
return z3.BoolVal(bool(a)) | |
if isinstance(a, (int, sympy.Integer)): | |
return z3.IntVal(int(a)) | |
if isinstance(a, (float, sympy.Float)): | |
return z3.RealVal(float(a)) | |
raise ValueError(f"can't lift type: {type(a)}") | |
def wrapper(*args): | |
# Lifts the arguments into a list of Z3 inhabitants. | |
wrapped_args = (wrap(a) for a in args) | |
# Run the function on the Z3 expressions. | |
return func(*wrapped_args) | |
return wrapper | |
ops = _Z3Ops(validator) | |
replacement_map = { | |
# Operator module. | |
operator.not_: lift(z3.Not), | |
operator.and_: lift(z3.And), | |
operator.or_: lift(z3.Or), | |
operator.floordiv: lift(ops.floordiv), | |
operator.truediv: lift(ops.div), | |
operator.mod: lift(ops.mod), | |
operator.abs: lift(ops.abs), | |
builtins.round: lift(ops.round), | |
# Math module. | |
math.ceil: lift(ops.ceil), | |
math.floor: lift(ops.floor), | |
# Torch module. | |
torch.sym_float: lift(ops.to_real), | |
torch.sym_max: lift(ops.max), | |
torch.sym_min: lift(ops.min), | |
torch.sym_ite: lift(lambda b, t, f: t if b else f), | |
torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] | |
# Not lifted because we only use this function as a | |
# marker for adding the expression as validator input. | |
torch._assert: torch._assert, | |
} | |
return replacement_map[op] if op in replacement_map else lift(op) | |
# Processes an FX graph, populating the given validator. | |
# | |
# [Note: PopulateValidator] | |
# This class walks through each node in the FX graph, translating | |
# them into the Z3 world. | |
# | |
# Then, whenever it finds an 'torch._assert' call_function operation, | |
# it adds the Z3 expression corresponding to the argument as validator | |
# input. | |
class PopulateValidator(torch.fx.Interpreter): | |
def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): | |
# Reference to the translation validator. | |
self.validator = validator | |
# Build the graph module and call `Interpreter` constructor. | |
module = torch.fx.GraphModule(root={}, graph=graph) | |
super().__init__(module, garbage_collect_values=True) | |
def placeholder(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: | |
symbol = fx_traceback.get_current_meta()["symbol"] | |
return self.validator.z3var(symbol) | |
def call_function(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: | |
if target != torch._assert: | |
# Actually runs the node target function (which is already | |
# lifted) with its arguments. | |
return super().call_function(target, args, kwargs) | |
# Adds the Z3 expression corresponding to the first argument | |
# as a validator input. | |
assert len(args) == 1, f"expected 1 argument on assertion. Got: {len(args)} " | |
self.validator.add_source_expr(args[0]) # type: ignore[arg-type] | |
# Translates SymPy expressions into Z3 expressions. | |
# | |
# [Note: SympyToZ3] | |
# At the time of the translation, all free variables present in the | |
# SymPy expression being translated must be already mapped to a Z3 | |
# integer variable. | |
class SympyToZ3: | |
OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} | |
def __init__( | |
self, | |
validator: "TranslationValidator", | |
) -> None: | |
self._validator = validator | |
self._ops = _Z3Ops(self._validator) | |
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: | |
if dtype is torch.int64: | |
return z3.IntVal(int(value)) | |
if dtype is torch.double: | |
return z3.RealVal(float(value)) | |
if dtype is torch.bool: | |
return z3.BoolVal(bool(value)) | |
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") | |
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: | |
return self._ops.div(numerator, denominator) | |
def floordiv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: | |
return self._ops.floordiv(numerator, denominator) | |
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: | |
return self._ops.floordiv(numerator, denominator) | |
def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: | |
return self._ops.pow(base, exp) | |
def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: | |
return self._ops.mod(p, q) | |
def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: | |
return self._ops.round(number, ndigits) | |
def __getattr__(self, name: str) -> Any: | |
REPLACEMENT = { | |
"and_": z3.And, | |
"or_": z3.Or, | |
"not_": z3.Not, | |
"floor": self._ops.floor, | |
"ceil": self._ops.ceil, | |
"minimum": self._ops.min, | |
"maximum": self._ops.max, | |
} | |
if name in REPLACEMENT: | |
return REPLACEMENT[name] | |
if name in self.OPERATOR_HANDLES: | |
return getattr(operator, name) | |
raise AttributeError(f"unhandled operator: {name}") | |
def run(self, expr: sympy.Basic) -> z3.ExprRef: | |
return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] | |
# Dynamo guards translation validator. | |
# | |
# [Note: TranslationValidator] | |
# Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. | |
# That is: whether those (target) guards only yield TRUE whenever the original, | |
# unoptimized, (source) guards yield TRUE. | |
# | |
# More concretely, given 'source' and 'target' guard expressions, we wish to | |
# check whether the following expression holds: | |
# | |
# Not(And(source)) AND And(target) | |
# | |
# i.e. whether there is an assignment of the free variables where the opposite | |
# happens: target is TRUE, but source is FALSE. | |
class TranslationValidator: | |
def __init__(self) -> None: | |
log.debug("new instance") | |
# Mapping of SymPy symbols to Z3 variables. | |
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} | |
# Set of source Z3 expressions. | |
# They represent the generated guards without any kind of | |
# simplification or transformation. | |
self._source_exprs: Set[z3.BoolRef] = set() | |
# Set of target Z3 expressions. | |
# They represent the actual checked guards at runtime. They might | |
# be simplified or transformed versions of the source guards. | |
self._target_exprs: Set[z3.BoolRef] = set() | |
# Set of Z3 expressions representing assertions over both the | |
# source and target expressions. | |
self._assertions: Set[z3.BoolRef] = set() | |
# Retrieves the corresponding Z3 variable. | |
def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: | |
assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" | |
return self.symbols[symbol] | |
# Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. | |
def add_var(self, symbol: sympy.Symbol, type: Type) -> z3.ExprRef: | |
if symbol in self.symbols: | |
return self.symbols[symbol] | |
log.debug("new variable: %s (%s)", symbol.name, type.__name__) | |
if type is int: | |
var = z3.Int(symbol.name) | |
# If 'symbol' is positive (SymPy assumption), we have to | |
# convey it to Z3 as well. | |
if symbol.is_positive: # type: ignore[attr-defined] | |
self._target_exprs.add(var > 0) | |
elif type is float: | |
var = z3.Real(symbol.name) | |
elif type is bool: | |
var = z3.Bool(symbol.name) | |
else: | |
raise RuntimeError(f"unsupported type for Z3 variable: {type}") | |
self.symbols[symbol] = var | |
return var | |
# Checks whether all symbols were already added. | |
def _check_freesymbols(self, e: sympy.Basic) -> None: | |
for s in e.free_symbols: | |
assert isinstance(s, sympy.Symbol) | |
# Call 'z3var' just to check whether there's already a | |
# Z3 variable corresponding to 's'. | |
self.z3var(s) | |
def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: | |
z3expr = SympyToZ3(self).run(e) | |
assert isinstance(z3expr, z3.BoolRef), f"expected boolean expression. Got: {z3expr}" | |
return z3expr | |
def add_source_expr(self, e: z3.BoolRef) -> None: | |
if e not in self._source_exprs: | |
log.debug("add source guard: %s", z3str(e)) | |
self._source_exprs.add(e) | |
def add_target_expr(self, e: sympy.Expr) -> None: | |
self._check_freesymbols(e) | |
z3expr = self.to_z3_boolean_expr(e) | |
if e not in self._target_exprs: | |
log.debug("add target guard: %s", z3str(z3expr)) | |
self._target_exprs.add(z3expr) | |
def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: | |
if isinstance(e, sympy.Basic): | |
self._check_freesymbols(e) | |
ref = self.to_z3_boolean_expr(e) | |
else: | |
ref = e | |
assert isinstance(ref, z3.BoolRef) | |
if ref not in self._assertions: | |
log.debug("add assertion: %s", z3str(ref)) | |
self._assertions.add(ref) | |
def validate(self) -> None: | |
from torch._dynamo.utils import dynamo_timed | |
if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: | |
# If there are no source/target expressions, there's nothing we really | |
# wish to prove. So, we just return. | |
return None | |
# Here, we use "QF_NRA" logic for the solver: | |
# "Quantifier-free Non-linear Real Arithmetic". | |
# | |
# Most of the guards expressions have: | |
# 1. arithmetic between integer and reals | |
# 2. no quantifiers | |
# 3. potentially non-linear. | |
# | |
# Although there's also "QF_NIRA" (mixed integer-real arithmetic), | |
# "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. | |
solver = z3.SolverFor("QF_NRA") | |
# Set a timeout for finding a solution. | |
solver.set(timeout=translation_validation_timeout()) | |
# Add all the assertions to the solver. | |
for assertion in self._assertions: | |
solver.add(assertion) | |
# "Is there any case where it's TRUE for the target expressions, | |
# but FALSE for the source expressions?" | |
solver.add(z3.Not(z3.And(*self._source_exprs))) | |
solver.add(*self._target_exprs) | |
log.debug("translation validation: start") | |
r = dynamo_timed()(solver.check)() | |
if r == z3.sat: | |
# Target expressions are unsound. | |
# Log the found model and the source expressions that failed. | |
model = solver.model() | |
raise ValidationException( | |
model, self._assertions, self._target_exprs, | |
failed_source_exprs=[ | |
inp for inp in self._source_exprs if not model.evaluate(inp) | |
] | |
) | |
else: | |
if r == z3.unknown: | |
# Could not find a solution. It didn't fail, but it also | |
# didn't succeed. Canceling the validation execution (keyboard | |
# interrupt) also gets to this branch. | |
log.warning("translation validation: could not validate: got z3.unknown") | |
else: | |
# Target expressions are sound. | |
assert r == z3.unsat | |
log.debug("translation validation: success") | |
except ImportError: | |
_HAS_Z3 = False | |
__all__ = [ | |
"translation_validation_enabled", "translation_validation_timeout", | |
"ValidationException", "BisectValidationException", | |
] | |
else: | |
_HAS_Z3 = True | |
__all__ = [ | |
"z3str", "z3op", "PopulateValidator", "SympyToZ3", "TranslationValidator", | |
"translation_validation_enabled", "translation_validation_timeout", | |
"ValidationException", "BisectValidationException", | |
] | |
from torch.fx.experimental import _config as config | |
def translation_validation_enabled() -> bool: | |
# Checks everytime this function is called, in case the Dynamo | |
# option is set, but Z3 is not installed. | |
_assert_z3_installed_if_tv_set() | |
return _HAS_Z3 and config.translation_validation | |
def translation_validation_timeout() -> int: | |
return config.translation_validation_timeout | |
def _assert_z3_installed_if_tv_set(): | |
assert _HAS_Z3 or not config.translation_validation, ( | |
"translation validation requires Z3 package. Please, either install " | |
"z3-solver or disable translation validation." | |
) | |
class ValidationException(TorchDynamoException): | |
def __init__(self, model, assertions, target_exprs, failed_source_exprs): | |
assert _HAS_Z3 | |
def symbolstr(sym) -> str: | |
return f"{sym}: {model[sym]}" | |
def joinlines(xs) -> str: | |
return "\n".join(f" ==> {x}" for x in xs) | |
model_str = joinlines(sorted(map(symbolstr, model))) | |
assertions_str = joinlines(sorted(map(z3str, assertions))) | |
target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) | |
failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) | |
self.msg = "translation validation failed." | |
self.details = f"""\ | |
Model: | |
{model_str} | |
Assertions: | |
{assertions_str} | |
Target Expressions: | |
{target_exprs_str} | |
Failed Source Expressions: | |
{failed_source_exprs_str}""" | |
def __str__(self): | |
return f"{self.msg}\n\n{self.details}" | |
class BisectValidationException(TorchDynamoException): | |
def __init__(self, validation_exc, expr, failed_action, traced_node): | |
self.msg = f"translation validation failed when {failed_action}: {expr}" | |
self.details = f"""\ | |
Failure occurred while running node: | |
{traced_node.format_node()} | |
{validation_exc.details}""" | |
def __str__(self): | |
return f"{self.msg}\n\n{self.details}" | |
# Checks when this module is loaded. | |
_assert_z3_installed_if_tv_set() | |
# Translation validation bisection. | |
# | |
# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise | |
# the earliest ValidationException. | |
# | |
# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors | |
# might be silently happening. This function tries to nail down exactly at which | |
# point things went wrong from a validation perspective. | |
def bisect(shape_env): | |
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SHAPEENV_EVENT_KEY, CURRENT_NODE_KEY | |
from torch.fx.experimental.recording import FakeTensorMeta, ShapeEnvEvent, replay_shape_env_events | |
events = shape_env.events | |
# Retrieves the ShapeEnvEvent associated with node. | |
def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: | |
assert SHAPEENV_EVENT_KEY in node.meta | |
return events[node.meta[SHAPEENV_EVENT_KEY]] | |
# Creates a new instance of fake, but updating every symbolic value's ShapeEnv | |
# reference to the one given as argument. | |
# | |
# This is needed so as not to simplify a symbolic expression using a ShapeEnv | |
# "from the future", where it may have a different set of replacements. | |
def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: | |
if isinstance(fake, int): | |
return fake | |
if isinstance(fake, torch.SymInt): | |
return torch.SymInt(fake.node.with_shape_env(shape_env)) | |
assert isinstance(fake, FakeTensorMeta) | |
return FakeTensorMeta( | |
tuple(new_with_shape_env(shape_env, s) for s in fake.size()), | |
tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), | |
new_with_shape_env(shape_env, fake.storage_offset()), | |
fake.is_nested, | |
) | |
# Checks whether the given shape_env fails when produce_guards is called. | |
def check_shapeenv_fails(shape_env: ShapeEnv, tracked_fakes: Optional[List[Any]]) -> Optional[ValidationException]: | |
assert tracked_fakes is not None | |
try: | |
# This produce_guards call is a best-effort replication, since we | |
# don't populate EqualityConstraint list. Reason: we would also have | |
# to save OutputGraph.tracked_fakes_id_to_source. | |
shape_env.produce_guards( | |
[new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], | |
[a.source for a in tracked_fakes], | |
input_contexts=[a.symbolic_context for a in tracked_fakes], | |
) | |
return None | |
except ValidationException as e: | |
return e | |
# Checks whether the ShapeEnv reconstructed by replaying the events until | |
# node is created fails when produce_guards is called. | |
def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: | |
number = node.meta[SHAPEENV_EVENT_KEY] | |
# Reconstruct shape_env until the event at event_number. | |
shape_env = replay_shape_env_events(events[:number + 1]) | |
shape_env.graph.lint() | |
return check_shapeenv_fails(shape_env, events[number].tracked_fakes) | |
last_exception = check_shapeenv_fails(shape_env, shape_env._snapshot_tracked_fakes()) | |
if not last_exception: | |
# We don't actually fail due to a produce_guards call. | |
# Stop and don't bisect. | |
log.info("translation validation succeeded: no errors found.") | |
return | |
if not shape_env.should_record_events or config.translation_validation_no_bisect: | |
# Bisection is off. | |
# Return the last ValidationException we got. | |
raise last_exception | |
# Cache the raised exception (if any) at each bisection point. | |
exception = {} | |
# Bisection happens on the assertion nodes of the recorded FX graph for | |
# dynamic shapes. | |
assert_nodes = [node for node in shape_env.graph.nodes if node.target == torch._assert] | |
# Preparing the indices for binary search. | |
left, mid, right = 0, 0, len(assert_nodes) - 1 | |
while left < right: | |
mid = (left + right) // 2 | |
node = assert_nodes[mid] | |
log.debug("bisecting at %s: %s", mid, get_node_event(node)) | |
# Check whether the new shape_env raises a ValidationException or not. | |
exception[mid] = check_node_fails(node) | |
if exception[mid]: | |
right = mid | |
else: | |
left = mid + 1 | |
assert left in exception and isinstance(exception[left], ValidationException) | |
node = assert_nodes[left] | |
event = get_node_event(node) | |
if event.is_evaluate_expr(): | |
failed_action = "evaluating" | |
else: | |
assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" | |
failed_action = "adding runtime assert" | |
args = event.args | |
assert args is not None | |
assert len(args) >= 2, ( | |
f"bisecting expects {event.name} to have at least 2 positional arguments. " | |
f"Got: {len(args)}" | |
) | |
assert isinstance(args[1], sympy.Basic), ( | |
f"bisecting expects {event.name} to have a SymPy expression as its second argument. " | |
f"Got: {type(args[1])}" | |
) | |
raise BisectValidationException( | |
exception[left], | |
expr=args[1], | |
failed_action=failed_action, | |
traced_node=node.meta[CURRENT_NODE_KEY], | |
) | |