Spaces:
Runtime error
Runtime error
| # Simple gradient checkpointing. Works with distributed data parallel | |
| import torch as t | |
| def checkpoint(func, inputs, params, flag): | |
| if flag: | |
| args = inputs + tuple(params) | |
| return CheckpointFunction.apply(func, len(inputs), *args) | |
| else: | |
| return func(*inputs) | |
| class CheckpointFunction(t.autograd.Function): | |
| def forward(ctx, run_function, length, *args): | |
| ctx.run_function = run_function | |
| ctx.input_tensors = list(args[:length]) | |
| ctx.input_params = list(args[length:]) | |
| with t.no_grad(): | |
| output_tensors = ctx.run_function(*ctx.input_tensors) | |
| return output_tensors | |
| def backward(ctx, *output_grads): | |
| for i in range(len(ctx.input_tensors)): | |
| temp = ctx.input_tensors[i] | |
| ctx.input_tensors[i] = temp.detach() | |
| ctx.input_tensors[i].requires_grad = temp.requires_grad | |
| with t.enable_grad(): | |
| output_tensors = ctx.run_function(*ctx.input_tensors) | |
| input_grads = t.autograd.grad(output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True) | |
| del ctx.input_tensors | |
| del output_tensors | |
| return (None, None) + input_grads | |