File size: 22,463 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
"""

``torch.autograd`` provides classes and functions implementing automatic

differentiation of arbitrary scalar valued functions. It requires minimal

changes to the existing code - you only need to declare :class:`Tensor` s

for which gradients should be computed with the ``requires_grad=True`` keyword.

As of now, we only support autograd for floating point :class:`Tensor` types (

half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble).

"""
import warnings
from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union

import torch

from torch.types import _size, _TensorOrTensors, _TensorOrTensorsOrGradEdge
from .. import _vmap_internals
from ..overrides import handle_torch_function, has_torch_function, is_tensor_like
from . import forward_ad, functional, graph
from .anomaly_mode import detect_anomaly, set_detect_anomaly
from .function import Function, NestedIOFunction
from .grad_mode import (
    _force_original_view_tracking,
    _unsafe_preserve_version_counter,
    enable_grad,
    inference_mode,
    no_grad,
    set_grad_enabled,
    set_multithreading_enabled,
)
from .gradcheck import gradcheck, gradgradcheck
from .graph import _engine_run_backward

from .variable import Variable

__all__ = ["Variable", "Function", "backward", "grad_mode"]

_OptionalTensor = Optional[torch.Tensor]
_ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]


def _calculate_shape(

    output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool

) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
    # is_same_size ensures that both tensors are either nested or non nested
    # circular import
    from torch.nested._internal.nested_tensor import NestedTensor

    if output.is_nested and not isinstance(output, NestedTensor):
        if is_grads_batched:
            raise RuntimeError("Batched grads are not supported with Nested Tensor.")
        out_shape = output._nested_tensor_size()
        grad_shape = grad._nested_tensor_size()

        return out_shape, grad_shape

    reg_out_shape = output.shape
    reg_grad_shape = grad.shape if not is_grads_batched else grad.shape[1:]
    return reg_out_shape, reg_grad_shape


def _make_grads(

    outputs: Sequence[torch.Tensor],

    grads: Sequence[_OptionalTensor],

    is_grads_batched: bool,

) -> Tuple[_OptionalTensor, ...]:
    new_grads: List[_OptionalTensor] = []
    for out, grad in zip(outputs, grads):
        if isinstance(grad, torch.Tensor):
            from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq

            first_grad = grad if not is_grads_batched else grad[0]
            # TODO: We can remove this conditional once we uniformly use
            # singleton int to represent jagged dimension, so that size() call
            # on nested tensor works
            if out.is_nested or first_grad.is_nested:
                shape_matches = torch.is_same_size(out, first_grad)
            else:
                # We need to do a regular size check, without going through
                # the operator, to be able to handle unbacked symints
                # (expect_true ensures we can deal with unbacked)
                shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
            if not shape_matches:
                out_shape, grad_shape = _calculate_shape(
                    out, first_grad, is_grads_batched
                )
                if is_grads_batched:
                    raise RuntimeError(
                        "If `is_grads_batched=True`, we interpret the first "
                        "dimension of each grad_output as the batch dimension. "
                        "The sizes of the remaining dimensions are expected to match "
                        "the shape of corresponding output, but a mismatch "
                        "was detected: grad_output["
                        + str(grads.index(grad))
                        + "] has a shape of "
                        + str(grad_shape)
                        + " and output["
                        + str(outputs.index(out))
                        + "] has a shape of "
                        + str(out_shape)
                        + ". "
                        "If you only want some tensors in `grad_output` to be considered "
                        "batched, consider using vmap."
                    )
                else:
                    raise RuntimeError(
                        "Mismatch in shape: grad_output["
                        + str(grads.index(grad))
                        + "] has a shape of "
                        + str(grad_shape)
                        + " and output["
                        + str(outputs.index(out))
                        + "] has a shape of "
                        + str(out_shape)
                        + "."
                    )
            if out.dtype.is_complex != grad.dtype.is_complex:
                raise RuntimeError(
                    "For complex Tensors, both grad_output and output"
                    " are required to have the same dtype."
                    " Mismatch in dtype: grad_output["
                    + str(grads.index(grad))
                    + "] has a dtype of "
                    + str(grad.dtype)
                    + " and output["
                    + str(outputs.index(out))
                    + "] has a dtype of "
                    + str(out.dtype)
                    + "."
                )
            new_grads.append(grad)
        elif grad is None:
            if out.requires_grad:
                if out.numel() != 1:
                    raise RuntimeError(
                        "grad can be implicitly created only for scalar outputs"
                    )
                if not out.dtype.is_floating_point:
                    msg = (
                        "grad can be implicitly created only for real scalar outputs"
                        f" but got {out.dtype}"
                    )
                    raise RuntimeError(msg)
                new_grads.append(
                    torch.ones_like(out, memory_format=torch.preserve_format)
                )
            else:
                new_grads.append(None)
        else:
            raise TypeError(
                "gradients can be either Tensors or None, but got "
                + type(grad).__name__
            )
    return tuple(new_grads)


def _tensor_or_tensors_to_tuple(

    tensors: Optional[_TensorOrTensors], length: int

) -> Tuple[_OptionalTensor, ...]:
    if tensors is None:
        return (None,) * length
    if isinstance(tensors, torch.Tensor):
        return (tensors,)
    return tuple(tensors)


def backward(

    tensors: _TensorOrTensors,

    grad_tensors: Optional[_TensorOrTensors] = None,

    retain_graph: Optional[bool] = None,

    create_graph: bool = False,

    grad_variables: Optional[_TensorOrTensors] = None,

    inputs: Optional[_TensorOrTensorsOrGradEdge] = None,

) -> None:
    r"""Computes the sum of gradients of given tensors with respect to graph

    leaves.



    The graph is differentiated using the chain rule. If any of ``tensors``

    are non-scalar (i.e. their data has more than one element) and require

    gradient, then the Jacobian-vector product would be computed, in this

    case the function additionally requires specifying ``grad_tensors``.

    It should be a sequence of matching length, that contains the "vector"

    in the Jacobian-vector product, usually the gradient of the differentiated

    function w.r.t. corresponding tensors (``None`` is an acceptable value for

    all tensors that don't need gradient tensors).



    This function accumulates gradients in the leaves - you might need to zero

    ``.grad`` attributes or set them to ``None`` before calling it.

    See :ref:`Default gradient layouts<default-grad-layouts>`

    for details on the memory layout of accumulated gradients.



    .. note::

        Using this method with ``create_graph=True`` will create a reference cycle

        between the parameter and its gradient which can cause a memory leak.

        We recommend using ``autograd.grad`` when creating the graph to avoid this.

        If you have to use this function, make sure to reset the ``.grad`` fields of your

        parameters to ``None`` after use to break the cycle and avoid the leak.



    .. note::



        If you run any forward ops, create ``grad_tensors``, and/or call ``backward``

        in a user-specified CUDA stream context, see

        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.



    .. note::



        When ``inputs`` are provided and a given input is not a leaf,

        the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).

        It is an implementation detail on which the user should not rely.

        See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.



    Args:

        tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be

            computed.

        grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in

            the Jacobian-vector product, usually gradients w.r.t. each element of

            corresponding tensors. None values can be specified for scalar Tensors or

            ones that don't require grad. If a None value would be acceptable for all

            grad_tensors, then this argument is optional.

        retain_graph (bool, optional): If ``False``, the graph used to compute the grad

            will be freed. Note that in nearly all cases setting this option to ``True``

            is not needed and often can be worked around in a much more efficient

            way. Defaults to the value of ``create_graph``.

        create_graph (bool, optional): If ``True``, graph of the derivative will

            be constructed, allowing to compute higher order derivative products.

            Defaults to ``False``.

        inputs (Sequence[Tensor] or Tensor or Sequence[GradientEdge], optional): Inputs w.r.t. which the gradient

            be will accumulated into ``.grad``. All other Tensors will be ignored. If

            not provided, the gradient is accumulated into all the leaf Tensors that

            were used to compute the :attr:`tensors`.

    """
    if torch._C._are_functorch_transforms_active():
        raise RuntimeError(
            "backward() called inside a functorch transform. This is not "
            "supported, please use functorch.grad or functorch.vjp instead "
            "or call backward() outside of functorch transforms."
        )

    if grad_variables is not None:
        warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
        if grad_tensors is None:
            grad_tensors = grad_variables
        else:
            raise RuntimeError(
                "'grad_tensors' and 'grad_variables' (deprecated) "
                "arguments both passed to backward(). Please only "
                "use 'grad_tensors'."
            )
    if inputs is not None and len(inputs) == 0:
        raise RuntimeError("'inputs' argument to backward() cannot be empty.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    inputs = (
        (inputs,)
        if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
        else tuple(inputs)
        if inputs is not None
        else tuple()
    )

    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
    grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
    if retain_graph is None:
        retain_graph = create_graph

    # The reason we repeat the same comment below is that
    # some Python versions print out the first line of a multi-line function
    # calls in the traceback and some print out the last line
    _engine_run_backward(
        tensors,
        grad_tensors_,
        retain_graph,
        create_graph,
        inputs,
        allow_unreachable=True,
        accumulate_grad=True,
    )


def grad(

    outputs: _TensorOrTensors,

    inputs: _TensorOrTensorsOrGradEdge,

    grad_outputs: Optional[_TensorOrTensors] = None,

    retain_graph: Optional[bool] = None,

    create_graph: bool = False,

    only_inputs: bool = True,

    allow_unused: Optional[bool] = None,

    is_grads_batched: bool = False,

    materialize_grads: bool = False,

) -> Tuple[torch.Tensor, ...]:
    r"""Computes and returns the sum of gradients of outputs with respect to

    the inputs.



    ``grad_outputs`` should be a sequence of length matching ``output``

    containing the "vector" in vector-Jacobian product, usually the pre-computed

    gradients w.r.t. each of the outputs. If an output doesn't require_grad,

    then the gradient can be ``None``).



    .. note::



        If you run any forward ops, create ``grad_outputs``, and/or call ``grad``

        in a user-specified CUDA stream context, see

        :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.



    .. note::



        ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``).

        To accumulate gradient for other parts of the graph, please use

        ``torch.autograd.backward``.



    Args:

        outputs (sequence of Tensor): outputs of the differentiated function.

        inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be

            returned (and not accumulated into ``.grad``).

        grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.

            Usually gradients w.r.t. each output. None values can be specified for scalar

            Tensors or ones that don't require grad. If a None value would be acceptable

            for all grad_tensors, then this argument is optional. Default: None.

        retain_graph (bool, optional): If ``False``, the graph used to compute the grad

            will be freed. Note that in nearly all cases setting this option to ``True``

            is not needed and often can be worked around in a much more efficient

            way. Defaults to the value of ``create_graph``.

        create_graph (bool, optional): If ``True``, graph of the derivative will

            be constructed, allowing to compute higher order derivative products.

            Default: ``False``.

        allow_unused (Optional[bool], optional): If ``False``, specifying inputs

            that were not used when computing outputs (and therefore their grad is

            always zero) is an error. Defaults to the value of ``materialize_grads``.

        is_grads_batched (bool, optional): If ``True``, the first dimension of each

            tensor in ``grad_outputs`` will be interpreted as the batch dimension.

            Instead of computing a single vector-Jacobian product, we compute a

            batch of vector-Jacobian products for each "vector" in the batch.

            We use the vmap prototype feature as the backend to vectorize calls

            to the autograd engine so that this computation can be performed in a

            single call. This should lead to performance improvements when compared

            to manually looping and performing backward multiple times. Note that

            due to this feature being experimental, there may be performance

            cliffs. Please use ``torch._C._debug_only_display_vmap_fallback_warnings(True)``

            to show any performance warnings and file an issue on github if warnings exist

            for your use case. Defaults to ``False``.

        materialize_grads (bool, optional): If ``True``, set the gradient for unused inputs

            to zero instead of None. This is useful when computing higher-order derivatives.

            If ``materialize_grads`` is ``True`` and ``allow_unused`` is ``False``, an error

            will be raised. Defaults to ``False``.



    """
    if materialize_grads and allow_unused is False:
        raise ValueError(
            "Expected allow_unused to be True or not passed when materialize_grads=True, "
            "but got: allow_unused=False."
        )
    if allow_unused is None:
        allow_unused = materialize_grads
    t_outputs = cast(
        Tuple[torch.Tensor, ...],
        (outputs,) if is_tensor_like(outputs) else tuple(outputs),
    )
    if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
        inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
    else:
        inputs = tuple(inputs)
    t_inputs = tuple(i for i in inputs if is_tensor_like(i))
    overridable_args = t_outputs + t_inputs
    if has_torch_function(overridable_args):
        return handle_torch_function(
            grad,
            overridable_args,
            t_outputs,
            inputs,
            grad_outputs=grad_outputs,
            retain_graph=retain_graph,
            create_graph=create_graph,
            only_inputs=only_inputs,
            allow_unused=allow_unused,
            is_grads_batched=is_grads_batched,
            materialize_grads=materialize_grads,
        )

    if not only_inputs:
        warnings.warn(
            "only_inputs argument is deprecated and is ignored now "
            "(defaults to True). To accumulate gradient for other "
            "parts of the graph, please use torch.autograd.backward."
        )

    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs))
    grad_outputs_ = _make_grads(
        t_outputs, grad_outputs_, is_grads_batched=is_grads_batched
    )

    if retain_graph is None:
        retain_graph = create_graph

    # The reason we repeat the same comment several times below is because
    # some Python versions print out the first line of multi-line function
    # calls in the traceback and some print out the last line
    if is_grads_batched:

        def vjp(gO):
            return _engine_run_backward(
                t_outputs,
                gO,
                retain_graph,
                create_graph,
                inputs,
                allow_unused,
                accumulate_grad=False,
            )

        result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
            grad_outputs_
        )
    else:
        result = _engine_run_backward(
            t_outputs,
            grad_outputs_,
            retain_graph,
            create_graph,
            inputs,
            allow_unused,
            accumulate_grad=False,
        )
    if materialize_grads:
        if any(
            result[i] is None and not is_tensor_like(inputs[i])
            for i in range(len(inputs))
        ):
            raise RuntimeError(
                "materialize_grads cannot be used when the given input is a GradientEdge"
            )
        result = tuple(
            output
            if output is not None
            else torch.zeros_like(input, requires_grad=True)
            for (output, input) in zip(result, inputs)
        )
    return result


# This function applies in case of gradient checkpointing for memory
# optimization. Currently, gradient checkpointing is supported only if the
# execution engine is invoked through torch.autograd.backward() and its
# inputs argument is not passed. It is not supported for torch.autograd.grad().
# This is because if inputs are specified, the gradient won't be calculated for
# anything else e.g. model parameters like weights, bias etc.
#
# This function returns whether the checkpointing is valid i.e. torch.autograd.backward
# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread
# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask
# in the stack and before a NodeTask is executed in evaluate_function, it
# checks for whether reentrant backwards is imperative or not.
# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/context
def _is_checkpoint_valid():
    return Variable._execution_engine.is_checkpoint_valid()


def variable(*args, **kwargs):
    raise RuntimeError(
        "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead"
    )


# Monkey patching variable.Variable to fix FX codegen. FX generates a call by roughly doing
# f"{fn.__module__}.{fn.__name__}(...). This yields torch.autograd.variable.Variable(...) in the
# output of an FX graph.  Unfortunately the module name torch.autograd.variable is shadowed by the
# deprecated function - variable(...).
variable.Variable = Variable  # type: ignore[attr-defined]

if not torch._C._autograd_init():
    raise RuntimeError("autograd initialization failed")

# Import all native method/classes
from torch._C._autograd import (
    _add_metadata_json,
    _disable_profiler,
    _disable_profiler_legacy,
    _enable_profiler,
    _enable_profiler_legacy,
    _enable_record_function,
    _get_sequence_nr,
    _kineto_step,
    _KinetoEvent,
    _pop_saved_tensors_default_hooks,
    _prepare_profiler,
    _profiler_enabled,
    _ProfilerResult,
    _push_saved_tensors_default_hooks,
    _record_function_with_args_enter,
    _record_function_with_args_exit,
    _set_empty_test_observer,
    _supported_activities,
    DeviceType,
    kineto_available,
    ProfilerEvent,
    SavedTensor,
)

from torch._C._profiler import ProfilerActivity, ProfilerConfig, ProfilerState

from . import profiler


def _register_py_tensor_class_for_device(device, cls):
    if not isinstance(cls, type):
        raise RuntimeError("cls isn't a typeinfo object")
    torch._C._register_py_class_for_device(device, cls)


is_multithreading_enabled = torch._C._is_multithreading_enabled
torch._C._add_docstr(
    is_multithreading_enabled, "Returns True if multithreading is currently enabled."
)

is_view_replay_enabled = torch._C._is_view_replay_enabled
torch._C._add_docstr(
    is_view_replay_enabled, "Returns True if view-replay is currently enabled."
)