Spaces:
Running
Running
File size: 5,311 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import collections
import functools
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
from typing import Any
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(torch.amp.autocast_mode.autocast):
r"""See :class:`torch.autocast`.
``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
"""
def __init__(
self,
enabled: bool = True,
dtype: torch.dtype = torch.float16,
cache_enabled: bool = True,
):
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = "cuda"
self.fast_dtype = dtype
return
super().__init__(
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
if torch._jit_internal.is_scripting():
return self
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return super().__call__(func)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_cuda
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
# custom_fwd is a decorator that may or may not be used with arguments, following
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
# this works:
# @custom_fwd
# def forward(...):
# this also works:
# @custom_fwd(cast_inputs=torch.float)
# def forward(...):
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_gpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd
|