File size: 7,231 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import contextlib
from typing import List, Tuple

import torch


@contextlib.contextmanager
def optimized_execution(should_optimize):
    """Context manager that controls whether the JIT's executor will run optimizations before executing a function."""
    stored_flag = torch._C._get_graph_executor_optimize()
    torch._C._set_graph_executor_optimize(should_optimize)
    try:
        yield
    finally:
        torch._C._set_graph_executor_optimize(stored_flag)


@contextlib.contextmanager
def fuser(name):
    """Context manager that facilitates switching between backend fusers.



    Valid names:

    * ``fuser0`` - enables only legacy fuser

    * ``fuser1`` - enables only NNC

    * ``fuser2`` - enables only nvFuser

    * ``fuser3`` - enables oneDNN Graph

    """
    old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
    old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
    old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
    old_nvfuser_state = torch._C._jit_nvfuser_enabled()
    old_llga_state = torch._C._jit_llga_enabled()
    if name == "fuser0":  # legacy fuser
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(False)
        torch._C._jit_set_llga_enabled(False)
    elif name == "fuser1":  # NNC
        old_profiling_executor = torch._C._jit_set_profiling_executor(True)
        old_profiling_mode = torch._C._get_graph_executor_optimize(True)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(True)
        torch._C._jit_set_texpr_fuser_enabled(True)
        torch._C._jit_set_nvfuser_enabled(False)
        torch._C._jit_set_llga_enabled(False)
    elif name == "fuser2":  # nvFuser
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._jit_set_llga_enabled(False)
    elif name == "fuser3":  # oneDNN Graph
        old_profiling_executor = torch._C._jit_set_profiling_executor(True)
        old_profiling_mode = torch._C._get_graph_executor_optimize(True)
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(True)
        torch._C._jit_set_nvfuser_enabled(False)
        torch._C._jit_set_llga_enabled(True)
    elif name == "none":  # Turn Pytorch fuser off
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(False)
        torch._C._jit_set_llga_enabled(False)
    else:
        raise Exception(f"unrecognized fuser option (name: {name})")
    try:
        yield
    finally:
        if name in ["fuser1", "fuser3"]:  # NNC or oneDNN Graph
            torch._C._jit_set_profiling_executor(old_profiling_executor)  # type: ignore[possibly-undefined]
            torch._C._get_graph_executor_optimize(old_profiling_mode)  # type: ignore[possibly-undefined]
        # recover the previous values
        torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
        torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
        torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
        torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
        torch._C._jit_set_llga_enabled(old_llga_state)


last_executed_optimized_graph = torch._C._last_executed_optimized_graph


def _get_differentiable_graph_node(node, diff_node):
    if node.kind() == "prim::DifferentiableGraph":
        diff_node.append(node)
    else:
        for block in node.blocks():
            for n in block.nodes():
                _get_differentiable_graph_node(n, diff_node)


def _graph_for(self, *args, **kwargs):
    return _script_method_graph_for(self, self, *args, **kwargs)


def _script_method_graph_for(self, parent, *args, **kwargs):
    try:
        dbs = parent.get_debug_state()
        eps = list(dbs.execution_plans.values())
        assert len(eps) == 1
        graph = eps[0].graph.copy()

        # graph_executor_states for differentiable node
        fw_states = eps[0].code.differentiable_op_executor_states()
        diff_nodes: List[torch._C.Node] = []
        for n in graph.nodes():
            _get_differentiable_graph_node(n, diff_nodes)

        assert len(fw_states) == len(diff_nodes)
        # swap each differentiable graph with optimized graph in their execution plan
        for n, state in zip(diff_nodes, fw_states):
            fw_execution_plans = list(state.execution_plans.values())
            # we can only update the subgraph when there's a unique execution
            # plan. Avoid assert here so we would skip the ones that can't be
            # updated while try the best effort to update other nodes.
            if len(fw_execution_plans) == 1:
                n.g_("Subgraph", fw_execution_plans[0].graph)

        return graph
    except Exception:
        # fallback approach, we just ran the graph and return the recorded optimized
        # graph
        self(*args, **kwargs)
        return last_executed_optimized_graph()


def set_fusion_strategy(strategy: List[Tuple[str, int]]):
    """Set the type and number of specializations that can occur during fusion.



    Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"

    and depth is an integer.



    Behavior - static vs dynamic:

        In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined

        based on some initial profiling runs.

        In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple

        shapes are possible.



    In both cases, we also recompile on new striding behavior, device, or dtype.



    Behavior - fallback functions & depth:

        When an input doesn't match the format required by the specialized compiled op, it will run

        a fallback function. Fallback functions are recursively be compiled and specialized based

        on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to

        limit the number of specializations that can be compiled, before giving up on recompiling and

        falling back to a completely un-fused, un-specialized implementation.



    The list of (type, depth) pairs controls the type of specializations and the number of

    specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first

    two specializations will use static fusions, the following two specializations will use

    dynamic fusion, and any inputs that satisfy none of the 4 options will run an

    unfused implementation.



    NB: in the future, if more as more fusion backends are added there may be more granular

    apis for specific fusers.

    """
    return torch._C._jit_set_fusion_strategy(strategy)