File size: 4,929 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import warnings

import torch

__all__ = ["detect_anomaly", "set_detect_anomaly"]


class detect_anomaly:
    r"""Context-manager that enable anomaly detection for the autograd engine.



    This does two things:



    - Running the forward pass with detection enabled will allow the backward

      pass to print the traceback of the forward operation that created the failing

      backward function.

    - If ``check_nan`` is ``True``, any backward computation that generate "nan"

      value will raise an error. Default ``True``.



    .. warning::

        This mode should be enabled only for debugging as the different tests

        will slow down your program execution.



    Example:



        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ANOMALY)

        >>> import torch

        >>> from torch import autograd

        >>> class MyFunc(autograd.Function):

        ...     @staticmethod

        ...     def forward(ctx, inp):

        ...         return inp.clone()

        ...     @staticmethod

        ...     def backward(ctx, gO):

        ...         # Error during the backward pass

        ...         raise RuntimeError("Some error in backward")

        ...         return gO.clone()

        >>> def run_fn(a):

        ...     out = MyFunc.apply(a)

        ...     return out.sum()

        >>> inp = torch.rand(10, 10, requires_grad=True)

        >>> out = run_fn(inp)

        >>> out.backward()

            Traceback (most recent call last):

              File "<stdin>", line 1, in <module>

              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward

                torch.autograd.backward(self, gradient, retain_graph, create_graph)

              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward

                allow_unreachable=True)  # allow_unreachable flag

              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply

                return self._forward_cls.backward(self, *args)

              File "<stdin>", line 8, in backward

            RuntimeError: Some error in backward

        >>> with autograd.detect_anomaly():

        ...     inp = torch.rand(10, 10, requires_grad=True)

        ...     out = run_fn(inp)

        ...     out.backward()

            Traceback of forward call that caused the error:

              File "tmp.py", line 53, in <module>

                out = run_fn(inp)

              File "tmp.py", line 44, in run_fn

                out = MyFunc.apply(a)

            Traceback (most recent call last):

              File "<stdin>", line 4, in <module>

              File "/your/pytorch/install/torch/_tensor.py", line 93, in backward

                torch.autograd.backward(self, gradient, retain_graph, create_graph)

              File "/your/pytorch/install/torch/autograd/__init__.py", line 90, in backward

                allow_unreachable=True)  # allow_unreachable flag

              File "/your/pytorch/install/torch/autograd/function.py", line 76, in apply

                return self._forward_cls.backward(self, *args)

              File "<stdin>", line 8, in backward

            RuntimeError: Some error in backward



    """

    def __init__(self, check_nan=True) -> None:
        self.prev = torch.is_anomaly_enabled()
        self.check_nan = check_nan
        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
        warnings.warn(
            "Anomaly Detection has been enabled. "
            "This mode will increase the runtime "
            "and should only be enabled for debugging.",
            stacklevel=2,
        )

    def __enter__(self) -> None:
        torch.set_anomaly_enabled(True, self.check_nan)

    def __exit__(self, *args: object) -> None:
        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)


class set_detect_anomaly:
    r"""Context-manager that sets the anomaly detection for the autograd engine on or off.



    ``set_detect_anomaly`` will enable or disable the autograd anomaly detection

    based on its argument :attr:`mode`.

    It can be used as a context-manager or as a function.



    See ``detect_anomaly`` above for details of the anomaly detection behaviour.



    Args:

        mode (bool): Flag whether to enable anomaly detection (``True``),

                     or disable (``False``).

        check_nan (bool): Flag whether to raise an error when the backward

                          generate "nan"



    """

    def __init__(self, mode: bool, check_nan: bool = True) -> None:
        self.prev = torch.is_anomaly_enabled()
        self.prev_check_nan = torch.is_anomaly_check_nan_enabled()
        torch.set_anomaly_enabled(mode, check_nan)

    def __enter__(self) -> None:
        pass

    def __exit__(self, *args: object) -> None:
        torch.set_anomaly_enabled(self.prev, self.prev_check_nan)