Spaces:
Running
Running
File size: 13,368 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
from typing import Any
import torch
from torch.utils._contextlib import (
_DecoratorContextManager,
_NoParamDecoratorContextManager,
F,
)
__all__ = [
"no_grad",
"enable_grad",
"set_grad_enabled",
"inference_mode",
"set_multithreading_enabled",
]
class no_grad(_NoParamDecoratorContextManager):
r"""Context-manager that disables gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
There is an exception! All factory functions, or functions that create
a new Tensor and take a requires_grad kwarg, will NOT be affected by
this mode.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
No-grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
If you want to disable forward AD for a computation, you can unpack
your dual tensors.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # factory function exception
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)
class enable_grad(_NoParamDecoratorContextManager):
r"""Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
or :class:`~set_grad_enabled`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
enable_grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
tensor([2.])
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad
... def tripler(x):
... return x * 3
>>> with torch.no_grad():
... z = tripler(x)
>>> z.requires_grad
True
"""
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch._C._set_grad_enabled(True)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
class set_grad_enabled(_DecoratorContextManager):
r"""Context-manager that sets gradient calculation on or off.
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable grad (``True``), or disable
(``False``). This can be used to conditionally enable
gradients.
.. note::
set_grad_enabled is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> _ = torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> _ = torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
"""
def __init__(self, mode: bool) -> None:
self.prev = torch.is_grad_enabled()
self.mode = mode
torch._C._set_grad_enabled(mode)
def __call__(self, orig_func: F) -> F:
torch._C._set_grad_enabled(self.prev)
return super().__call__(orig_func)
def __enter__(self) -> None:
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_grad_enabled(self.prev)
def clone(self) -> "set_grad_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class inference_mode(_DecoratorContextManager):
r"""Context-manager that enables or disables inference mode.
InferenceMode is a new context manager analogous to :class:`~no_grad`
to be used when you are certain your operations will have no interactions
with autograd (e.g., model training). Code run under this mode gets better
performance by disabling view tracking and version counter bumps. Note that
unlike some other mechanisms that locally enable or disable grad,
entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator.
.. note::
Inference mode is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
Args:
mode (bool or function): Either a boolean flag whether to enable or
disable inference mode or a Python function to decorate with
inference mode enabled
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD)
>>> import torch
>>> x = torch.ones(1, 2, 3, requires_grad=True)
>>> with torch.inference_mode():
... y = x * x
>>> y.requires_grad
False
>>> # xdoctest: +SKIP("want string isnt quite right")
>>> y._version
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Inference tensors do not track version counter.
>>> @torch.inference_mode()
... def func(x):
... return x * x
>>> out = func(x)
>>> out.requires_grad
False
>>> @torch.inference_mode
... def doubler(x):
... return x * 2
>>> out = doubler(x)
>>> out.requires_grad
False
"""
def __init__(self, mode: bool = True) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.mode = mode
def __new__(cls, mode=True):
if isinstance(mode, bool):
return super().__new__(cls)
return cls()(mode)
def __enter__(self) -> None:
self._inference_mode_context = torch._C._InferenceMode(self.mode)
self._inference_mode_context.__enter__()
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
def clone(self) -> "inference_mode":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
def _enter_inference_mode(mode):
mode_context = torch._C._InferenceMode(mode)
mode_context.__enter__()
return mode_context
def _exit_inference_mode(mode):
mode.__exit__(None, None, None)
class set_multithreading_enabled(_DecoratorContextManager):
r"""Context-manager that sets multithreaded backwards on or off.
``set_multithreading_enabled`` will enable or disable multithreaded backwards based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
Args:
mode (bool): Flag whether to enable multithreaded backwards (``True``), or disable
(``False``).
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_multithreading_enabled()
torch._C._set_multithreading_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_multithreading_enabled(self.prev)
def clone(self) -> "set_multithreading_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
class _force_original_view_tracking(_DecoratorContextManager):
r"""Context-manager that sets whether or not to always enable view-replay in autograd.
``set_view_replay_enabled`` will enable or disable view-replay based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
This context manager is thread local; it will not affect computation
in other threads.
When a tensor view is mutated, the autograd engine needs to decide whether or not
to regenerate the "updated view" by either replaying the chain of views from the updated base,
or with a single call to as_strided.
If set_view_replay_enabled is set to True, then autograd will always use view replay.
Otherwise, it will fall back to its existing logic.
Args:
mode (bool): Flag whether to enable view-replay (``True``), or disable
(``False``).
"""
def __init__(self, mode: bool) -> None:
self.prev = torch._C._is_view_replay_enabled()
torch._C._set_view_replay_enabled(mode)
self.mode = mode
def __enter__(self) -> None:
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch._C._set_view_replay_enabled(self.prev)
def clone(self):
return self.__class__(self.mode)
class _unsafe_preserve_version_counter(_DecoratorContextManager):
r"""DO NOT USE THIS UNLESS YOU KNOW EXACTLY WHAT YOU'RE DOING.
This context manager can lead to arbitrary silent-correctness issues in any other part of your code
(even the ones not touched directly by the context manager)!
Ordinarily, autograd will track mutations to tensors by incrementing it's `._version` attribute.
This is generally important for correctness, as for example, mutating a tensor that autograd has saved
for the backwards pass can result in incorrect gradients, and autograd uses the version counter to detect
and error out in this situation.
However, there are rare instances where it might be useful to hide mutations from autograd. For example:
if a tensor is very large, and you'd like to free its memory by storing it elsewhere, and re-populate
the tensor right before it is needed by autograd.
Args:
tensor (torch.Tensor): the tensor in question, that you would like to preserve the version counter of.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
"""
def __init__(self, tensor: torch.Tensor) -> None:
self.tensor = tensor
self.prev_version = tensor._version
def __enter__(self) -> None:
pass
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensor, self.prev_version)
|