Spaces:
Sleeping
Sleeping
import torch.fx as fx | |
from torch.fx.node import Argument, Target | |
from torch.nn.utils.fusion import fuse_conv_bn_eval | |
from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.fx.passes.shape_prop import ShapeProp | |
import copy | |
from collections import defaultdict | |
import torch.utils.mkldnn as th_mkldnn | |
import operator | |
import time | |
import logging | |
from enum import Enum | |
def _parent_name(target : str) -> Tuple[str, str]: | |
""" | |
Splits a qualname into parent path and last atom. | |
For example, `foo.bar.baz` -> (`foo.bar`, `baz`) | |
""" | |
*parent, name = target.rsplit('.', 1) | |
return parent[0] if parent else '', name | |
# Works for length 2 patterns with 2 modules | |
def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): | |
if len(node.args) == 0: | |
return False | |
nodes: Tuple[Any, fx.Node] = (node.args[0], node) | |
for expected_type, current_node in zip(pattern, nodes): | |
if not isinstance(current_node, fx.Node): | |
return False | |
if current_node.op != 'call_module': | |
return False | |
if not isinstance(current_node.target, str): | |
return False | |
if current_node.target not in modules: | |
return False | |
if type(modules[current_node.target]) is not expected_type: | |
return False | |
return True | |
def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): | |
assert isinstance(node.target, str) | |
parent_name, name = _parent_name(node.target) | |
modules[node.target] = new_module | |
setattr(modules[parent_name], name, new_module) | |
def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: | |
""" | |
Fuses convolution/BN layers for inference purposes. Will deepcopy your | |
model by default, but can modify the model inplace as well. | |
""" | |
patterns = [(nn.Conv1d, nn.BatchNorm1d), | |
(nn.Conv2d, nn.BatchNorm2d), | |
(nn.Conv3d, nn.BatchNorm3d)] | |
if not inplace: | |
model = copy.deepcopy(model) | |
if not no_trace or not isinstance(model, torch.fx.GraphModule): | |
fx_model = fx.symbolic_trace(model) | |
else: | |
fx_model = model | |
modules = dict(fx_model.named_modules()) | |
new_graph = copy.deepcopy(fx_model.graph) | |
for pattern in patterns: | |
for node in new_graph.nodes: | |
if matches_module_pattern(pattern, node, modules): | |
if len(node.args[0].users) > 1: # Output of conv is used by other nodes | |
continue | |
conv = modules[node.args[0].target] | |
bn = modules[node.target] | |
if not bn.track_running_stats: | |
continue | |
fused_conv = fuse_conv_bn_eval(conv, bn) | |
replace_node_module(node.args[0], modules, fused_conv) | |
node.replace_all_uses_with(node.args[0]) | |
new_graph.erase_node(node) | |
return fx.GraphModule(fx_model, new_graph) | |
def remove_dropout(model: nn.Module) -> nn.Module: | |
""" | |
Removes all dropout layers from the module. | |
""" | |
fx_model = fx.symbolic_trace(model) | |
class DropoutRemover(torch.fx.Transformer): | |
def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: | |
if isinstance(self.submodules[target], nn.Dropout): | |
assert len(args) == 1 | |
return args[0] | |
else: | |
return super().call_module(target, args, kwargs) | |
return DropoutRemover(fx_model).transform() | |
def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): | |
""" | |
Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. | |
""" | |
new_graph = fx.Graph() | |
env: Dict[fx.Node, fx.Node] = {} | |
for input in inputs: | |
new_node = new_graph.placeholder(input.name) | |
env[input] = new_node | |
for node in nodes: | |
new_node = new_graph.node_copy(node, lambda x: env[x]) | |
env[node] = new_node | |
new_graph.output([env[output] for output in outputs]) | |
new_graph.lint() | |
return fx.GraphModule(orig_module, new_graph) | |
mkldnn_supported = [ | |
nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, | |
torch.relu, torch.transpose, torch.sigmoid, | |
F.relu, F.avg_pool2d, F.adaptive_avg_pool2d | |
] | |
# These are operators that may not be convertible into MKLDNN ops (e.g. the | |
# args are scalar values). Thus, we only include them in the subgraph if their | |
# arguments are already in MKLDNN. | |
# TODO: Determine whether this can be removed after type inference. | |
mkldnn_supported_unknown = [operator.add, operator.mul] | |
mkldnn_map = { | |
nn.Conv2d: th_mkldnn.MkldnnConv2d, | |
nn.Linear: th_mkldnn.MkldnnLinear, | |
nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) | |
} | |
def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): | |
""" | |
For each node, if it's a module that can be preconverted into MKLDNN, | |
then we do so and create a mapping to allow us to convert from the MKLDNN | |
version of the module to the original. | |
""" | |
old_modules: Dict[nn.Module, nn.Module] = {} | |
for node in nodes: | |
if node.op == 'call_module': | |
assert isinstance(node.target, str) | |
cur_module = modules[node.target] | |
if type(cur_module) in mkldnn_map: | |
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) | |
assert isinstance(new_module, nn.Module) | |
old_modules[new_module] = copy.deepcopy(cur_module) | |
replace_node_module(node, modules, new_module) | |
return old_modules | |
def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): | |
""" | |
Maps each module that's been changed with `modules_to_mkldnn` back to its | |
original. | |
""" | |
for node in nodes: | |
if node.op == 'call_module': | |
assert (isinstance(node.target, str)) | |
cur_module = modules[node.target] | |
if cur_module in old_modules: | |
replace_node_module(node, modules, old_modules[cur_module]) | |
class MklSubgraph: | |
def __init__(self, fx_graph: fx.Graph): | |
self.fx_graph = fx_graph | |
self.nodes: List[fx.Node] = [] | |
self.start_nodes: List[fx.Node] = [] | |
self.end_nodes: List[fx.Node] = [] | |
def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): | |
""" | |
This generates a heuristic that can be passed into `optimize_for_inference` that | |
determines whether a subgraph should be run in MKL by running it with the example_inputs. | |
Example usage: | |
heuristic = gen_mkl_autotuner(example_inputs, iters=10) | |
fast_model = optimization.optimize_for_inference(model, heuristic) | |
""" | |
fx_model = None | |
old_modules = None | |
def use_mkl_heuristic(graph: MklSubgraph) -> bool: | |
nonlocal fx_model, old_modules | |
input_nodes = graph.start_nodes | |
if fx_model is None: | |
fx_model = graph.fx_graph.owning_module | |
old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] | |
ShapeProp(fx_model).propagate(example_inputs) | |
sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] | |
output_args = cast(List[fx.Node], [node.args[0] for node in graph.end_nodes]) | |
submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) | |
def benchmark(f): | |
for _ in range(warmup): | |
f() | |
begin = time.time() | |
for _ in range(iters): | |
out = f() | |
return time.time() - begin | |
mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) | |
reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) | |
no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) | |
return mkl_time < no_mkl_time | |
return use_mkl_heuristic | |
def use_mkl_length(graph: MklSubgraph) -> bool: | |
""" | |
This is a heuristic that can be passed into `optimize_for_inference` that | |
determines whether a subgraph should be run in MKL by checking if there | |
are more than 2 nodes in it | |
""" | |
return len(graph.nodes) > 2 | |
class UnionFind: | |
def __init__(self, n): | |
self.parent: List[Optional[int]] = [None] * n | |
self.size: List[int] = [0] * n | |
def make_set(self, v: int): | |
self.parent[v] = v | |
self.size[v] = 1 | |
def find(self, v: int) -> int: | |
par = self.parent[v] | |
if v == par: | |
return v | |
assert par is not None | |
self.parent[v] = self.find(par) | |
return cast(int, self.parent[v]) | |
def join(self, a: int, b: int): | |
a, b = self.find(a), self.find(b) | |
if a == b: | |
return a | |
if self.size[a] < self.size[b]: | |
a, b = b, a | |
self.parent[b] = a | |
self.size[a] += self.size[b] | |
def optimize_for_inference( | |
model: torch.nn.Module, | |
pass_config: Optional[Dict[str, Any]] = None, | |
tracer: Type[fx.Tracer] = fx.Tracer | |
) -> torch.nn.Module: | |
""" | |
Performs a set of optimization passes to optimize a model for the | |
purposes of inference. Specifically, the passes that are run are: | |
1. Conv/BN fusion | |
2. Dropout removal | |
3. MKL layout optimizations | |
The third optimization takes a function `use_mkl_heuristic` that's used | |
to determine whether a subgraph should be explicitly run in MKL layout. | |
Note: As FX does not currently handle aliasing, this pass currently | |
assumes nothing aliases. If that isn't true, use at your own risk. | |
""" | |
default_pass_config = { | |
"conv_bn_fuse": True, | |
"remove_dropout": True, | |
"mkldnn_layout_optimize": {'heuristic': use_mkl_length}, | |
} | |
if pass_config is None: | |
pass_config = {} | |
default_pass_config.update(pass_config) | |
if default_pass_config["conv_bn_fuse"]: | |
model = fuse(model) | |
if default_pass_config["remove_dropout"]: | |
model = remove_dropout(model) | |
if default_pass_config["mkldnn_layout_optimize"] is False: | |
return model | |
if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): | |
raise RuntimeError("mkldnn_layout_optimize config is not a dict") | |
if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: | |
raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") | |
use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] | |
cur_tracer = tracer() | |
fx_graph = cur_tracer.trace(copy.deepcopy(model)) | |
fx_model = fx.GraphModule(cur_tracer.root, fx_graph) | |
modules: Dict[str, nn.Module] = dict(model.named_modules()) | |
class MklSupport(Enum): | |
NO = 1 | |
YES = 2 | |
UNKNOWN = 3 | |
# Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. | |
# If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. | |
# However, if it's in `mkldnn_supported_unknown`, then we only treat it as | |
# a MKLDNN node if its inputs are MKLDNN nodes. | |
for node in list(fx_graph.nodes): | |
supports_mkldnn = MklSupport.NO | |
if node.op == 'call_module': | |
cur_module = modules[node.target] | |
if type(cur_module) in mkldnn_supported: | |
supports_mkldnn = MklSupport.YES | |
sample_parameter = next(cur_module.parameters(), None) | |
if sample_parameter is not None: | |
assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" | |
assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" | |
elif node.op == 'call_function': | |
if node.target in mkldnn_supported: | |
supports_mkldnn = MklSupport.YES | |
elif node.target in mkldnn_supported_unknown: | |
supports_mkldnn = MklSupport.UNKNOWN | |
if supports_mkldnn != MklSupport.NO: | |
if supports_mkldnn == MklSupport.UNKNOWN: | |
if not any(arg.target == 'to_dense' for arg in node.args): | |
continue | |
with fx_graph.inserting_before(node): | |
mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) | |
node.args = cast(Tuple[fx.node.Argument], mkldnn_args) | |
with fx_graph.inserting_after(node): | |
dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) | |
node.replace_all_uses_with(dense_x) | |
dense_x.args = (node,) | |
# Does pre-conversion of all modules into MKLDNN (when possible) | |
old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) | |
fx_graph.old_modules = old_modules # type: ignore[attr-defined] | |
# optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b | |
for node in fx_graph.nodes: | |
if node.op == 'call_method' and node.target == 'to_dense': | |
prv_node = node.args[0] | |
users = list(node.users) | |
for user in users: | |
if user.op == 'call_method' and user.target == 'to_mkldnn': | |
user.replace_all_uses_with(prv_node) | |
fx_graph.erase_node(user) | |
if len(node.users) == 0: | |
fx_graph.erase_node(node) | |
num_nodes = len(fx_graph.nodes) | |
uf = UnionFind(num_nodes) | |
def get_color(n): | |
if hasattr(n, 'color'): # Current node is part of a MKL subgraph | |
return uf.find(n.color) | |
if hasattr(n, 'start_color'): # Current node is input to MKL subgraph | |
return uf.find(n.start_color) | |
return None | |
# This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists | |
# of input nodes (which are only `to_mkldnn` calls), output nodes | |
# (`to_dense` calls), and intermediate nodes, which are run entirely on | |
# MKLDNN layout tensors. | |
# | |
# Specifically, this code does a flood fill on a directed acyclic graph | |
# (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). | |
# If every node only had one input, this would be sufficient. However, in | |
# the case that a node has multiple inputs coming from different start | |
# nodes (i.e. colors), we need to join these 2 colors into 1. That's done | |
# using a Disjoint Set Union. | |
for cur_idx, node in enumerate(fx_graph.nodes): | |
if node.op == 'call_method' and node.target == 'to_mkldnn': | |
node.start_color = cur_idx | |
uf.make_set(cur_idx) | |
elif node.op == 'call_method' and node.target == 'to_dense': | |
assert get_color(node.args[0]) is not None | |
node.end_color = get_color(node.args[0]) | |
else: | |
cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] | |
if len(cur_colors) == 0: | |
continue | |
assert not any(i is None for i in cur_colors) | |
cur_colors = sorted(cur_colors) | |
node.color = cur_colors[0] | |
for other_color in cur_colors[1:]: | |
uf.join(cur_colors[0], other_color) | |
mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) | |
for node in fx_graph.nodes: | |
if hasattr(node, 'color'): | |
mkldnn_graphs[uf.find(node.color)].nodes.append(node) | |
if hasattr(node, 'start_color'): | |
mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) | |
if hasattr(node, 'end_color'): | |
mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) | |
# Now that we have all the subgraphs, we need to decide which MKLDNN | |
# subgraphs we actually want to keep in MKLDNN. | |
for graph in mkldnn_graphs.values(): | |
if not use_mkl_heuristic(graph): | |
for node in graph.start_nodes + graph.end_nodes: | |
prv = node.args[0] | |
node.replace_all_uses_with(prv) | |
fx_graph.erase_node(node) | |
reset_modules(graph.nodes, modules, old_modules) | |
mkldnn_conversions = 0 | |
for node in fx_graph.nodes: | |
if node.target == 'to_mkldnn' or node.target == 'to_dense': | |
mkldnn_conversions += 1 | |
logging.getLogger(__name__).info(f"mkldnn conversions: {mkldnn_conversions}") | |
fx_graph.lint() | |
result = fx.GraphModule(model, fx_graph) | |
return result | |