File size: 10,767 Bytes
22d5f88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Provides some extra utilities around torch compile, in particular with a way
to fully deactivate it easily with a context manager.
Provides a simple activation checkpointing that is compatible with FSDP and torch compile.
Finally, provides some utilities for CUDA graphing functions.
"""
from contextlib import contextmanager
from functools import wraps
import inspect
import os
import typing as tp
import torch
from torch import cuda
_compile_disabled: bool = False
@contextmanager
def no_compile():
"""Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
global _compile_disabled
prev_disabled = _compile_disabled
_compile_disabled = True
try:
yield
finally:
_compile_disabled = prev_disabled
def torch_compile_lazy(fun):
"""torch.compile creates a huge pool of processes, even when not using the function at all,
e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
"""
if os.environ.get("NO_TORCH_COMPILE"):
return fun
fun_compiled = None
@wraps(fun)
def _wrapped(*args, **kwargs):
nonlocal fun_compiled
if _compile_disabled:
return fun(*args, **kwargs)
if fun_compiled is None:
fun_compiled = torch.compile(fun)
return fun_compiled(*args, **kwargs)
return _wrapped
class Checkpoint(torch.autograd.Function):
@staticmethod
def forward(ctx, function, *args) -> tp.Any:
to_save = []
ctx.others = []
ctx.function = function
# Sources will indicate whether the arg in position N is
# a tensor stored in ctx.save_for_backward, or inside ctx.others.
ctx.sources = []
new_args = []
for arg in args:
if isinstance(arg, torch.Tensor):
to_save.append(arg)
ctx.sources.append("tensor")
new_args.append(arg.detach())
else:
ctx.sources.append("other")
ctx.others.append(arg)
new_args.append(arg)
ctx.save_for_backward(*to_save)
# During the forward, we just make a pass with no gradient computed.
with torch.no_grad():
res = function(*new_args)
return res
@staticmethod
def backward(ctx, *grads) -> tp.Tuple[tp.Optional[torch.Tensor], ...]:
pseudo_tensors = []
with torch.set_grad_enabled(True):
# We create leaf tensors to collect the output gradients.
# We call them pseudo_tensors because they are pretending to be the input
# to `function` but are not directly
for tensor in ctx.saved_tensors:
pseudo_tensor = tensor.detach()
pseudo_tensor.requires_grad_(True)
pseudo_tensors.append(pseudo_tensor)
pseudo_tensors_copy = list(pseudo_tensors)
args = []
for source in ctx.sources:
if source == "other":
args.append(ctx.others.pop(0))
else:
assert source == "tensor"
args.append(pseudo_tensors_copy.pop(0))
res = ctx.function(*args)
# The second forward with grad computation allows us to connect the input leaf tensors
# inside pseudo_tensors, to the outputs of the function called.
if not isinstance(res, tuple):
res = (res,)
# Now we just ask Torch to compute the derivative of `res` given the gradient coming from above
# `grads`. The computed gradient will end up into the `pseudo_tensors` grad attributes.
torch.autograd.backward(res, grads)
out: tp.List[tp.Optional[torch.Tensor]] = [None]
for source in ctx.sources:
# We still need to output `None` values for non tensor parameters.
if source == "other":
out.append(None)
else:
assert source == "tensor"
out.append(pseudo_tensors.pop(0).grad)
return tuple(out)
def simple_checkpoint(module: torch.nn.Module, *args, **kwargs):
"""Custom implementation of checkpointing in PyTorch as the builtin implementation is broken
when using torch compile. Only supports wrapping a `nn.Module` with a forward with no `*args` or `**kwargs`.
https://github.com/pytorch/pytorch/issues/97436.
Should be resolved in nightlies, but it is quite fun and simple to code it ourselves.
"""
if hasattr(module, "_fsdp_wrapped_module"):
module_for_sig = module._fsdp_wrapped_module
else:
module_for_sig = module
sig = inspect.signature(module_for_sig.forward)
# We first flatten all arguments to use only *args, to make things easier and because
# torch.autograd.Function has weird support for kwargs.
bounded = sig.bind(*args, **kwargs)
new_args = []
for name, param in sig.parameters.items():
if param.kind in {
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
}:
raise RuntimeError("simple_checkpoint doesn't support var args.")
if name not in bounded.arguments:
break
new_args.append(bounded.arguments[name])
return Checkpoint.apply(module, *new_args)
_in_cuda_graph = False
_disable_cuda_graph = False
def in_cuda_graph() -> bool:
"""Indicate whether we are in a function that is CUDA Graphed (or will be soon)."""
return _in_cuda_graph
@contextmanager
def _set_in_cuda_graph():
global _in_cuda_graph
assert not _in_cuda_graph
_in_cuda_graph = True
try:
yield
finally:
_in_cuda_graph = False
def _is_cuda_graph_enabled() -> bool:
if _disable_cuda_graph:
return False
no_cuda_graph = os.environ.get("NO_CUDA_GRAPH", "")
if no_cuda_graph.lower() not in {"0", "no", "n", ""}:
return False
return True
@contextmanager
def no_cuda_graph():
"""Deactivate CUDA Graphing for all the calls in this context manager."""
global _disable_cuda_graph
old_value = _disable_cuda_graph
_disable_cuda_graph = True
try:
yield
finally:
_disable_cuda_graph = old_value
class CUDAGraphed:
"""Allow simple CUDA Graphing of a function.
Args:
func: callable, taking any number of arguments. Its tensors arguments should
be top level args, not nested in structures (tuples, dicts, etc). Keyword
arguments are NOT supported for simplicity.
warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
allows torch.compiled functions to get properly compiled.
disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
"""
def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
self.func = func
self.warmup_steps = warmup_steps
self.disable = disable
self._graph: cuda.CUDAGraph | None = None
self._output: tuple | None = None
self._args: tuple | None = None
def reset(self, warmup_steps: int = 0) -> None:
"""Reset the state, meaning the next call we get CUDA Graphed again. Useful if some
shapes have changed, or external state (e.g. KVCache) has changed."""
self.warmup_steps = warmup_steps
self._graph = None
self._output = None
self._args = None
def __call__(self, *args, **kwargs) -> tp.Any:
if kwargs:
raise RuntimeError("Named arguments not supported for now.")
if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
return self.func(*args, **kwargs)
def _clone_tensors(args: tuple) -> tuple:
out: list = []
for arg in args:
if isinstance(arg, torch.Tensor):
arg = arg.clone()
out.append(arg)
return tuple(out)
def _match_values_copy_tensors(args: tuple, target_args: tuple) -> None:
if len(args) != len(target_args):
raise ValueError(
f"Expected {len(target_args)}, but got {args} for CUDA Graphed function."
)
for idx, (source, target) in enumerate(zip(args, target_args)):
if isinstance(target, torch.Tensor):
if not isinstance(source, torch.Tensor):
raise ValueError(
f"Argument #{idx} was a tensor, and is no longer (now {source})."
)
if source.shape != target.shape:
raise ValueError(
f"Argument #{idx} had shape {target.shape}, but got shae {source.shape}"
)
target.copy_(source)
else:
if isinstance(source, torch.Tensor):
raise ValueError(
f"Argument #{idx} was not a tensor {target}, but is now one."
)
if source is not target and source != target:
raise ValueError(
f"Argument #{idx} changed value from {target} to {source}."
)
with _set_in_cuda_graph():
# Prevent any one under us to try and CUDA Graph things.
if self._graph is None:
if self.warmup_steps <= 0:
self._graph = cuda.CUDAGraph()
# Making a copy just to ensure those are not used else where.
self._args = _clone_tensors(args)
with cuda.graph(self._graph):
self._output = self.func(*self._args)
# At this point nothing really happened, so we have to make it run for real.
self._graph.replay()
return self._output
else:
self.warmup_steps -= 1
return self.func(*args)
else:
assert self._args is not None
assert self._output is not None
_match_values_copy_tensors(args, self._args)
self._graph.replay()
return self._output
def cuda_graph(func: tp.Callable, warmup_steps: int = 1):
"""Just calls `CUDAGraphed` on the given function."""
if not _is_cuda_graph_enabled():
return func
return CUDAGraphed(func, warmup_steps)
|