Spaces:
Running
Running
# mypy: ignore-errors | |
import torch | |
import torch.fx | |
import traceback | |
from torch._dispatch.python import enable_python_dispatcher | |
from torch.fx.node import Node, map_aggregate | |
from typing import Any, Tuple, NamedTuple, Optional, Dict | |
from torch.fx._compatibility import compatibility | |
from torch._guards import detect_fake_mode | |
__all__ = ['TensorMetadata', 'ShapeProp'] | |
class TensorMetadata(NamedTuple): | |
# TensorMetadata is a structure containing pertinent information | |
# about a tensor within a PyTorch program. | |
# General Tensor metadata | |
shape : torch.Size | |
dtype : torch.dtype | |
requires_grad : bool | |
stride : Tuple[int, ...] | |
memory_format : Optional[torch.memory_format] | |
# Quantization metadata | |
is_quantized : bool | |
qparams: Dict[str, Any] | |
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: | |
""" | |
Extract a TensorMetadata NamedTuple describing `result`. | |
""" | |
shape = result.shape | |
dtype = result.dtype | |
requires_grad = result.requires_grad | |
stride = result.stride() | |
memory_format = None | |
if include_contiguity: | |
memory_formats = { | |
torch.contiguous_format, | |
torch.channels_last, | |
torch.channels_last_3d, | |
} | |
for query_format in memory_formats: | |
if result.is_contiguous(memory_format=query_format): | |
memory_format = query_format | |
break | |
is_quantized = result.is_quantized | |
qparams: Dict[str, Any] = {} | |
if is_quantized: | |
qscheme = result.qscheme() | |
qparams["qscheme"] = qscheme | |
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: | |
qparams["scale"] = result.q_scale() # type: ignore[assignment] | |
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] | |
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: | |
# In this branch, scale and zero_point are expected to be tensors, | |
# we store the values as immutable_list in TensorMetadata for | |
# easier serialization downstream | |
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] | |
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] | |
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] | |
return TensorMetadata( | |
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) | |
class ShapeProp(torch.fx.Interpreter): | |
""" | |
Execute an FX graph Node-by-Node and | |
record the shape and type of the result | |
into the corresponding node. | |
Example: | |
In this example, we record the shape | |
and data type of a module given | |
an example input ``torch.randn(50, D_in)``. | |
We print the name, shape and dtype of each node. | |
class TwoLayerNet(torch.nn.Module): | |
def __init__(self, D_in, H, D_out): | |
super().__init__() | |
self.linear1 = torch.nn.Linear(D_in, H) | |
self.linear2 = torch.nn.Linear(H, D_out) | |
def forward(self, x): | |
h_relu = self.linear1(x).clamp(min=0) | |
y_pred = self.linear2(h_relu) | |
return y_pred | |
N, D_in, H, D_out = 64, 1000, 100, 10 | |
x = torch.randn(N, D_in) | |
y = torch.randn(N, D_out) | |
model = TwoLayerNet(D_in, H, D_out) | |
gm = torch.fx.symbolic_trace(model) | |
sample_input = torch.randn(50, D_in) | |
ShapeProp(gm).propagate(sample_input) | |
for node in gm.graph.nodes: | |
print(node.name, node.meta['tensor_meta'].dtype, | |
node.meta['tensor_meta'].shape) | |
The output of this code is: | |
x torch.float32 torch.Size([50, 1000]) | |
linear1 torch.float32 torch.Size([50, 100]) | |
clamp_1 torch.float32 torch.Size([50, 100]) | |
linear2 torch.float32 torch.Size([50, 10]) | |
output torch.float32 torch.Size([50, 10]) | |
Args: | |
module (GraphModule): The module to be executed | |
fake_mode (FakeTensorMode): A fake mode for copying the gm | |
""" | |
def __init__(self, gm, fake_mode=None): | |
super().__init__(gm) | |
if fake_mode is None: | |
fake_mode = detect_fake_mode() | |
if fake_mode is not None: | |
from torch._dynamo.utils import deepcopy_to_fake_tensor | |
# Note: | |
# We need fake execution cause the inputs are fake, however, we cannot fakify the module | |
# - because we need to write to the tensor_meta of the real module. So we fakify to | |
# produce a result (L131 below), to extract tensor meta, and then keep going. | |
# | |
# If we were to fakify, we would write to the wrong node, and then downstream fusion | |
# would be missing the tensor_meta. | |
# | |
# See torch/_inductor/overrides.py for where this is called upstream of fusion. | |
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) | |
self.fake_mode = fake_mode | |
else: | |
self.fake_module = None | |
self.fake_mode = None | |
self.real_module = self.module | |
def run_node(self, n : Node) -> Any: | |
try: | |
if self.fake_module is not None: | |
# Hacky swap. Alternatively, we could do this with overriding | |
# call_module and get_attr. | |
self.module = self.fake_module | |
try: | |
if self.fake_mode is not None: | |
with self.fake_mode, enable_python_dispatcher(): | |
result = super().run_node(n) | |
else: | |
result = super().run_node(n) | |
finally: | |
self.module = self.real_module | |
except Exception as e: | |
traceback.print_exc() | |
raise RuntimeError( | |
f"ShapeProp error for: node={n.format_node()} with " | |
f"meta={n.meta}" | |
) from e | |
found_tensor = False | |
def extract_tensor_meta(obj): | |
if isinstance(obj, torch.Tensor): | |
nonlocal found_tensor | |
found_tensor = True | |
return _extract_tensor_metadata(obj) | |
else: | |
return obj | |
meta = map_aggregate(result, extract_tensor_meta) | |
if found_tensor: | |
n.meta['tensor_meta'] = meta | |
n.meta['type'] = type(result) | |
return result | |
def propagate(self, *args): | |
""" | |
Run `module` via interpretation and return the result and | |
record the shape and type of each node. | |
Args: | |
*args (Tensor): the sample input. | |
Returns: | |
Any: The value returned from executing the Module | |
""" | |
if self.fake_mode is not None: | |
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] | |
else: | |
fake_args = args | |
return super().run(*fake_args) | |