File size: 11,349 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import collections
import copyreg
import io
import pickle
import sys
import threading
import traceback
from enum import Enum

import torch
import torch.distributed as dist
from torch._C._distributed_rpc import _get_current_rpc_agent

__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"]

# Thread local tensor tables to store tensors while pickling torch.Tensor
# objects
_thread_local_tensor_tables = threading.local()
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler


class RPCExecMode(Enum):
    SYNC = "sync"
    ASYNC = "async"
    ASYNC_JIT = "async_jit"
    REMOTE = "remote"


class _InternalRPCPickler:
    r"""

    This class provides serialize() and deserialize() interfaces to serialize

    data to be "binary string + tensor table" format

    So for RPC python UDF function and args, non tensor data will be serialized

    into regular binary string, tensor data will be put into thread local tensor

    tables, this serialization format is consistent with builtin operator and args

    using JIT pickler. This format will make tensor handling in C++ much easier,

    e.g. attach tensor to distributed autograd graph in C++

    """

    def __init__(self):
        # Ignore type error because dispatch_table is defined in third-party package
        self._dispatch_table = copyreg.dispatch_table.copy()  # type: ignore[attr-defined]
        self._dispatch_table[torch.Tensor] = self._tensor_reducer
        # Used for registering customized picklers.
        self._class_reducer_dict = {}

    def _register_reducer(self, obj_class, reducer):
        # For the same class, only register the reducer once.
        if obj_class not in self._class_reducer_dict:
            self._class_reducer_dict[obj_class] = reducer

    @classmethod
    def _tensor_receiver(cls, tensor_index):
        global _thread_local_tensor_tables
        return _thread_local_tensor_tables.recv_tables[tensor_index]

    def _tensor_reducer(self, tensor):
        global _thread_local_tensor_tables
        _thread_local_tensor_tables.send_tables.append(tensor)
        tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
        return (_InternalRPCPickler._tensor_receiver, (tensor_index,))

    @classmethod
    def _py_rref_receiver(cls, rref_fork_data):
        return dist.rpc.PyRRef._deserialize(rref_fork_data)

    def _py_rref_reducer(self, py_rref):
        rref_fork_data = py_rref._serialize()
        return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,))

    def _rref_reducer(self, rref):
        return self._py_rref_reducer(rref)

    @classmethod
    def _script_module_receiver(cls, script_module_serialized):
        """

        Given a serialized representation of a ScriptModule created with torch.jit.save,

        loads and returns the ScriptModule.

        """
        f = io.BytesIO(script_module_serialized)
        m = torch.jit.load(f)
        return m

    def _script_module_reducer(self, script_module):
        """

        Serializes a ScriptModule.

        """
        f = io.BytesIO()
        torch.jit.save(script_module, f)
        return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),))

    def serialize(self, obj):
        r"""

        Serialize non tensor data into binary string, tensor data into

        tensor table

        """
        f = io.BytesIO()
        p = _pickler(f)
        p.dispatch_table = self._dispatch_table

        # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref,
        # user picklers could have different initialization function from _InternalRPCPickler,
        # but all the user picklers should call serialize() and use _rref_reducer to pickle rref
        # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not
        # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor,
        # so putting rref's dispatch table here
        #
        # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`.
        # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`.
        # Ignore type error because dispatch_table is defined in third-party package
        p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer  # type: ignore[index]
        # An RRef created locally by RRef Python constructor is type of `rpc.RRef`.
        # Ignore type error because dispatch_table is defined in third-party package
        p.dispatch_table[dist.rpc.RRef] = self._rref_reducer  # type: ignore[index]

        # Add dispatch pickling for ScriptModule or its subclass.
        if isinstance(obj, torch.jit.ScriptModule):
            # Ignore type error because dispatch_table is defined in third-party package
            p.dispatch_table[obj.__class__] = self._script_module_reducer  # type: ignore[index]

        # Install customized picklers.
        for class_name in self._class_reducer_dict.keys():
            p.dispatch_table[class_name] = self._class_reducer_dict[class_name]  # type: ignore[index]

        # save _thread_local_tensor_tables.send_tables if it is in nested call
        global _thread_local_tensor_tables
        if hasattr(_thread_local_tensor_tables, "send_tables"):
            old_send_tables = _thread_local_tensor_tables.send_tables
        else:
            old_send_tables = None
        _thread_local_tensor_tables.send_tables = []

        p.dump(obj)

        # restore _thread_local_tensor_tables.send_tables if return
        # from nested call, otherwise clean up the table
        tensors = _thread_local_tensor_tables.send_tables
        if old_send_tables is not None:
            _thread_local_tensor_tables.send_tables = old_send_tables
        else:
            del _thread_local_tensor_tables.send_tables

        return (f.getvalue(), tensors)

    def deserialize(self, binary_data, tensor_table):
        r"""

        Deserialize binary string + tensor table to original obj

        """
        # save _thread_local_tensor_tables.recv_tables if it is in nested call
        global _thread_local_tensor_tables
        if hasattr(_thread_local_tensor_tables, "recv_tables"):
            old_recv_tables = _thread_local_tensor_tables.recv_tables
        else:
            old_recv_tables = None
        _thread_local_tensor_tables.recv_tables = tensor_table

        try:
            unpickler = _unpickler(io.BytesIO(binary_data))
            ret = unpickler.load()
        except AttributeError as e:
            # Occurs when function is not found on module/class during
            # unpickling.
            except_str = (
                str(e)
                + """ Default RPC pickler does not serialize

            function code. Ensure that UDFs are defined on both caller and

            callee modules."""
            )
            ret = AttributeError(except_str)
            # Ensure the stack trace gets preserved
            ret.__cause__ = e

        # restore _thread_local_tensor_tables.recv_tables if return
        # from nested call, otherwise clean up the table
        if old_recv_tables is not None:
            _thread_local_tensor_tables.recv_tables = old_recv_tables
        else:
            del _thread_local_tensor_tables.recv_tables

        return ret


# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
_internal_rpc_pickler = _InternalRPCPickler()


def serialize(obj):
    return _internal_rpc_pickler.serialize(obj)


def deserialize(binary_data, tensor_table):
    return _internal_rpc_pickler.deserialize(binary_data, tensor_table)


def _run_function(python_udf):
    r"""

    This function is exclusively called from C++.

    See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.



    Runs a Python UDF and returns its return value.

    Wraps any exception in ``RemoteException`` if the function raises.

    """
    try:
        if isinstance(python_udf, AttributeError):
            raise python_udf
        result = python_udf.func(*python_udf.args, **python_udf.kwargs)
    except Exception as e:
        # except str = exception info + traceback string
        except_str = (
            f"On {_get_current_rpc_agent().get_worker_info()}:\n"
            f"{repr(e)}\n{traceback.format_exc()}"
        )
        print(except_str, file=sys.stderr)
        result = RemoteException(except_str, type(e))
    return result


def _handle_exception(result):
    if isinstance(result, RemoteException):
        exception_msg = result.msg.encode("utf-8").decode("unicode_escape")
        # We wrap exception re-creation here in case some exception classes
        # cannot be constructed directly from a string.
        exc = None
        try:
            exc = result.exception_type(exception_msg)
        except BaseException as e:
            raise RuntimeError(  # noqa: B904
                f"Failed to create original exception type. Error msg was {str(e)}"
                f" Original exception on remote side was {exception_msg}"
            ) from e

        if exc is not None:
            raise exc


def _build_rpc_profiling_key(

    exec_type, func_name, current_worker_name, dst_worker_name

):
    """

    Builds the key that RPC calls are profiled with using the autograd profiler.

    This will be the name of the corresponding Event recorded in the profiler.



    Args:

        exec_type (RPCExecMode): Type of RPC/RRef call

        func_name (str): Name of function being profiled.

        current_worker_name (str): Name of current worker.

        dst_worker_name (str): Name of the destination worker.



    Returns:

        String representing profiling key

    """
    profile_key = f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})"
    return profile_key


def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name):
    """

    This function should be called from RPC/RRef functions to create a

    RecordFunction object for profiling. This function also runs the before

    callbacks that start the profiling, though the user is responsible for

    running the appropriate callbacks when the function to be profiled finishes.



    Args:

        exec_type (RPCExecMode): Type of RPC/RRef call

        func_name (str): Name of function being profiled.

        current_worker_name (str): Name of current worker.

        dest_worker_name (str): Name of the destination worker.



    Returns:

        An instance of `torch.autograd._RecordFunction`.

    """
    assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled."
    profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})"
    rf = torch.autograd._RecordFunction()  # type: ignore[attr-defined]
    torch.autograd._run_before_callbacks(rf, profile_key)  # type: ignore[attr-defined]
    return rf


PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"])