robinwitch's picture
update
1da48bb
# Simple gradient checkpointing. Works with distributed data parallel
import torch as t
def checkpoint(func, inputs, params, flag):
if flag:
args = inputs + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(t.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with t.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
for i in range(len(ctx.input_tensors)):
temp = ctx.input_tensors[i]
ctx.input_tensors[i] = temp.detach()
ctx.input_tensors[i].requires_grad = temp.requires_grad
with t.enable_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True)
del ctx.input_tensors
del output_tensors
return (None, None) + input_grads