File size: 35,508 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
import functools
import inspect
import itertools
import warnings
from collections import OrderedDict
from typing import Any, List, Optional, Tuple

import torch
import torch._C as _C
import torch._functorch as _functorch
import torch.utils.hooks as hooks
from torch._C import _functions
from torch._functorch.autograd_function import custom_function_call

__all__ = [
    "FunctionCtx",
    "BackwardCFunction",
    "FunctionMeta",
    "Function",
    "once_differentiable",
    "traceable",
    "InplaceFunction",
    "NestedIOFunction",
]

# Unique id provider for each class inheriting from Function
# This is incremented in FunctionMeta during class definition
AUTOGRAD_FUNCTION_COUNTER = itertools.count()


# Formerly known as: _ContextMethodMixin
class FunctionCtx:
    def save_for_backward(self, *tensors: torch.Tensor):
        r"""Save given tensors for a future call to :func:`~Function.backward`.



        ``save_for_backward`` should be called at most once, only from inside the

        :func:`forward` method, and only with tensors.



        All tensors intended to be used in the backward pass should be saved

        with ``save_for_backward`` (as opposed to directly on ``ctx``) to prevent

        incorrect gradients and memory leaks, and enable the application of saved

        tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.



        Note that if intermediary tensors, tensors that are neither inputs

        nor outputs of :func:`forward`, are saved for backward, your custom Function

        may not support double backward.

        Custom Functions that do not support double backward should decorate their

        :func:`backward` method with ``@once_differentiable`` so that performing

        double backward raises an error. If you'd like to support double backward,

        you can either recompute intermediaries based on the inputs during backward

        or return the intermediaries as the outputs of the custom Function. See the

        `double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_

        for more details.



        In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`

        attribute. Before returning them to the user, a check is made to ensure

        they weren't used in any in-place operation that modified their content.



        Arguments can also be ``None``. This is a no-op.



        See :ref:`extending-autograd` for more details on how to use this method.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)

            >>> class Func(Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):

            >>>         w = x * z

            >>>         out = x * y + y * z + w * y

            >>>         ctx.save_for_backward(x, y, w, out)

            >>>         ctx.z = z  # z is not a tensor

            >>>         return out

            >>>

            >>>     @staticmethod

            >>>     @once_differentiable

            >>>     def backward(ctx, grad_out):

            >>>         x, y, w, out = ctx.saved_tensors

            >>>         z = ctx.z

            >>>         gx = grad_out * (y + y * z)

            >>>         gy = grad_out * (x + z + w)

            >>>         gz = None

            >>>         return gx, gy, gz

            >>>

            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)

            >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)

            >>> c = 4

            >>> d = Func.apply(a, b, c)



        """
        self.to_save = tensors

    def save_for_forward(self, *tensors: torch.Tensor):
        r"""Save given tensors for a future call to :func:`~Function.jvp`.



        ``save_for_forward`` should be only called once, from inside the :func:`forward`

        method, and only be called with tensors.



        In :func:`jvp`, saved objects can be accessed through the :attr:`saved_tensors`

        attribute.



        Arguments can also be ``None``. This is a no-op.



        See :ref:`extending-autograd` for more details on how to use this method.



        Example::

            >>> # xdoctest: +SKIP

            >>> class Func(torch.autograd.Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):

            >>>         ctx.save_for_backward(x, y)

            >>>         ctx.save_for_forward(x, y)

            >>>         ctx.z = z

            >>>         return x * y * z

            >>>

            >>>     @staticmethod

            >>>     def jvp(ctx, x_t, y_t, _):

            >>>         x, y = ctx.saved_tensors

            >>>         z = ctx.z

            >>>         return z * (y * x_t + x * y_t)

            >>>

            >>>     @staticmethod

            >>>     def vjp(ctx, grad_out):

            >>>         x, y = ctx.saved_tensors

            >>>         z = ctx.z

            >>>         return z * grad_out * y, z * grad_out * x, None

            >>>

            >>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)

            >>>     t = torch.tensor(1., dtype=torch.double)

            >>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)

            >>>     c = 4

            >>>

            >>>     with fwAD.dual_level():

            >>>         a_dual = fwAD.make_dual(a, t)

            >>>         d = Func.apply(a_dual, b, c)



        """
        for tensor in tensors:
            assert isinstance(tensor, torch.Tensor) or tensor is None, (
                "save_for_forward expects all arguments to be tensors; you should "
                "save non-tensors as attributes on ctx."
            )

        self.saved_for_forward = tensors

    def mark_dirty(self, *args: torch.Tensor):
        r"""Mark given tensors as modified in an in-place operation.



        **This should be called at most once, only from inside the**

        :func:`forward` **method, and all arguments should be inputs.**



        Every tensor that's been modified in-place in a call to :func:`forward`

        should be given to this function, to ensure correctness of our checks.

        It doesn't matter whether the function is called before or after

        modification.



        Examples::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)

            >>> class Inplace(Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x):

            >>>         x_npy = x.numpy() # x_npy shares storage with x

            >>>         x_npy += 1

            >>>         ctx.mark_dirty(x)

            >>>         return x

            >>>

            >>>     @staticmethod

            >>>     @once_differentiable

            >>>     def backward(ctx, grad_output):

            >>>         return grad_output

            >>>

            >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()

            >>> b = a * a

            >>> Inplace.apply(a)  # This would lead to wrong gradients!

            >>>                   # but the engine would not know unless we mark_dirty

            >>> # xdoctest: +SKIP

            >>> b.backward() # RuntimeError: one of the variables needed for gradient

            >>>              # computation has been modified by an inplace operation



        """
        self.dirty_tensors = args

    def mark_shared_storage(self, *pairs):
        warnings.warn(
            "mark_shared_storage is deprecated. "
            "Tensors with shared storages are automatically tracked. Note "
            "that calls to `set_()` are not tracked"
        )

    def mark_non_differentiable(self, *args: torch.Tensor):
        r"""Mark outputs as non-differentiable.



        **This should be called at most once, only from inside the**

        :func:`forward` **method, and all arguments should be tensor outputs.**



        This will mark outputs as not requiring gradients, increasing the

        efficiency of backward computation. You still need to accept a gradient

        for each output in :meth:`~Function.backward`, but it's always going to

        be a zero tensor with the same shape as the shape of a corresponding

        output.



        This is used e.g. for indices returned from a sort. See example::

            >>> class Func(Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x):

            >>>         sorted, idx = x.sort()

            >>>         ctx.mark_non_differentiable(idx)

            >>>         ctx.save_for_backward(x, idx)

            >>>         return sorted, idx

            >>>

            >>>     @staticmethod

            >>>     @once_differentiable

            >>>     def backward(ctx, g1, g2):  # still need to accept g2

            >>>         x, idx = ctx.saved_tensors

            >>>         grad_input = torch.zeros_like(x)

            >>>         grad_input.index_add_(0, idx, g1)

            >>>         return grad_input



        """
        self.non_differentiable = args

    def set_materialize_grads(self, value: bool):
        r"""Set whether to materialize grad tensors. Default is ``True``.



        **This should be called only from inside the** :func:`forward` **method**



        If ``True``, undefined grad tensors will be expanded to tensors full of zeros

        prior to calling the :func:`backward` and :func:`jvp` methods.



        Example::

            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)

            >>> class SimpleFunc(Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x):

            >>>         return x.clone(), x.clone()

            >>>

            >>>     @staticmethod

            >>>     @once_differentiable

            >>>     def backward(ctx, g1, g2):

            >>>         return g1 + g2  # No check for None necessary

            >>>

            >>> # We modify SimpleFunc to handle non-materialized grad outputs

            >>> class Func(Function):

            >>>     @staticmethod

            >>>     def forward(ctx, x):

            >>>         ctx.set_materialize_grads(False)

            >>>         ctx.save_for_backward(x)

            >>>         return x.clone(), x.clone()

            >>>

            >>>     @staticmethod

            >>>     @once_differentiable

            >>>     def backward(ctx, g1, g2):

            >>>         x, = ctx.saved_tensors

            >>>         grad_input = torch.zeros_like(x)

            >>>         if g1 is not None:  # We must check for None now

            >>>             grad_input += g1

            >>>         if g2 is not None:

            >>>             grad_input += g2

            >>>         return grad_input

            >>>

            >>> a = torch.tensor(1., requires_grad=True)

            >>> b, _ = Func.apply(a)  # induces g2 to be undefined



        """
        self.materialize_grads = value


# DO NOT USE: This is only defined to be able to load old serialized models
_ContextMethodMixin = FunctionCtx


class _HookMixin:
    @staticmethod
    def _register_hook(backward_hooks, hook):
        if backward_hooks is None:
            backward_hooks = OrderedDict()
        handle = hooks.RemovableHandle(backward_hooks)
        backward_hooks[handle.id] = hook
        return backward_hooks, handle


class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
    r"""

    This class is used for internal autograd work. Do not use.

    """

    def apply(self, *args):
        r"""

        Apply method used when executing this Node during the backward

        """
        # _forward_cls is defined by derived class
        # The user should define either backward or vjp but never both.
        backward_fn = self._forward_cls.backward  # type: ignore[attr-defined]
        vjp_fn = self._forward_cls.vjp  # type: ignore[attr-defined]
        if backward_fn is not Function.backward and vjp_fn is not Function.vjp:
            raise RuntimeError(
                "Implementing both 'backward' and 'vjp' for a custom "
                "Function is not allowed. You should only implement one "
                "of them."
            )
        user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
        return user_fn(self, *args)

    def apply_jvp(self, *args):
        r"""

        Apply method used when executing forward mode AD during the forward

        """
        # _forward_cls is defined by derived class
        return self._forward_cls.jvp(self, *args)  # type: ignore[attr-defined]

    def _compiled_autograd_key(self):
        return self._forward_cls._compiled_autograd_key(self)  # type: ignore[attr-defined]


def _warn_traceable_deprecated():
    warnings.warn(
        "The is_traceable field on torch.autograd.Function is deprecated "
        "and will be removed in PyTorch 2.4.",
        stacklevel=3,
    )


class FunctionMeta(type):
    """Function metaclass.



    This metaclass sets up the following properties:

        _backward_cls: The Function class corresponding to the differentiated

            version of this function (which is generated on the fly by this

            metaclass).

    """

    def __init__(cls, name, bases, attrs):
        backward_fn = type(
            name + "Backward", (BackwardCFunction,), {"_forward_cls": cls}
        )
        backward_fn._autograd_function_id = next(AUTOGRAD_FUNCTION_COUNTER)  # type: ignore[attr-defined]
        backward_fn._compiled_autograd_should_lift = attrs.get(  # type: ignore[attr-defined]
            "_compiled_autograd_should_lift", True
        )
        cls._backward_cls = backward_fn

        if "is_traceable" in attrs and attrs["is_traceable"] is True:
            _warn_traceable_deprecated()

        super().__init__(name, bases, attrs)

    def __getattribute__(cls, name):
        if name == "is_traceable":
            _warn_traceable_deprecated()
        return super().__getattribute__(name)

    def __setattr__(cls, name, value):
        if name == "is_traceable" and value is True:
            warnings.warn(
                "The is_traceable field on torch.autograd.Function is deprecated "
                "and will be removed in PyTorch 2.4.",
                stacklevel=2,
            )
        return super().__setattr__(name, value)


class _SingleLevelFunction(
    _C._FunctionBase, FunctionCtx, _HookMixin, metaclass=FunctionMeta
):
    @staticmethod
    def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
        r"""Define the forward of the custom autograd Function.



        This function is to be overridden by all subclasses.

        There are two ways to define forward:



        Usage 1 (Combined forward and ctx)::



            @staticmethod

            def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:

                pass



        - It must accept a context ctx as the first argument, followed by any

          number of arguments (tensors or other types).

        - See :ref:`combining-forward-context` for more details



        Usage 2 (Separate forward and ctx)::



            @staticmethod

            def forward(*args: Any, **kwargs: Any) -> Any:

                pass



            @staticmethod

            def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:

                pass



        - The forward no longer accepts a ctx argument.

        - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`

          staticmethod to handle setting up the ``ctx`` object.

          ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs

          to the forward.

        - See :ref:`extending-autograd` for more details



        The context can be used to store arbitrary data that can be then

        retrieved during the backward pass. Tensors should not be stored

        directly on `ctx` (though this is not currently enforced for

        backward compatibility). Instead, tensors should be saved either with

        :func:`ctx.save_for_backward` if they are intended to be used in

        ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`

        if they are intended to be used for in ``jvp``.

        """
        raise NotImplementedError(
            "You must implement the forward function for custom autograd.Function."
        )

    @staticmethod
    def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> Any:
        r"""There are two ways to define the forward pass of an autograd.Function.



        Either:



        1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.

           ``setup_context`` is not overridden. Setting up the ctx for backward

           happens inside the ``forward``.

        2. Override forward with the signature ``forward(*args, **kwargs)`` and

           override ``setup_context``. Setting up the ctx for backward happens

           inside ``setup_context`` (as opposed to inside the ``forward``)



        See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.

        """
        raise NotImplementedError("setup_context is not implemented.")

    @staticmethod
    def backward(ctx: Any, *grad_outputs: Any) -> Any:
        r"""Define a formula for differentiating the operation with backward mode automatic differentiation.



        This function is to be overridden by all subclasses.

        (Defining this function is equivalent to defining the ``vjp`` function.)



        It must accept a context :attr:`ctx` as the first argument, followed by

        as many outputs as the :func:`forward` returned (None will be passed in

        for non tensor outputs of the forward function),

        and it should return as many tensors, as there were inputs to

        :func:`forward`. Each argument is the gradient w.r.t the given output,

        and each returned value should be the gradient w.r.t. the

        corresponding input. If an input is not a Tensor or is a Tensor not

        requiring grads, you can just pass None as a gradient for that input.



        The context can be used to retrieve tensors saved during the forward

        pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple

        of booleans representing whether each input needs gradient. E.g.,

        :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the

        first input to :func:`forward` needs gradient computed w.r.t. the

        output.

        """
        raise NotImplementedError(
            "You must implement either the backward or vjp method for "
            "your custom autograd.Function to use it with backward "
            "mode AD."
        )

    # vjp and backward are alias of each other
    vjp = backward

    @staticmethod
    def jvp(ctx: Any, *grad_inputs: Any) -> Any:
        r"""Define a formula for differentiating the operation with forward mode automatic differentiation.



        This function is to be overridden by all subclasses.

        It must accept a context :attr:`ctx` as the first argument, followed by

        as many inputs as the :func:`forward` got (None will be passed in

        for non tensor inputs of the forward function),

        and it should return as many tensors as there were outputs to

        :func:`forward`. Each argument is the gradient w.r.t the given input,

        and each returned value should be the gradient w.r.t. the

        corresponding output. If an output is not a Tensor or the function is not

        differentiable with respect to that output, you can just pass None as a

        gradient for that input.



        You can use the :attr:`ctx` object to pass any value from the forward to this

        functions.

        """
        raise NotImplementedError(
            "You must implement the jvp function for custom "
            "autograd.Function to use it with forward mode AD."
        )


class Function(_SingleLevelFunction):
    r"""Base class to create custom `autograd.Function`.



    To create a custom `autograd.Function`, subclass this class and implement

    the :meth:`forward` and :meth:`backward` static methods. Then, to use your custom

    op in the forward pass, call the class method ``apply``. Do not call

    :meth:`forward` directly.



    To ensure correctness and best performance, make sure you are calling the

    correct methods on ``ctx`` and validating your backward function using

    :func:`torch.autograd.gradcheck`.



    See :ref:`extending-autograd` for more details on how to use this class.



    Examples::



        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)

        >>> class Exp(Function):

        >>>     @staticmethod

        >>>     def forward(ctx, i):

        >>>         result = i.exp()

        >>>         ctx.save_for_backward(result)

        >>>         return result

        >>>

        >>>     @staticmethod

        >>>     def backward(ctx, grad_output):

        >>>         result, = ctx.saved_tensors

        >>>         return grad_output * result

        >>>

        >>> # Use it by calling the apply method:

        >>> # xdoctest: +SKIP

        >>> output = Exp.apply(input)

    """

    def __init__(self, *args, **kwargs):
        cls = self.__class__
        warnings.warn(
            f"{cls} should not be instantiated. Methods on autograd functions"
            "are all static, so you should invoke them on the class itself. "
            "Instantiating an autograd function will raise an "
            "error in a future version of PyTorch.",
            DeprecationWarning,
            stacklevel=2,
        )

    def __call__(self, *args, **kwargs):
        raise RuntimeError(
            "Legacy autograd function with non-static forward method is deprecated. "
            "Please use new-style autograd function with static forward method. "
            "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)"
        )

    # for the tracer
    is_traceable = False

    """

    Bool that specifies if PyTorch should attempt to autogenerate

    :func:`torch.vmap` support for this autograd.Function. You may set this to

    True only if this autograd.Function's forward, backward, and jvp (if they

    exist) are written using PyTorch operations; otherwise, please override

    :meth:`torch.autograd.Function.vmap` to add support for :func:`torch.vmap`.



    Please see :ref:`func-autograd-function` for more details.

    """
    generate_vmap_rule = False

    @staticmethod
    def vmap(info, in_dims, *args):
        r"""Define the behavior for this autograd.Function underneath :func:`torch.vmap`.



        For a :func:`torch.autograd.Function` to support

        :func:`torch.vmap`, you must either override this static method, or set

        ``generate_vmap_rule`` to ``True`` (you may not do both).



        If you choose to override this staticmethod: it must accept



        - an ``info`` object as the first argument. ``info.batch_size``

          specifies the size of the dimension being vmapped over,

          while ``info.randomness`` is the randomness option passed to

          :func:`torch.vmap`.

        - an ``in_dims`` tuple as the second argument.

          For each arg in ``args``, ``in_dims`` has a corresponding

          ``Optional[int]``. It is ``None`` if the arg is not a Tensor or if

          the arg is not being vmapped over, otherwise, it is an integer

          specifying what dimension of the Tensor is being vmapped over.

        - ``*args``, which is the same as the args to :meth:`~Function.forward`.



        The return of the vmap staticmethod is a tuple of ``(output, out_dims)``.

        Similar to ``in_dims``, ``out_dims`` should be of the same structure as

        ``output`` and contain one ``out_dim`` per output that specifies if the

        output has the vmapped dimension and what index it is in.



        Please see :ref:`func-autograd-function` for more details.

        """
        raise NotImplementedError(
            "To use autograd.Function with vmap, you must either override the "
            "vmap staticmethod or set generate_vmap_rule=True."
        )

    @classmethod
    def apply(cls, *args, **kwargs):
        def bind_default_args(func, *args, **kwargs):
            signature = inspect.signature(func)
            bound_args = signature.bind(*args, **kwargs)
            bound_args.apply_defaults()

            return bound_args.args

        is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
        if is_setup_ctx_defined:
            args = bind_default_args(cls.forward, *args, **kwargs)

        if not torch._C._are_functorch_transforms_active():
            # See NOTE: [functorch vjp and autograd interaction]
            args = _functorch.utils.unwrap_dead_wrappers(args)
            return super().apply(*args, **kwargs)  # type: ignore[misc]

        if not is_setup_ctx_defined:
            raise RuntimeError(
                "In order to use an autograd.Function with functorch transforms "
                "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
                "staticmethod. For more details, please see "
                "https://pytorch.org/docs/master/notes/extending.func.html"
            )

        return custom_function_call(cls, *args, **kwargs)

    @staticmethod
    def _compiled_autograd_key(ctx):
        return (ctx._autograd_function_id,)


def once_differentiable(fn):
    @functools.wraps(fn)
    def wrapper(ctx, *args):
        with torch.no_grad():
            outputs = fn(ctx, *args)

        if not torch.is_grad_enabled():
            return outputs

        # If any of the inputs have requires_grad=True, we force the outputs
        # to have requires_grad=True but point to a grad_fn which throws an
        # error message during (double) back-propagation.
        # XXX: this is only an approximation of requires_grad - there's no way
        # to figure out if fn didn't use ctx.saved_tensors and as a result
        # some Tensors might require grad, even if no args do.
        # Unfortunately, this leads to unexpected error messages ("no nodes
        # require computing gradients"), but I don't have a better idea.
        # These functions would raise an error in backward anyway.
        requires_grad = any(
            isinstance(arg, torch.Tensor) and arg.requires_grad for arg in args
        )
        if not requires_grad:
            return outputs

        if not isinstance(outputs, tuple):
            outputs = (outputs,)

        err_fn = _functions.DelayedError(
            b"trying to differentiate twice a function that was marked "
            b"with @once_differentiable",
            len(outputs),
        )

        # Create aliases of each output that has requires_grad=True. We need
        # at least one of the inputs to err_fn to require grad so that the
        # output will have a grad_fn.
        def fake_requires_grad(var):
            if var is not None:
                var = var.detach()
                var.requires_grad = True
            return var

        return err_fn(*[fake_requires_grad(v) for v in outputs])

    return wrapper


def traceable(fn_cls):
    r"""Mark Function as traceable for the JIT.



    Traceable functions have additional restrictions - they can't pass any

    data-dependent values to backward (e.g. Prod passes the output, which makes

    it non-traceable), and their backward should be implemented entirely in terms

    of operations on autograd Tensors in all cases.



    DON'T USE THIS DECORATOR. IT IS FOR INTERNAL USE ONLY AND SHOULD BE HANDLED WITH

    CARE (or can give incorrect results otherwise).

    """
    warnings.warn(
        "torch.autograd.function.traceable is deprecated "
        "and will be removed in PyTorch 2.4.",
        stacklevel=2,
    )
    fn_cls.is_traceable = True
    return fn_cls


class InplaceFunction(Function):
    r"""

    This class is here only for backward compatibility reasons.

    Use :class:`Function` instead of this for any new use case.

    """

    def __init__(self, inplace=False):
        super().__init__()
        self.inplace = inplace


def _nested_map(condition, fn, condition_msg=None):
    def _map(obj):
        if condition(obj):
            return fn(obj)
        elif obj is None:
            return None
        elif isinstance(obj, (list, tuple)):
            mapped = (_map(x) for x in obj)
            if hasattr(obj, "_fields"):
                # obj is namedtuple
                return type(obj)(*mapped)
            return type(obj)(mapped)
        elif isinstance(obj, dict):
            return {x: _map(obj[x]) for x in obj}
        else:
            raise ValueError(
                "Auto nesting doesn't know how to process "
                "an input object of type "
                + torch.typename(obj)
                + (
                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
                    if condition_msg
                    else ""
                )
            )

    return _map


def _jit_unwrap_structured(obj):
    if hasattr(obj, "_jit_unwrap"):
        return obj._jit_unwrap()
    return obj


def _iter_filter(condition, allow_unknown=False, condition_msg=None, conversion=None):
    def _iter(obj):
        if conversion is not None:
            obj = conversion(obj)
        if condition(obj):
            yield obj
        elif obj is None:
            return
        elif isinstance(obj, (list, tuple)):
            for o in obj:
                yield from _iter(o)
        elif isinstance(obj, dict):
            # We only accept primitive key types, so we needn't inspect them
            for o in obj.values():
                yield from _iter(o)
        elif allow_unknown:
            yield obj
        else:
            raise ValueError(
                "Auto nesting doesn't know how to process "
                "an input object of type "
                + torch.typename(obj)
                + (
                    ". Accepted types: " + condition_msg + ", or lists/tuples of them"
                    if condition_msg
                    else ""
                )
            )

    return _iter


def _unflatten(input, proto):
    # unflatten a list or tuple input into a nested list/tuple structure
    # specified by proto
    def unflatten_helper(input, proto):
        res: List[Optional[torch.Tensor]] = []
        if hasattr(proto, "_jit_wrap"):
            return proto._jit_wrap(input)
        if not isinstance(proto, (list, tuple)):
            return input[0], input[1:]
        for e in proto:
            if e is None:
                res.append(e)
            else:
                res_e, input = unflatten_helper(input, e)
                res.append(res_e)
        return type(proto)(res), input

    return unflatten_helper(input, proto)[0]


_iter_jit_values = _iter_filter(
    lambda o: o is None or isinstance(o, torch._C.Value),
    condition_msg="jit's Values or None",
)
_iter_tensors = _iter_filter(
    lambda x: isinstance(x, torch.Tensor),
    condition_msg="Tensors",
    conversion=_jit_unwrap_structured,
)
_iter_tensors_permissive = _iter_filter(
    lambda x: isinstance(x, torch.Tensor),
    allow_unknown=True,
    condition_msg="Tensors (permissive)",
)
_iter_None_tensors = _iter_filter(
    lambda o: o is None or isinstance(o, torch.Tensor), condition_msg="Tensors or None"
)
_map_tensor_data = _nested_map(
    lambda x: isinstance(x, torch.Tensor), lambda o: o.data, condition_msg="Tensors"
)


class NestedIOFunction(Function):
    r"""

    This class is here only for backward compatibility reasons.

    Use :class:`Function` instead of this for any new use case.

    """
    # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
    # superclass (Function) but are instance methods here, which mypy reports as incompatible.

    def _do_forward(self, *input):
        self._nested_input = input
        flat_input = tuple(_iter_tensors(input))
        flat_output = super()._do_forward(*flat_input)  # type: ignore[misc]
        nested_output = self._nested_output
        nested_tensors = _unflatten(flat_output, self._nested_output)
        return nested_tensors

    def _do_backward(self, gradients, retain_variables):
        self.retain_variables = retain_variables
        result = super()._do_backward(gradients, retain_variables)  # type: ignore[misc]
        if not retain_variables:
            del self._nested_output
            del self._to_save_nested
        return result

    def backward(self, *gradients: Any) -> Any:  # type: ignore[override]
        r"""

        Shared backward utility.

        """
        nested_gradients = _unflatten(gradients, self._nested_output)
        result = self.backward_extended(*nested_gradients)  # type: ignore[func-returns-value]
        return tuple(_iter_None_tensors(result))

    __call__ = _do_forward

    def forward(self, *args: Any) -> Any:  # type: ignore[override]
        r"""

        Shared forward utility.

        """
        nested_tensors = _map_tensor_data(self._nested_input)
        result = self.forward_extended(*nested_tensors)  # type: ignore[func-returns-value]
        del self._nested_input
        self._nested_output = result
        return tuple(_iter_tensors(result))

    def save_for_backward(self, *args: Any) -> None:
        r"""

        See :meth:`Function.save_for_backward`.

        """
        self.to_save = tuple(_iter_tensors(args))
        self._to_save_nested = args

    @property
    def saved_tensors(self):
        r"""

        See :meth:`Function.saved_tensors`.

        """
        flat_tensors = super().saved_tensors  # type: ignore[misc]
        return _unflatten(flat_tensors, self._to_save_nested)

    def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
        r"""

        See :meth:`Function.mark_dirty`.

        """
        self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))

    def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
        r"""

        See :meth:`Function.mark_non_differentiable`.

        """
        self.non_differentiable = tuple(_iter_tensors((args, kwargs)))

    def forward_extended(self, *input: Any) -> None:
        r"""

        User defined forward.

        """
        raise NotImplementedError

    def backward_extended(self, *grad_output: Any) -> None:
        r"""

        User defined backward.

        """
        raise NotImplementedError