Spaces:
Running
Running
import builtins | |
import dataclasses | |
import inspect | |
import math | |
import sys | |
import weakref | |
from collections import defaultdict | |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union | |
import torch | |
from torch._subclasses.fake_tensor import FakeTensor | |
from torch.utils._pytree import SUPPORTED_NODES | |
from .exported_program import ExportedProgram | |
if TYPE_CHECKING: | |
from sympy import Symbol | |
from torch._guards import Source | |
from ..fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint | |
__all__ = ["Constraint", "Dim", "dims", "dynamic_dim"] | |
class _Dim(type): | |
""" | |
Metaclass for :func:`Dim` types. | |
""" | |
def readable(name, min_, max_): | |
if min_ == 2: | |
min_ = None | |
if max_ == sys.maxsize - 1: | |
max_ = None | |
if min_ is None and max_ is None: | |
return f"Dim('{name}')" | |
if min_ is None: | |
return f"Dim('{name}', max={max_})" | |
if max_ is None: | |
return f"Dim('{name}', min={min_})" | |
return f"Dim('{name}', min={min_}, max={max_})" | |
def __add__(cls, other): | |
# e.g., dim + 1 | |
if type(other) is not int: | |
raise NotImplementedError( | |
f"Attempted to add {other} to {cls.__name__}, where an integer was expected. " | |
"(Only increasing linear operations with integer coefficients are supported.)" | |
) | |
return cls._derive(lambda x: x + other) | |
def __radd__(cls, other): | |
return cls + other | |
def __sub__(cls, other): | |
# e.g., dim - 1 | |
if type(other) is not int: | |
raise NotImplementedError( | |
f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. " | |
"(Only increasing linear operations with integer coefficients are supported.)" | |
) | |
return cls._derive(lambda x: x - other) | |
def __rsub__(cls, other): | |
raise NotImplementedError( | |
f"Attempted to negate {cls.__name__}. " | |
"(Only increasing linear operations with integer coefficients are supported.)" | |
) | |
def __mul__(cls, other): | |
# e.g., dim * 2 | |
if type(other) is not int or other <= 0: | |
raise NotImplementedError( | |
f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. " | |
"(Only increasing linear operations with integer coefficients are supported.)" | |
) | |
return cls._derive(lambda x: x * other) | |
def __rmul__(cls, other): | |
return cls * other | |
def _derived_name(cls, fn): | |
from sympy import sympify | |
return str(fn(sympify(cls.__name__))) | |
def _derive(cls, fn): | |
return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn}) | |
class _DerivedDim(_Dim): | |
""" | |
Metaclass for derived :func:`Dim` types. | |
Currently we only support increasing linear expressions with integer coefficients. | |
In other words, a derived Dim can always be written in the form Ax + B, where | |
x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. | |
(In particular, the latter ensures that x < y => Ax + B < Ay + B.) | |
These restrictions on the form of derived Dims makes the metatheory simpler: e.g., | |
it simplifies computing ranges for derived Dims, solving for underlying regular Dims, | |
deciding equalities between derived Dims, and so on. | |
The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. | |
The range of a derived Dim is computed by mapping `fn` over the range of its `root`. | |
""" | |
def min(self): | |
# assume that self.fn is an increasing function | |
# TODO(avik): use sympy value range analysis instead? | |
from sympy import Integer | |
_min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] | |
assert _min_symint >= 2, ( | |
f"Expected derived min value of {self.__name__} to be >= 2. " | |
f"Please specify an appropriate min value for {self.root.__name__} " # type: ignore[attr-defined] | |
f"(currently {self.root.min})." # type: ignore[attr-defined] | |
) | |
return int(_min_symint) | |
def max(self): | |
# assume that self.fn is an increasing function | |
# TODO(avik): use sympy value range analysis instead? | |
from sympy import Integer | |
_max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] | |
assert _max_symint <= sys.maxsize - 1, ( | |
f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " | |
f"Please specify an appropriate max value for {self.root.__name__} " # type: ignore[attr-defined] | |
f"(currently {self.root.max})." # type: ignore[attr-defined] | |
) | |
return int(_max_symint) | |
def _derive(self, fn): | |
# We support nesting, e.g., 2*dim + 1. | |
# This is implemented by composing operations on the same root. | |
# As a consequence, roots are always regular Dims (i.e., not derived Dims). | |
return _DerivedDim( | |
self._derived_name(fn), | |
(int,), | |
{"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined] | |
) | |
def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None): | |
""" | |
:func:`Dim` constructs a type analogous to a named symbolic integer with a range. | |
It can be used to describe multiple possible values of a dynamic tensor dimension. | |
Note that different dynamic dimensions of the same tensor, or of different tensors, | |
can be described by the same type. | |
Args: | |
name (str): Human-readable name for debugging. | |
min (Optional[int]): Minimum possible value of given symbol (inclusive) | |
max (Optional[int]): Maximum possible value of given symbol (inclusive) | |
Returns: | |
A type that can be used in dynamic shape specifications for tensors. | |
""" | |
_min = 2 if min is None else builtins.max(min, 2) | |
_max = sys.maxsize - 1 if max is None else builtins.min(max, sys.maxsize - 1) | |
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" | |
dim = _Dim(name, (int,), {"min": _min, "max": _max}) | |
dim.__module__ = getattr( | |
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__" | |
) | |
return dim | |
def dims(*names: str, min: Optional[int] = None, max: Optional[int] = None): | |
""" | |
Util to create multiple :func:`Dim` types. | |
""" | |
return tuple(Dim(name, min=min, max=max) for name in names) | |
class _ConstraintTarget: | |
""" | |
This represents input tensor dimensions. Don't create this | |
class directly; instead, use :func:`dynamic_dim`. | |
""" | |
w_tensor: Any # weakref to torch.Tensor | |
# TODO: We don't need t_id; we can get it off of w_tensor | |
t_id: int | |
dim: int | |
class _ConstraintFactory(type): | |
""" | |
Metaclass that ensures a private constructor for :class:`_Constraint` | |
""" | |
def __call__(cls, *args, **kwargs): | |
raise TypeError( | |
f"{cls.__module__}.{cls.__qualname__} has no public constructor. " | |
f"Please use torch.export.dynamic_dim() to create one" | |
) | |
def _create( | |
cls, w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None | |
): | |
return super().__call__( | |
w_tensor, t_id, dim, constraint_range, shared, debug_name | |
) | |
def _create_constraint( | |
w_tensor, t_id, dim, constraint_range, shared=None, debug_name=None | |
): | |
return _Constraint._create( | |
w_tensor, t_id, dim, constraint_range, shared, debug_name | |
) | |
class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): | |
""" | |
.. warning:: | |
Do not construct :class:`_Constraint` directly, use :func:`dynamic_dim` instead. | |
This represents constraints on input tensor dimensions, e.g., requiring | |
them to be fully polymorphic or within some range. | |
""" | |
# NOTE(avik): In the future, this could be Union[StrictMinMaxConstraint, <other kinds>] | |
constraint_range: "StrictMinMaxConstraint" | |
# Represent that `constraint_range` is shared with another _ConstraintTarget, which | |
# typically arises because of a specified equality with another dynamic dimension. | |
shared: Optional[_ConstraintTarget] = None | |
debug_name: Optional[str] = None | |
def _clone_with_range(self, lower=2, upper=math.inf): | |
# Import sympy locally | |
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint | |
from torch.utils._sympy.value_ranges import ValueRanges | |
constraint_range = StrictMinMaxConstraint( | |
vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), | |
warn_only=False, | |
) | |
return _create_constraint( | |
self.w_tensor, | |
self.t_id, | |
self.dim, | |
constraint_range, | |
self.shared, | |
self.debug_name, | |
) | |
def __ge__(self, lower): | |
return self._clone_with_range(lower=lower) | |
def __gt__(self, lower): | |
return self._clone_with_range(lower=lower + 1) | |
def __le__(self, upper): | |
return self._clone_with_range(upper=upper) | |
def __lt__(self, upper): | |
return self._clone_with_range(upper=upper - 1) | |
def __bool__(self): | |
# NOTE(avik): We do not support compound expressions like a <= x <= b. | |
# This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), | |
# and moreover, enforces that any overload of __bool__ must return True or False. | |
# FWIW, sympy also raises TypeError in this case. | |
raise TypeError( | |
"Cannot determine truth value of _Constraint. " | |
"If you are trying to combine _Constraint's with logical connectives, " | |
"you can specify them separately instead." | |
) | |
def serializable_spec(self): | |
# We need a serialization compatible format of the constraint so that it | |
# can be savedin the graph module w/o breaking the module serialization. | |
# The saved constraints will be used directly for the post-exporting pass | |
# that converts constraints to runtime assertion. The saved constraints | |
# will not be saved in the serialized module. | |
# TODO: A better way is needed. Currently we use 't_id' to map the constraint, | |
# which is not reliable | |
return { | |
"t_id": self.t_id, | |
"dim": self.dim, | |
"min": self.constraint_range.vr.lower, | |
"max": self.constraint_range.vr.upper, | |
} | |
def __eq__(self, other): | |
if not isinstance(other, _Constraint): | |
raise TypeError( | |
"A dynamic dim can be specified equal only to another dynamic dim. " | |
f"Equality with {type(other)} is not supported." | |
) | |
# import sympy locally | |
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint | |
constraint_range = StrictMinMaxConstraint( | |
vr=self.constraint_range.vr & other.constraint_range.vr, | |
warn_only=False, | |
) | |
if self.debug_name is None: | |
debug_name = other.debug_name | |
else: | |
assert other.debug_name is None or self.debug_name == other.debug_name | |
debug_name = self.debug_name | |
return _create_constraint( | |
self.w_tensor, | |
self.t_id, | |
self.dim, | |
constraint_range, | |
shared=_ConstraintTarget(other.w_tensor, other.t_id, other.dim), | |
debug_name=debug_name, | |
) | |
class _PhantomRoot: | |
""" | |
This represents the root of a derived Dim where the root does not directly | |
specify the shape of any input dimension, but the derived Dim does. | |
e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. | |
The fields `name`, `constraint_range`, and `val` carried by a phantom root | |
help create a symbol for it. Any derived dims with this phantom root are | |
backed by expressions over this symbol. | |
""" | |
name: str | |
constraint_range: "StrictMinMaxConstraint" | |
val: int | |
class _DerivedConstraint(_ConstraintTarget): | |
""" | |
This represents a derived Dim, whose root is either a regular constraint target | |
(which directly specifies the shape of some input dimension) or a phantom root | |
(which does so indirectly). | |
""" | |
# NOTE: This is not currently a subclass of _Constraint because we do not support | |
# `shared` for derived `Dim`s. Indeed, sharing is a necessary concept only for | |
# legacy constraints based on `dynamic_dim`: equality can be expressed simply by | |
# reusing the same (derived or normal) `Dim`. | |
root: Union[_ConstraintTarget, _PhantomRoot] | |
fn: Callable | |
constraint_range: "StrictMinMaxConstraint" | |
debug_name: Optional[str] = None | |
def shared(self): | |
# Some code paths expect a union of _Constraint and _DerivedConstraint. | |
# Thus we expose a `shared` field that is always None. | |
# TODO(avik): clean this up | |
return None | |
def serializable_spec(self): | |
# same as _Constraint.serializable_spec | |
return { | |
"t_id": self.t_id, | |
"dim": self.dim, | |
"min": self.constraint_range.vr.lower, | |
"max": self.constraint_range.vr.upper, | |
} | |
Constraint = Union[_Constraint, _DerivedConstraint] | |
def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): | |
""" | |
.. warning:: | |
(This feature is DEPRECATED. See :func:`Dim` instead.) | |
:func:`dynamic_dim` constructs a :class:`_Constraint` object that describes the dynamism of | |
a dimension ``index`` of tensor ``t``. :class:`_Constraint` objects should be passed to | |
``constraints`` argument of :func:`export`. | |
Args: | |
t (torch.Tensor): Example input tensor that have dynamic dimension size(s) | |
index (int): Index of dynamic dimension | |
Returns: | |
A :class:`_Constraint` object that describes shape dynamism. It can be passed to :func:`export` so | |
that :func:`export` does not assume static size of specified tensor, i.e. keeping it dynamic | |
as a symbolic size rather than specializing according to size of example tracing input. | |
Specifically :func:`dynamic_dim` can be used to express following types of dynamism. | |
- Size of a dimension is dynamic and unbounded:: | |
t0 = torch.rand(2, 3) | |
t1 = torch.rand(3, 4) | |
# First dimension of t0 can be dynamic size rather than always being static size 2 | |
constraints = [dynamic_dim(t0, 0)] | |
ep = export(fn, (t0, t1), constraints=constraints) | |
- Size of a dimension is dynamic with a lower bound:: | |
t0 = torch.rand(10, 3) | |
t1 = torch.rand(3, 4) | |
# First dimension of t0 can be dynamic size with a lower bound of 5 (inclusive) | |
# Second dimension of t1 can be dynamic size with a lower bound of 2 (exclusive) | |
constraints = [ | |
dynamic_dim(t0, 0) >= 5, | |
dynamic_dim(t1, 1) > 2, | |
] | |
ep = export(fn, (t0, t1), constraints=constraints) | |
- Size of a dimension is dynamic with an upper bound:: | |
t0 = torch.rand(10, 3) | |
t1 = torch.rand(3, 4) | |
# First dimension of t0 can be dynamic size with a upper bound of 16 (inclusive) | |
# Second dimension of t1 can be dynamic size with a upper bound of 8 (exclusive) | |
constraints = [ | |
dynamic_dim(t0, 0) <= 16, | |
dynamic_dim(t1, 1) < 8, | |
] | |
ep = export(fn, (t0, t1), constraints=constraints) | |
- Size of a dimension is dynamic and it is always equal to size of another dynamic dimension:: | |
t0 = torch.rand(10, 3) | |
t1 = torch.rand(3, 4) | |
# Sizes of second dimension of t0 and first dimension are always equal | |
constraints = [ | |
dynamic_dim(t0, 1) == dynamic_dim(t1, 0), | |
] | |
ep = export(fn, (t0, t1), constraints=constraints) | |
- Mix and match all types above as long as they do not express conflicting requirements | |
""" | |
from torch._dynamo.exc import UserError, UserErrorType | |
if not isinstance(t, torch.Tensor): | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, | |
f"Expected tensor as input to dynamic_dim but got {type(t)}", | |
) | |
if t.dim() < 1: | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, "Cannot mark 0-dimension tensors to be dynamic" | |
) | |
if index >= t.dim(): | |
raise UserError( | |
UserErrorType.DYNAMIC_DIM, | |
f"Expected the dimension passed to dynamic_dim to be in the range [0:{t.dim()-1}]" | |
f" but got {index}, which is out of bounds for the given tensor.", | |
) | |
# Import sympy locally | |
import sympy | |
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint | |
from torch.utils._sympy.value_ranges import ValueRanges | |
return _create_constraint( | |
weakref.ref(t), | |
id(t), | |
index, | |
StrictMinMaxConstraint( | |
vr=ValueRanges(lower=2, upper=sympy.oo), warn_only=False | |
), | |
debug_name=debug_name, | |
) | |
def _process_equalities( | |
constraint: Constraint, | |
get_sources: Callable[[int, int], List["Source"]], | |
shape_env: "ShapeEnv", | |
source_pairs: List[Tuple["Source", "Source"]], | |
derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], | |
phantom_symbols: Dict[str, "Symbol"], | |
): | |
""" | |
Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become | |
fields of `EqualityConstraint`) based on a given input `constraint`. | |
""" | |
source, *other_sources = get_sources(constraint.t_id, constraint.dim) | |
# When t.size()[dim] maps to src0, src1, ..., srcN, we add | |
# constraints that make src0 "equal" to src1, ..., srcN. | |
source_pairs.extend((source, other_source) for other_source in other_sources) | |
if not isinstance(constraint, _DerivedConstraint): | |
if constraint.shared is not None: | |
# Moreover, when t.size()[dim] is specified equal to t'.size()[dim'] | |
# and t'.size()[dim'] maps to src1', ..., srcN', we add | |
# constraints that also make src0 "equal" to src1', ..., srcN'. | |
other_sources = get_sources(constraint.shared.t_id, constraint.shared.dim) | |
source_pairs.extend( | |
(source, other_source) for other_source in other_sources | |
) | |
else: | |
# branch based on the root of the _DerivedConstraint | |
if not isinstance(constraint.root, _PhantomRoot): | |
# either root points to an input source | |
root = get_sources(constraint.root.t_id, constraint.root.dim)[0] # type: ignore[assignment] | |
else: | |
# or root points to a phantom symbol | |
if constraint.root.name in phantom_symbols: | |
root = phantom_symbols[constraint.root.name] # type: ignore[assignment] | |
else: | |
# create a phantom symbol in the shape env based on the _PhantomRoot | |
root = shape_env.create_symbol( | |
val=constraint.root.val, | |
source=torch._dynamo.source.ConstantSource(constraint.root.name), | |
dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, | |
constraint_dim=constraint.root.constraint_range, | |
) | |
phantom_symbols[constraint.root.name] = root # type: ignore[assignment] | |
fn = constraint.fn | |
# A derived equality (source, root, fn) informally corresponds to source = fn(root). | |
# Here source describes an input and root might describe another input or a phantom symbol. | |
derived_equalities.append((source, root, fn)) | |
def _process_dynamic_shapes( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, | |
) -> Optional[List[Constraint]]: | |
from collections import defaultdict | |
from collections.abc import Mapping, Sequence | |
from torch._dynamo.exc import UserError, UserErrorType | |
if dynamic_shapes is None or len(dynamic_shapes) == 0: | |
return None | |
kwargs = kwargs if kwargs is not None else {} | |
def tree_zip(combined_args, dynamic_shapes): | |
if isinstance(combined_args, (tuple, list)): | |
if not isinstance(dynamic_shapes, Sequence): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, " | |
f"got {dynamic_shapes} instead", | |
) | |
if len(combined_args) != len(dynamic_shapes): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected {dynamic_shapes} to have {len(combined_args)} items", | |
) | |
for i, shape in enumerate(dynamic_shapes): | |
yield from tree_zip(combined_args[i], shape) | |
elif isinstance(combined_args, dict): | |
if not isinstance(dynamic_shapes, Mapping): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, " | |
f"got {dynamic_shapes} instead", | |
) | |
if len(combined_args) != len(dynamic_shapes): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected {dynamic_shapes} to have {len(combined_args)} items", | |
) | |
for k, shape in dynamic_shapes.items(): | |
yield from tree_zip(combined_args[k], shape) | |
elif type(combined_args) in SUPPORTED_NODES: | |
if not isinstance(dynamic_shapes, Sequence): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a user-registered class (e.g., " | |
f"{type(combined_args)}) to be a Sequence that matches the " | |
f"flattened structure, but got {dynamic_shapes} instead", | |
) | |
yield from tree_zip( | |
SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0], | |
dynamic_shapes, | |
) | |
elif isinstance(combined_args, torch.Tensor): | |
yield (combined_args, dynamic_shapes) | |
else: | |
if dynamic_shapes is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expected dynamic_shapes of a {type(combined_args)} to be None, " | |
f"got {dynamic_shapes} instead", | |
) | |
# map of Dim names representing input shape dimensions to constraints on them | |
symbols: Dict[str, List[Constraint]] = defaultdict(list) | |
# track roots that do not directly represent input shape dimensions | |
phantom_roots: Dict[str, _PhantomRoot] = {} | |
derived_constraints_with_phantom_root: List[_DerivedConstraint] = [] | |
def to_constraint(dim, tensor, i): | |
import sympy | |
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint | |
from torch.utils._sympy.solve import try_solve | |
from torch.utils._sympy.value_ranges import ValueRanges | |
def root_value(): | |
# given tensor.shape[i] is the value of dim = fn(root), | |
# find the value of root | |
symbol = sympy.Symbol(dim.root.__name__, integer=True) | |
expr = dim.fn(symbol) | |
solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) | |
if solution is not None: | |
return int(solution[1]) # type: ignore[call-overload] | |
else: | |
raise UserError( # noqa: TRY200 | |
UserErrorType.CONSTRAINT_VIOLATION, | |
f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " | |
f"of the form {expr}, where {symbol} is an integer", | |
) | |
if isinstance(dim, _DerivedDim): | |
# generate a _DerivedConstraint where the root is: | |
# - either a _ConstraintTarget (if dim.root directly describes an input shape) | |
# - or a _PhantomRoot (otherwise) | |
dim_root = dim.root # type: ignore[attr-defined] | |
if dim_root.__name__ in symbols: | |
# root represents an input shape dimension | |
root_constraint = symbols[dim_root.__name__][0] | |
root = _ConstraintTarget( | |
root_constraint.w_tensor, | |
root_constraint.t_id, | |
root_constraint.dim, | |
) | |
elif dim_root.__name__ not in phantom_roots: | |
# create a phantom root | |
root = _PhantomRoot( # type: ignore[assignment] | |
name=dim_root.__name__, | |
constraint_range=StrictMinMaxConstraint( | |
vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), | |
warn_only=False, | |
), | |
val=root_value(), | |
) | |
phantom_roots[dim_root.__name__] = root # type: ignore[assignment] | |
else: | |
root = phantom_roots[dim_root.__name__] # type: ignore[assignment] | |
constraint = _DerivedConstraint( | |
weakref.ref(tensor), | |
id(tensor), | |
i, | |
root, | |
dim.fn, # type: ignore[attr-defined] | |
StrictMinMaxConstraint( | |
vr=ValueRanges(lower=dim.min, upper=dim.max), | |
warn_only=False, | |
), | |
debug_name=dim.__name__, | |
) | |
if isinstance(root, _PhantomRoot): | |
# NOTE(avik): since we have not processed all inputs yet, we may replace this | |
# with a root that does represent an input shape dimension later (see below) | |
derived_constraints_with_phantom_root.append(constraint) | |
else: | |
constraint = dynamic_dim(tensor, i, debug_name=dim.__name__) | |
if dim.min != 2: | |
constraint = constraint >= dim.min | |
if dim.max != sys.maxsize - 1: | |
constraint = constraint <= dim.max | |
return constraint | |
bounds: Dict[str, Tuple[int, int]] = {} | |
def check_same_bounds(dim): | |
if dim.__name__ in symbols: | |
min_, max_ = bounds[dim.__name__] | |
if dim.min != min_ or dim.max != max_: | |
this_ = _Dim.readable(dim.__name__, min_, max_) | |
that_ = _Dim.readable(dim.__name__, dim.min, dim.max) | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Found different definitions {this_} and {that_} " | |
f"for the same symbolic dimension {dim}!", | |
) | |
else: | |
bounds[dim.__name__] = (dim.min, dim.max) | |
def update_symbols(tensor, shape): | |
if isinstance(shape, dict): | |
for i, dim in shape.items(): | |
if isinstance(dim, _Dim): | |
check_same_bounds(dim) | |
constraint = to_constraint(dim, tensor, i) | |
symbols[dim.__name__].append(constraint) | |
else: | |
if dim is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " | |
"try None instead", | |
) | |
elif isinstance(shape, (tuple, list)): | |
for i, dim in enumerate(shape): | |
if isinstance(dim, _Dim): | |
check_same_bounds(dim) | |
constraint = to_constraint(dim, tensor, i) | |
symbols[dim.__name__].append(constraint) | |
else: | |
if dim is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected item #{i} ({dim}) in dynamic_shape {shape} of Tensor, " | |
"try None instead", | |
) | |
else: | |
if shape is not None: | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead", | |
) | |
import inspect | |
if isinstance(f, ExportedProgram): | |
f = f.module() | |
signature = ( | |
inspect.signature(f.forward) | |
if isinstance(f, torch.nn.Module) | |
else inspect.signature(f) | |
) | |
combined_args = signature.bind(*args, **kwargs).arguments | |
# This means user didn't specify dynamic shapes with argument names. | |
combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment] | |
for tensor, shape in tree_zip(combined_args, dynamic_shapes): | |
update_symbols(tensor, shape) | |
constraints = [] | |
for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: | |
phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] | |
if phantom_root_name in symbols: | |
# We found an input shape dimension corresponding to this name, so we | |
# do not need a phantom symbol for it after all. | |
# NOTE(avik): Overall we want to maintain the invariant that roots that | |
# are phantom symbols are really "phantom," i.e., they cannot be represented | |
# by any input source. This is important when we are deciding derived equalities, | |
# since we can focus our attention exclusively on input sources: deciding | |
# derived equalities involving phantom symbols are, in comparison, trivial. | |
derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] | |
for dynamic_dims in symbols.values(): | |
if all( | |
isinstance(dynamic_dim, _DerivedConstraint) for dynamic_dim in dynamic_dims | |
): | |
constraints.extend(dynamic_dims) | |
else: | |
primary, *others = dynamic_dims | |
if others: | |
for other in others: | |
constraints.append(primary == other) # type: ignore[arg-type] | |
else: | |
constraints.append(primary) | |
return constraints # type: ignore[return-value] | |
def _process_constraints( | |
fake_mode, | |
graph_module: torch.fx.GraphModule, | |
num_lifted_params_buffers: int, | |
example_inputs: List[torch.Tensor], | |
) -> Dict: | |
""" | |
Process the constraints stored in the graph module to return something more readable. | |
Args: | |
graph_module (torch.fx.GraphModule): GraphModule returned from | |
dynamo.export, which contains the "input_shape_constraints" and | |
"inline_constraints" metadata | |
example_inputs: Flattened list of example inputs used to export the graph module | |
Returns: | |
range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of | |
symbols (from SymInts) appearing in the fake tensors in | |
node.meta["val"] to their range constraints, which are a tuple | |
containing (lower, upper) constraints. | |
""" | |
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( | |
InputDim, | |
) | |
# Import sympy locally | |
from torch.fx.experimental.symbolic_shapes import SymInt | |
from torch.utils._sympy.value_ranges import ValueRanges | |
input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) | |
inline_constraints = graph_module.meta.get("inline_constraints", []) | |
# Create dict mapping tensor_id to node names | |
tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) | |
# Create dict mapping placeholder node names to their nodes | |
placeholder_nodes: Dict[str, torch.fx.Node] = {} | |
for i, node in enumerate(graph_module.graph.nodes): | |
if node.op != "placeholder": | |
# All placeholder nodes should be together in the beginning of the | |
# graph | |
break | |
if i >= num_lifted_params_buffers: | |
example_input = example_inputs[i - num_lifted_params_buffers] | |
tensor_id_to_nodes[id(example_input)].append(node.name) | |
placeholder_nodes[node.name] = node | |
# Create dict mapping (node name, dim) a list of range (lower, upper) | |
# constraints | |
multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) | |
for constraint in input_shape_constraints: | |
for node in tensor_id_to_nodes[constraint["t_id"]]: | |
node_dim = InputDim(node, constraint["dim"]) | |
# Accumulate range constraints | |
multi_range_constraints[node_dim].append( | |
ValueRanges(constraint["min"], constraint["max"]) | |
) | |
# Create dict mapping symbol to a singular range (lower, upper) | |
range_constraints: Dict[Any, ValueRanges] = {} | |
# Add inline constraints to range_constraints | |
range_constraints = { | |
symbol: inline_constraints[symbol] for symbol in inline_constraints | |
} | |
free_symbols: Set["Symbol"] = set() | |
# Add input range constraints to range_constraints | |
for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] | |
# Simplify the range constraints into a single range constraint | |
# Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] | |
min_vals = [rc.lower for rc in multi_range_constraint] | |
max_vals = [rc.upper for rc in multi_range_constraint] | |
min_val = max(min_vals) # type: ignore[type-var] | |
max_val = min(max_vals) # type: ignore[type-var] | |
assert min_val <= max_val # type: ignore[operator] | |
# Add input node range constraints | |
val = placeholder_nodes[input_dim.input_name].meta["val"] | |
assert isinstance(val, FakeTensor) | |
symint = val.shape[input_dim.dim] | |
assert isinstance( | |
symint, SymInt | |
), f"Expected SymInt but got {symint}: {type(symint)}" | |
symbol = symint.node.expr | |
range_constraints[symbol] = ValueRanges(min_val, max_val) | |
free_symbols.update(symbol.free_symbols) | |
for symbol in free_symbols: | |
if symbol not in range_constraints: | |
# Placeholders can have symbolic shapes that are derived expressions. | |
# The above code will record direct range constraints for them | |
# so that we can do runtime assertions. In addition, for serde checks | |
# we want to record range constraints for their root symbols. | |
range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] | |
return range_constraints | |