Spaces:
Running
Running
File size: 7,408 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# mypy: ignore-errors
import torch
import torch.fx
import traceback
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
from torch._guards import detect_fake_mode
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int, ...]
memory_format : Optional[torch.memory_format]
# Quantization metadata
is_quantized : bool
qparams: Dict[str, Any]
def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
memory_format = None
if include_contiguity:
memory_formats = {
torch.contiguous_format,
torch.channels_last,
torch.channels_last_3d,
}
for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
into the corresponding node.
Example:
In this example, we record the shape
and data type of a module given
an example input ``torch.randn(50, D_in)``.
We print the name, shape and dtype of each node.
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super().__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
The output of this code is:
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp_1 torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])
Args:
module (GraphModule): The module to be executed
fake_mode (FakeTensorMode): A fake mode for copying the gm
"""
def __init__(self, gm, fake_mode=None):
super().__init__(gm)
if fake_mode is None:
fake_mode = detect_fake_mode()
if fake_mode is not None:
from torch._dynamo.utils import deepcopy_to_fake_tensor
# Note:
# We need fake execution cause the inputs are fake, however, we cannot fakify the module
# - because we need to write to the tensor_meta of the real module. So we fakify to
# produce a result (L131 below), to extract tensor meta, and then keep going.
#
# If we were to fakify, we would write to the wrong node, and then downstream fusion
# would be missing the tensor_meta.
#
# See torch/_inductor/overrides.py for where this is called upstream of fusion.
self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode)
self.fake_mode = fake_mode
else:
self.fake_module = None
self.fake_mode = None
self.real_module = self.module
def run_node(self, n : Node) -> Any:
try:
if self.fake_module is not None:
# Hacky swap. Alternatively, we could do this with overriding
# call_module and get_attr.
self.module = self.fake_module
try:
if self.fake_mode is not None:
with self.fake_mode, enable_python_dispatcher():
result = super().run_node(n)
else:
result = super().run_node(n)
finally:
self.module = self.real_module
except Exception as e:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: node={n.format_node()} with "
f"meta={n.meta}"
) from e
found_tensor = False
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)
return result
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)
|