Spaces:
Running
Running
import collections | |
import dataclasses | |
import enum | |
from typing import Any, Optional, Union | |
from torch._guards import ChainedSource, GuardSource, Source | |
from . import utils | |
from .bytecode_transformation import create_call_function, create_instruction | |
from .utils import enum_repr | |
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, | |
# so those cases are omitted intentionally | |
_GUARD_SOURCE_NN_MODULE = { | |
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE, | |
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE, | |
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE, | |
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE, | |
} | |
_GUARD_SOURCE_FSDP_MODULE = { | |
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE, | |
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE, | |
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE, | |
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE, | |
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, | |
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, | |
} | |
_GUARD_SOURCE_NOT_NN_MODULE = { | |
GuardSource.LOCAL: GuardSource.LOCAL, | |
GuardSource.GLOBAL: GuardSource.GLOBAL, | |
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL, | |
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL, | |
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL, | |
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL, | |
} | |
def is_constant_source(source): | |
if isinstance(source, ConstantSource): | |
return True | |
try: | |
if source.guard_source() == GuardSource.CONSTANT: | |
return True | |
except NotImplementedError: | |
pass | |
return False | |
def reconstruct_getitem( | |
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice | |
): | |
source.base.reconstruct(codegen) | |
if isinstance(source.index, Source): | |
source.index.reconstruct(codegen) | |
else: | |
if index_is_slice: | |
assert isinstance(source, GetItemSource) | |
codegen.append_output(codegen.create_load_const(source.unpack_slice())) | |
else: | |
codegen.append_output(codegen.create_load_const(source.index)) | |
class LocalSource(Source): | |
local_name: str | |
cell_or_freevar: bool = False | |
def reconstruct(self, codegen): | |
codegen.append_output(codegen.create_load(self.local_name)) | |
def guard_source(self): | |
return GuardSource.LOCAL | |
def name(self): | |
return f"L[{repr(self.local_name)}]" | |
class SyntheticLocalSource(Source): | |
local_name: str | |
def reconstruct(self, codegen): | |
codegen.append_output(codegen.create_load(self.local_name)) | |
def guard_source(self): | |
return GuardSource.SYNTHETIC_LOCAL | |
def name(self): | |
return f"SYNTHETIC_LOCAL[{self.local_name!r}]" | |
class RandomValueSource(Source): | |
random_call_index: int | |
def guard_source(self): | |
return GuardSource.RANDOM_VALUE | |
def reconstruct(self, codegen): | |
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) | |
codegen.append_output(codegen.create_load_const(self.random_call_index)) | |
codegen.append_output(create_instruction("BINARY_SUBSCR")) | |
def name(self): | |
return f"random_value_{self.random_call_index}" | |
class GlobalSource(Source): | |
global_name: str | |
def reconstruct(self, codegen): | |
codegen.append_output( | |
codegen.create_load_global(self.global_name, False, add=True) | |
) | |
def guard_source(self): | |
return GuardSource.GLOBAL | |
def name(self): | |
return f"G[{repr(self.global_name)}]" | |
class GlobalWeakRefSource(Source): | |
global_name: str | |
def reconstruct(self, codegen): | |
codegen.append_output( | |
codegen.create_load_global(self.global_name, True, add=True) | |
) | |
codegen.extend_output(create_call_function(0, False)) | |
def guard_source(self): | |
return GuardSource.GLOBAL | |
def name(self): | |
return f"G[{repr(self.global_name)}]()" | |
class AttrSource(ChainedSource): | |
member: str | |
get_static: bool = False | |
def __post_init__(self): | |
assert self.base, "Can't construct an AttrSource without a valid base source" | |
if "." in self.member: | |
member_parts = self.member.split(".") | |
object.__setattr__( | |
self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) | |
) | |
object.__setattr__(self, "member", member_parts[-1]) | |
def reconstruct(self, codegen): | |
self.base.reconstruct(codegen) | |
codegen.extend_output(codegen.create_load_attrs(self.member)) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
if self.get_static: | |
return f"inspect.getattr_static({self.base.name()}, {self.member!r})" | |
elif not self.member.isidentifier(): | |
return f"getattr({self.base.name()}, {self.member!r})" | |
return f"{self.base.name()}.{self.member}" | |
class ParamBufferSource(AttrSource): | |
def guard_source(self): | |
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] | |
# This source is intended to be used in places where a source is needed but it is expected | |
# that the symbol will be simplified out later on. Symbols with ephemeral sources are | |
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral | |
# source. Guarding on this source is an error. | |
# | |
# Example: During subclass view fake-ification, any close-over ViewFunc state should be | |
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source | |
# is useful for symbols utilized in the middle of the view chain that are not expected to be | |
# present within the final view shape metadata. | |
class EphemeralSource(Source): | |
desc: Optional[str] = None | |
def guard_source(self): | |
return GuardSource.EPHEMERAL | |
def name(self): | |
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>" | |
def make_guard(self): | |
raise NotImplementedError() | |
def is_ephemeral(self): | |
return True | |
class TensorProperty(enum.Enum): | |
SIZE = 0 | |
STRIDE = 1 | |
STORAGE_OFFSET = 2 | |
def method_name(self): | |
if self is TensorProperty.SIZE: | |
return "size" | |
elif self is TensorProperty.STRIDE: | |
return "stride" | |
elif self is TensorProperty.STORAGE_OFFSET: | |
return "storage_offset" | |
class TensorPropertySource(ChainedSource): | |
prop: TensorProperty | |
idx: Optional[int] = None # None for STORAGE_OFFSET | |
def __post_init__(self): | |
assert self.base is not None | |
if self.prop is TensorProperty.STORAGE_OFFSET: | |
assert self.idx is None | |
else: | |
assert self.idx is not None | |
def reconstruct(self, codegen): | |
self.base.reconstruct(codegen) | |
codegen.append_output(codegen.create_load_attr(self.prop.method_name())) | |
if self.idx is not None: | |
codegen.append_output(codegen.create_load_const(self.idx)) | |
codegen.extend_output( | |
create_call_function(1 if self.idx is not None else 0, True) | |
) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
if self.prop is TensorProperty.SIZE: | |
return f"{self.base.name()}.size()[{self.idx}]" | |
elif self.prop is TensorProperty.STRIDE: | |
return f"{self.base.name()}.stride()[{self.idx}]" | |
elif self.prop is TensorProperty.STORAGE_OFFSET: | |
assert self.idx is None | |
return f"{self.base.name()}.storage_offset()" | |
else: | |
raise AssertionError(f"unhandled {self.prop}") | |
class NegateSource(ChainedSource): | |
def __post_init__(self): | |
assert self.base is not None | |
def reconstruct(self, codegen): | |
raise NotImplementedError() | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
# NB: use method call so that function stripping regexes work | |
return f"{self.base.name()}.__neg__()" | |
class ConvertIntSource(ChainedSource): | |
def __post_init__(self): | |
assert self.base is not None | |
def reconstruct(self, codegen): | |
self.base.reconstruct(codegen) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
return f"cast_symbool_to_symint_guardless({self.base.name()})" | |
class DefaultsSource(ChainedSource): | |
idx_key: Union[int, str] | |
is_kw: bool = False | |
field: str = dataclasses.field(init=False, repr=False, compare=False) | |
_name: str = dataclasses.field(init=False, repr=False, compare=False) | |
def __post_init__(self): | |
assert ( | |
self.base | |
), "Base must be a valid source in order to properly track and guard this Defaults to its origin." | |
if self.is_kw: | |
assert isinstance(self.idx_key, str) | |
object.__setattr__(self, "field", "__kwdefaults__") | |
object.__setattr__( | |
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" | |
) | |
else: | |
assert isinstance(self.idx_key, int) | |
object.__setattr__(self, "field", "__defaults__") | |
object.__setattr__( | |
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" | |
) | |
def reconstruct(self, codegen): | |
self.base.reconstruct(codegen) | |
codegen.extend_output(codegen.create_load_attrs(self.field)) | |
codegen.append_output(codegen.create_load_const(self.idx_key)) | |
codegen.append_output(create_instruction("BINARY_SUBSCR")) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
return self._name | |
class GetItemSource(ChainedSource): | |
index: Any | |
index_is_slice: bool = False | |
def __post_init__(self): | |
assert self.base is not None | |
if isinstance(self.index, slice): | |
# store the hashable version of the slice so the whole GetItemSource is hashable | |
super().__setattr__("index", self.index.__reduce__()) | |
super().__setattr__("index_is_slice", True) | |
def reconstruct(self, codegen): | |
reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice) | |
codegen.append_output(create_instruction("BINARY_SUBSCR")) | |
def guard_source(self): | |
return self.base.guard_source() | |
def unpack_slice(self): | |
assert self.index_is_slice | |
slice_class, slice_args = self.index | |
return slice_class(*slice_args) | |
def name(self): | |
# Index can be of following types | |
# 1) ConstDictKeySource | |
# 2) enum.Enum | |
# 3) index is a slice - example 1:4 | |
# 4) index is a constant - example string, integer | |
if isinstance(self.index, Source): | |
if not isinstance(self.index, ConstDictKeySource): | |
raise ValueError( | |
"GetItemSource index must be a constant, enum or ConstDictKeySource" | |
) | |
return f"{self.base.name()}[{self.index.name()}]" | |
elif self.index_is_slice: | |
return f"{self.base.name()}[{self.unpack_slice()!r}]" | |
elif isinstance(self.index, enum.Enum): | |
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]" | |
else: | |
return f"{self.base.name()}[{self.index!r}]" | |
class ConstDictKeySource(GetItemSource): | |
def is_dict_key(self): | |
return True | |
def reconstruct(self, codegen): | |
codegen.load_import_from(utils.__name__, "dict_keys_getitem") | |
self.base.reconstruct(codegen) | |
codegen.append_output(codegen.create_load_const(self.index)) | |
codegen.extend_output(create_call_function(2, True)) | |
def name(self): | |
# The list creation will be CSE'd by PyExprCSEPass | |
return f"list({self.base.name()}.keys())[{self.index!r}]" | |
class TupleIteratorGetItemSource(GetItemSource): | |
def reconstruct(self, codegen): | |
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") | |
self.base.reconstruct(codegen) | |
codegen.append_output(codegen.create_load_const(self.index)) | |
codegen.extend_output(create_call_function(2, True)) | |
def name(self): | |
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" | |
class TypeSource(ChainedSource): | |
def __post_init__(self): | |
assert self.base is not None | |
def reconstruct(self, codegen): | |
codegen.load_import_from("builtins", "type") | |
self.base.reconstruct(codegen) | |
codegen.extend_output(create_call_function(1, True)) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
return f"type({self.base.name()})" | |
class ODictGetItemSource(ChainedSource): | |
index: Any | |
def __post_init__(self): | |
assert self.base is not None | |
def reconstruct(self, codegen): | |
codegen.append_output( | |
codegen._create_load_const(collections.OrderedDict.__getitem__) | |
) | |
reconstruct_getitem(self, codegen, index_is_slice=False) | |
codegen.extend_output(create_call_function(2, True)) | |
def guard_source(self): | |
return self.base.guard_source() | |
def name(self): | |
if isinstance(self.index, type): | |
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}' | |
return f"___odict_getitem({self.base.name()}, {rep})" | |
elif isinstance(self.index, Source): | |
return f"___odict_getitem({self.base.name()}, {self.index.name()})" | |
else: | |
return f"___odict_getitem({self.base.name()}, {self.index!r})" | |
class NNModuleSource(ChainedSource): | |
def reconstruct(self, codegen): | |
self.base.reconstruct(codegen) | |
def guard_source(self): | |
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] | |
def name(self): | |
return self.base.name() | |
class NotNNModuleSource(NNModuleSource): | |
def guard_source(self): | |
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()] | |
class FSDPNNModuleSource(NNModuleSource): | |
def guard_source(self): | |
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] | |
class GlobalStateSource(Source): | |
def name(self): | |
return "" | |
def guard_source(self): | |
return GuardSource.GLOBAL | |
class ConstantSource(Source): | |
source_name: str | |
def reconstruct(self, codegen): | |
codegen.append_output( | |
codegen.create_load_global(self.source_name, False, add=False) | |
) | |
def guard_source(self): | |
return GuardSource.CONSTANT | |
def name(self): | |
return self.source_name | |
def make_guard(self, fn): | |
raise NotImplementedError() | |
class NumpyTensorSource(ChainedSource): | |
def name(self) -> str: | |
return f"___from_numpy({self.base.name()})" | |
def guard_source(self): | |
return self.base.guard_source() | |
def reconstruct(self, codegen): | |
codegen.load_import_from("torch", "as_tensor") | |
self.base.reconstruct(codegen) | |
codegen.extend_output(create_call_function(1, True)) | |
# This is a synthetic source that is associated with the singleton | |
# shape env guard we always register for all frames. We get the actual | |
# guard contents from the ambient ShapeEnv | |
class ShapeEnvSource(Source): | |
def name(self): | |
return "" | |
def guard_source(self): | |
return GuardSource.SHAPE_ENV | |
class BackwardStateSource(Source): | |
def name(self): | |
return "" | |
def guard_source(self): | |
return GuardSource.BACKWARD_STATE | |
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True): | |
if isinstance(source, ChainedSource): | |
return is_from_local_source( | |
source.base, allow_cell_or_freevar=allow_cell_or_freevar | |
) | |
if not isinstance(source, LocalSource): | |
return False | |
if not allow_cell_or_freevar and source.cell_or_freevar: | |
return False | |
return True | |
# TODO: can probably write a generic "test this on everything in the chain" | |
# helper | |
def is_from_defaults(source: Source): | |
if isinstance(source, DefaultsSource): | |
return True | |
if isinstance(source, ChainedSource): | |
return is_from_defaults(source.base) | |
return False | |