File size: 4,241 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Callable, Iterable, Sequence, Union

import torch
from torch.cuda.amp import custom_bwd, custom_fwd


def checkpoint(
    func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
    inputs: Sequence[torch.Tensor],
    params: Iterable[torch.Tensor],
    flag: bool,
):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)


class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.length = length
        input_tensors = list(args[:length])
        input_params = list(args[length:])
        ctx.save_for_backward(*input_tensors, *input_params)
        with torch.no_grad():
            output_tensors = ctx.run_function(*input_tensors)
        return output_tensors

    @staticmethod
    @custom_bwd
    def backward(ctx, *output_grads):
        inputs = ctx.saved_tensors
        input_tensors = inputs[: ctx.length]
        input_params = inputs[ctx.length :]
        res = CheckpointFunctionGradFunction.apply(
            ctx.run_function,
            len(input_tensors),
            len(input_params),
            *input_tensors,
            *input_params,
            *output_grads
        )
        return (None, None) + res


class CheckpointFunctionGradFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, run_function, length_1, length_2, *args):
        ctx.run_function = run_function
        ctx.length_1 = length_1
        ctx.length_2 = length_2
        input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]
        input_params = list(args[length_1 : length_1 + length_2])
        output_grads = list(args[length_1 + length_2 :])
        ctx.save_for_backward(*input_tensors, *input_params, *output_grads)

        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            input_tensors + input_params,
            output_grads,
            allow_unused=True,
        )
        return input_grads

    @staticmethod
    @custom_bwd
    def backward(ctx, *all_output_grads):
        args = ctx.saved_tensors
        input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]
        input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])
        output_grads = [
            x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]
        ]

        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
            input_grads = torch.autograd.grad(
                output_tensors,
                input_tensors + input_params,
                output_grads,
                allow_unused=True,
                create_graph=True,
                retain_graph=True,
            )
        input_grads_grads = torch.autograd.grad(
            input_grads,
            input_tensors + input_params + output_grads,
            all_output_grads,
            allow_unused=True,
        )
        del input_grads
        return (None, None, None) + input_grads_grads