# Simple gradient checkpointing. Works with distributed data parallel import torch as t def checkpoint(func, inputs, params, flag): if flag: args = inputs + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(t.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with t.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): for i in range(len(ctx.input_tensors)): temp = ctx.input_tensors[i] ctx.input_tensors[i] = temp.detach() ctx.input_tensors[i].requires_grad = temp.requires_grad with t.enable_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) del ctx.input_tensors del output_tensors return (None, None) + input_grads