Spaces:
Running
Running
File size: 14,146 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
from .graph_module import GraphModule
from .graph import Graph
from .node import Node
from ._symbolic_trace import symbolic_trace
from ._compatibility import compatibility
import copy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
@compatibility(is_backward_compatible=True)
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
@compatibility(is_backward_compatible=False)
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
gm.delete_all_unused_submodules()
if isinstance(replacement, GraphModule):
replacement.graph.lint()
def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
module_path, _, attr_name = target.rpartition(".")
try:
mod: torch.nn.Module = gm.get_submodule(module_path)
except AttributeError:
return None
attr = getattr(mod, attr_name, None)
return attr
for node in gm.graph.nodes:
if node.op == "call_module" or node.op == "get_attr":
gm_attr = try_get_attr(gm, node.target)
replacement_attr = try_get_attr(replacement, node.target)
# CASE 1: This target already exists as an attribute in our
# result GraphModule. Whether or not it exists in
# `replacement`, the existing submodule takes precedence.
if gm_attr is not None:
continue
# CASE 2: The target exists as an attribute in `replacement`
# only, so we need to copy it over.
elif replacement_attr is not None:
new_attr = copy.deepcopy(replacement_attr)
if isinstance(replacement_attr, torch.nn.Module):
gm.add_submodule(node.target, new_attr)
else:
setattr(gm, node.target, new_attr)
# CASE 3: The target doesn't exist as an attribute in `gm`
# or `replacement`
else:
raise RuntimeError("Attempted to create a \"", node.op,
"\" node during subgraph rewriting "
f"with target {node.target}, but "
"the referenced attribute does not "
"exist in the replacement GraphModule")
gm.graph.lint()
@compatibility(is_backward_compatible=True)
def replace_pattern(
gm: GraphModule,
pattern: Union[Callable, GraphModule],
replacement: Union[Callable, GraphModule]
) -> List[Match]:
"""
Matches all possible non-overlapping sets of operators and their
data dependencies (``pattern``) in the Graph of a GraphModule
(``gm``), then replaces each of these matched subgraphs with another
subgraph (``replacement``).
Args:
``gm``: The GraphModule that wraps the Graph to operate on
``pattern``: The subgraph to match in ``gm`` for replacement
``replacement``: The subgraph to replace ``pattern`` with
Returns:
List[Match]: A list of ``Match`` objects representing the places
in the original graph that ``pattern`` was matched to. The list
is empty if there are no matches. ``Match`` is defined as:
.. code-block:: python
class Match(NamedTuple):
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
Examples:
.. code-block:: python
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, w1, w2):
m1 = torch.cat([w1, w2]).sum()
m2 = torch.cat([w1, w2]).sum()
return x + torch.max(m1) + torch.max(m2)
def pattern(w1, w2):
return torch.cat([w1, w2]).sum()
def replacement(w1, w2):
return torch.stack([w1, w2])
traced_module = symbolic_trace(M())
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
The above code will first match ``pattern`` in the ``forward``
method of ``traced_module``. Pattern-matching is done based on
use-def relationships, not node names. For example, if you had
``p = torch.cat([a, b])`` in ``pattern``, you could match
``m = torch.cat([a, b])`` in the original ``forward`` function,
despite the variable names being different (``p`` vs ``m``).
The ``return`` statement in ``pattern`` is matched based on its
value only; it may or may not match to the ``return`` statement in
the larger graph. In other words, the pattern doesn't have to extend
to the end of the larger graph.
When the pattern is matched, it will be removed from the larger
function and replaced by ``replacement``. If there are multiple
matches for ``pattern`` in the larger function, each non-overlapping
match will be replaced. In the case of a match overlap, the first
found match in the set of overlapping matches will be replaced.
("First" here being defined as the first in a topological ordering
of the Nodes' use-def relationships. In most cases, the first Node
is the parameter that appears directly after ``self``, while the
last Node is whatever the function returns.)
One important thing to note is that the parameters of the
``pattern`` Callable must be used in the Callable itself,
and the parameters of the ``replacement`` Callable must match
the pattern. The first rule is why, in the above code block, the
``forward`` function has parameters ``x, w1, w2``, but the
``pattern`` function only has parameters ``w1, w2``. ``pattern``
doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
As an example of the second rule, consider replacing
.. code-block:: python
def pattern(x, y):
return torch.neg(x) + torch.relu(y)
with
.. code-block:: python
def replacement(x, y):
return torch.relu(x)
In this case, ``replacement`` needs the same number of parameters
as ``pattern`` (both ``x`` and ``y``), even though the parameter
``y`` isn't used in ``replacement``.
After calling ``subgraph_rewriter.replace_pattern``, the generated
Python code looks like this:
.. code-block:: python
def forward(self, x, w1, w2):
stack_1 = torch.stack([w1, w2])
sum_1 = stack_1.sum()
stack_2 = torch.stack([w1, w2])
sum_2 = stack_2.sum()
max_1 = torch.max(sum_1)
add_1 = x + max_1
max_2 = torch.max(sum_2)
add_2 = add_1 + max_2
return add_2
"""
match_and_replacements = _replace_pattern(gm, pattern, replacement)
return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
# Experimental API, not backward compatible
@compatibility(is_backward_compatible=False)
def replace_pattern_with_filters(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
"""
See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
Args:
``match_filters``: A list of functions that take in
(match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
whether the match satisfies the condition.
See matcher_utils.py for definition of InternalMatch.
"""
return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals)
def _replace_pattern(
gm: GraphModule,
pattern: Union[Callable, Graph, GraphModule],
replacement: Union[Callable, Graph, GraphModule],
match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
ignore_literals: bool = False,
) -> List[ReplacedPatterns]:
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
if match_filters is None:
match_filters = []
# Get the graphs for `gm`, `pattern`, `replacement`
original_graph: Graph = gm.graph
if isinstance(pattern, GraphModule):
pattern_graph = pattern.graph
elif isinstance(pattern, Graph):
pattern_graph = pattern
else:
pattern_graph = symbolic_trace(pattern).graph
if isinstance(replacement, GraphModule):
replacement_graph = replacement.graph
elif isinstance(replacement, Graph):
replacement_graph = replacement
else:
replacement_graph = symbolic_trace(replacement).graph
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
remove_overlapping_matches=True, ignore_literals=ignore_literals)
_matches: List[InternalMatch] = matcher.match(original_graph)
# Filter out matches that don't match the filter
_matches = [
m for m in _matches
if all(match_filter(m, original_graph, pattern_graph)
for match_filter in match_filters)
]
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
# As we progressively replace nodes, we'll need to keep track of how the match results should change
match_changed_node: Dict[Node, Node] = {}
match_and_replacements = []
for match in _matches:
# Build connecting between replacement graph's input and original graph input producer node
# Initialize `val_map` with mappings from placeholder nodes in
# `replacement` to their corresponding node in `original_graph`
assert len(match.placeholder_nodes) == len(replacement_placeholders)
val_map: Dict[Node, Node] = {}
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
if isinstance(gn, Node):
val_map[rn] = match_changed_node.get(gn, gn)
if gn != val_map[rn]:
# Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
gn_ind = match.placeholder_nodes.index(gn)
match.placeholder_nodes[gn_ind] = match_changed_node[gn]
map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
match.nodes_map[map_key] = match_changed_node[gn]
else:
val_map[rn] = gn
# Copy the replacement graph over
user_nodes: Set[Node] = set()
for n in match.returning_nodes:
for user in n.users:
user_nodes.add(user)
assert user_nodes, "The returning_nodes should have at least one user node"
if len(user_nodes) == 1:
first_user_node = next(iter(user_nodes))
else:
# If there are multiple user nodes, we need to find the first user node
# in the current execution order of the `original_graph`
for n in original_graph.nodes:
if n in user_nodes:
first_user_node = n
break
with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined]
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
if isinstance(copied_returning_nodes, Node):
copied_returning_nodes = (copied_returning_nodes, )
# Get a list of nodes that have been replaced into the graph
replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
# Hook the output Node of the replacement subgraph in to the
# original Graph at the correct location
assert len(match.returning_nodes) == len(copied_returning_nodes)
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
gn.replace_all_uses_with(copied_node)
match_changed_node[gn] = copied_node
# Remove the original nodes
for node in reversed(pattern_graph.nodes):
if node.op != "placeholder" and node.op != "output":
gn = match.nodes_map[node]
gm.graph.erase_node(gn)
match_and_replacements.append(
ReplacedPatterns(
anchor=match.anchors[0],
nodes_map=match.nodes_map,
replacements=replacement_nodes
)
)
# Update the passed-in GraphModule to reflect the new state of
# `original_graph`
gm.recompile()
# If `replacement` was an nn.Module, we'll need to make sure that
# all the submodules have been copied over correctly
if isinstance(replacement, torch.nn.Module):
_replace_attributes(gm, replacement)
return match_and_replacements
|