Spaces:
Sleeping
Sleeping
# Simple gradient checkpointing. Works with distributed data parallel | |
import torch as t | |
def checkpoint(func, inputs, params, flag): | |
if flag: | |
args = inputs + tuple(params) | |
return CheckpointFunction.apply(func, len(inputs), *args) | |
else: | |
return func(*inputs) | |
class CheckpointFunction(t.autograd.Function): | |
def forward(ctx, run_function, length, *args): | |
ctx.run_function = run_function | |
ctx.input_tensors = list(args[:length]) | |
ctx.input_params = list(args[length:]) | |
with t.no_grad(): | |
output_tensors = ctx.run_function(*ctx.input_tensors) | |
return output_tensors | |
def backward(ctx, *output_grads): | |
for i in range(len(ctx.input_tensors)): | |
temp = ctx.input_tensors[i] | |
ctx.input_tensors[i] = temp.detach() | |
ctx.input_tensors[i].requires_grad = temp.requires_grad | |
with t.enable_grad(): | |
output_tensors = ctx.run_function(*ctx.input_tensors) | |
input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) | |
del ctx.input_tensors | |
del output_tensors | |
return (None, None) + input_grads | |