Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
import torch | |
from torch._vmap_internals import _vmap | |
from . import forward_ad as fwAD | |
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"] | |
# Utility functions | |
def _as_tuple_nocheck(x): | |
if isinstance(x, tuple): | |
return x | |
elif isinstance(x, list): | |
return tuple(x) | |
else: | |
return (x,) | |
def _as_tuple(inp, arg_name=None, fn_name=None): | |
# Ensures that inp is a tuple of Tensors | |
# Returns whether or not the original inp was a tuple and the tupled version of the input | |
if arg_name is None and fn_name is None: | |
return _as_tuple_nocheck(inp) | |
is_inp_tuple = True | |
if not isinstance(inp, tuple): | |
inp = (inp,) | |
is_inp_tuple = False | |
for i, el in enumerate(inp): | |
if not isinstance(el, torch.Tensor): | |
if is_inp_tuple: | |
raise TypeError( | |
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" | |
f" value at index {i} has type {type(el)}." | |
) | |
else: | |
raise TypeError( | |
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the" | |
f" given {arg_name} has type {type(el)}." | |
) | |
return is_inp_tuple, inp | |
def _tuple_postprocess(res, to_unpack): | |
# Unpacks a potentially nested tuple of Tensors | |
# to_unpack should be a single boolean or a tuple of two booleans. | |
# It is used to: | |
# - invert _as_tuple when res should match the inp given to _as_tuple | |
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple | |
if isinstance(to_unpack, tuple): | |
assert len(to_unpack) == 2 | |
if not to_unpack[1]: | |
res = tuple(el[0] for el in res) | |
if not to_unpack[0]: | |
res = res[0] | |
else: | |
if not to_unpack: | |
res = res[0] | |
return res | |
def _grad_preprocess(inputs, create_graph, need_graph): | |
# Preprocess the inputs to make sure they require gradient | |
# inputs is a tuple of Tensors to preprocess | |
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs | |
# need_graph specifies if we internally want gradients to flow back to the Tensors in res | |
# Note that we *always* create a new Tensor object to be able to see the difference between | |
# inputs given as arguments and the same Tensors automatically captured by the user function. | |
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576 | |
res = [] | |
for inp in inputs: | |
if create_graph and inp.requires_grad: | |
# Create at least a new Tensor object in a differentiable way | |
if not inp.is_sparse: | |
# Use .view_as() to get a shallow copy | |
res.append(inp.view_as(inp)) | |
else: | |
# We cannot use view for sparse Tensors so we clone | |
res.append(inp.clone()) | |
else: | |
res.append(inp.detach().requires_grad_(need_graph)) | |
return tuple(res) | |
def _grad_postprocess(inputs, create_graph): | |
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not | |
# request it. | |
if isinstance(inputs[0], torch.Tensor): | |
if not create_graph: | |
return tuple(inp.detach() for inp in inputs) | |
else: | |
return inputs | |
else: | |
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs) | |
def _validate_v(v, other, is_other_tuple): | |
# This assumes that other is the correct shape, and v should match | |
# Both are assumed to be tuples of Tensors | |
if len(other) != len(v): | |
if is_other_tuple: | |
raise RuntimeError( | |
f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}." | |
) | |
else: | |
raise RuntimeError("The given v should contain a single Tensor.") | |
for idx, (el_v, el_other) in enumerate(zip(v, other)): | |
if el_v.size() != el_other.size(): | |
prepend = "" | |
if is_other_tuple: | |
prepend = f"Entry {idx} in " | |
raise RuntimeError( | |
f"{prepend}v has invalid size: should be {el_other.size()} but got {el_v.size()}." | |
) | |
def _check_requires_grad(inputs, input_type, strict): | |
# Used to make all the necessary checks to raise nice errors in strict mode. | |
if not strict: | |
return | |
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]: | |
raise RuntimeError("Invalid input_type to _check_requires_grad") | |
for i, inp in enumerate(inputs): | |
if inp is None: | |
# This can only be reached for grad_inputs. | |
raise RuntimeError( | |
f"The output of the user-provided function is independent of input {i}." | |
" This is not allowed in strict mode." | |
) | |
if not inp.requires_grad: | |
if input_type == "hessian": | |
raise RuntimeError( | |
f"The hessian of the user-provided function with respect to input {i}" | |
" is independent of the input. This is not allowed in strict mode." | |
" You should ensure that your function is thrice differentiable and that" | |
" the hessian depends on the inputs." | |
) | |
elif input_type == "jacobian": | |
raise RuntimeError( | |
"While computing the hessian, found that the jacobian of the user-provided" | |
f" function with respect to input {i} is independent of the input. This is not" | |
" allowed in strict mode. You should ensure that your function is twice" | |
" differentiable and that the jacobian depends on the inputs (this would be" | |
" violated by a linear function for example)." | |
) | |
elif input_type == "grad_inputs": | |
raise RuntimeError( | |
f"The gradient with respect to input {i} is independent of the inputs of the" | |
" user-provided function. This is not allowed in strict mode." | |
) | |
else: | |
raise RuntimeError( | |
f"Output {i} of the user-provided function does not require gradients." | |
" The outputs must be computed in a differentiable manner from the input" | |
" when running in strict mode." | |
) | |
def _autograd_grad( | |
outputs, | |
inputs, | |
grad_outputs=None, | |
create_graph=False, | |
retain_graph=None, | |
is_grads_batched=False, | |
): | |
# Version of autograd.grad that accepts `None` in outputs and do not compute gradients for them. | |
# This has the extra constraint that inputs has to be a tuple | |
assert isinstance(outputs, tuple) | |
if grad_outputs is None: | |
grad_outputs = (None,) * len(outputs) | |
assert isinstance(grad_outputs, tuple) | |
assert len(outputs) == len(grad_outputs) | |
new_outputs: Tuple[torch.Tensor, ...] = tuple() | |
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple() | |
for out, grad_out in zip(outputs, grad_outputs): | |
if out is not None and out.requires_grad: | |
new_outputs += (out,) | |
new_grad_outputs += (grad_out,) | |
if len(new_outputs) == 0: | |
# No differentiable output, we don't need to call the autograd engine | |
return (None,) * len(inputs) | |
else: | |
return torch.autograd.grad( | |
new_outputs, | |
inputs, | |
new_grad_outputs, | |
allow_unused=True, | |
create_graph=create_graph, | |
retain_graph=retain_graph, | |
is_grads_batched=is_grads_batched, | |
) | |
def _fill_in_zeros(grads, refs, strict, create_graph, stage): | |
# Used to detect None in the grads and depending on the flags, either replace them | |
# with Tensors full of 0s of the appropriate size based on the refs or raise an error. | |
# strict and create graph allow us to detect when it is appropriate to raise an error | |
# stage gives us information of which backward call we consider to give good error message | |
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]: | |
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros") | |
res: Tuple[torch.Tensor, ...] = tuple() | |
for i, grads_i in enumerate(grads): | |
if grads_i is None: | |
if strict: | |
if stage == "back": | |
raise RuntimeError( | |
"The output of the user-provided function is independent of " | |
f"input {i}. This is not allowed in strict mode." | |
) | |
elif stage == "back_trick": | |
raise RuntimeError( | |
f"The gradient with respect to the input is independent of entry {i}" | |
" in the grad_outputs when using the double backward trick to compute" | |
" forward mode gradients. This is not allowed in strict mode." | |
) | |
elif stage == "double_back": | |
raise RuntimeError( | |
"The jacobian of the user-provided function is independent of " | |
f"input {i}. This is not allowed in strict mode." | |
) | |
else: | |
raise RuntimeError( | |
"The hessian of the user-provided function is independent of " | |
f"entry {i} in the grad_jacobian. This is not allowed in strict " | |
"mode as it prevents from using the double backward trick to " | |
"replace forward mode AD." | |
) | |
grads_i = torch.zeros_like(refs[i]) | |
else: | |
if strict and create_graph and not grads_i.requires_grad: | |
if "double" not in stage: | |
raise RuntimeError( | |
"The jacobian of the user-provided function is independent of " | |
f"input {i}. This is not allowed in strict mode when create_graph=True." | |
) | |
else: | |
raise RuntimeError( | |
"The hessian of the user-provided function is independent of " | |
f"input {i}. This is not allowed in strict mode when create_graph=True." | |
) | |
res += (grads_i,) | |
return res | |
# Public API | |
def vjp(func, inputs, v=None, create_graph=False, strict=False): | |
r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a tuple of Tensors or a Tensor. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
v (tuple of Tensors or Tensor): The vector for which the vector | |
Jacobian product is computed. Must be the same size as the output | |
of ``func``. This argument is optional when the output of ``func`` | |
contains a single element and (if it is not provided) will be set | |
as a Tensor containing a single ``1``. | |
create_graph (bool, optional): If ``True``, both the output and result | |
will be computed in a differentiable way. Note that when ``strict`` | |
is ``False``, the result can not require gradients or be | |
disconnected from the inputs. Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we | |
detect that there exists an input such that all the outputs are | |
independent of it. If ``False``, we return a Tensor of zeros as the | |
vjp for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
Returns: | |
output (tuple): tuple with: | |
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` | |
vjp (tuple of Tensors or Tensor): result of the dot product with | |
the same shape as the inputs. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def exp_reducer(x): | |
... return x.exp().sum(dim=1) | |
>>> inputs = torch.rand(4, 4) | |
>>> v = torch.ones(4) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> vjp(exp_reducer, inputs, v) | |
(tensor([5.7817, 7.2458, 5.7830, 6.7782]), | |
tensor([[1.4458, 1.3962, 1.3042, 1.6354], | |
[2.1288, 1.0652, 1.5483, 2.5035], | |
[2.2046, 1.1292, 1.1432, 1.3059], | |
[1.3225, 1.6652, 1.7753, 2.0152]])) | |
>>> vjp(exp_reducer, inputs, v, create_graph=True) | |
(tensor([5.7817, 7.2458, 5.7830, 6.7782], grad_fn=<SumBackward1>), | |
tensor([[1.4458, 1.3962, 1.3042, 1.6354], | |
[2.1288, 1.0652, 1.5483, 2.5035], | |
[2.2046, 1.1292, 1.1432, 1.3059], | |
[1.3225, 1.6652, 1.7753, 2.0152]], grad_fn=<MulBackward0>)) | |
>>> def adder(x, y): | |
... return 2 * x + 3 * y | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> v = torch.ones(2) | |
>>> vjp(adder, inputs, v) | |
(tensor([2.4225, 2.3340]), | |
(tensor([2., 2.]), tensor([3., 3.]))) | |
""" | |
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple( | |
outputs, "outputs of the user-provided function", "vjp" | |
) | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
if v is not None: | |
_, v = _as_tuple(v, "v", "vjp") | |
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) | |
_validate_v(v, outputs, is_outputs_tuple) | |
else: | |
if len(outputs) != 1 or outputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The vector v can only be None if the " | |
"user-provided function returns " | |
"a single Tensor with a single element." | |
) | |
enable_grad = True if create_graph else torch.is_grad_enabled() | |
with torch.set_grad_enabled(enable_grad): | |
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph) | |
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back") | |
# Cleanup objects and return them to the user | |
outputs = _grad_postprocess(outputs, create_graph) | |
vjp = _grad_postprocess(vjp, create_graph) | |
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( | |
vjp, is_inputs_tuple | |
) | |
def jvp(func, inputs, v=None, create_graph=False, strict=False): | |
r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a tuple of Tensors or a Tensor. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
v (tuple of Tensors or Tensor): The vector for which the Jacobian | |
vector product is computed. Must be the same size as the input of | |
``func``. This argument is optional when the input to ``func`` | |
contains a single element and (if it is not provided) will be set | |
as a Tensor containing a single ``1``. | |
create_graph (bool, optional): If ``True``, both the output and result | |
will be computed in a differentiable way. Note that when ``strict`` | |
is ``False``, the result can not require gradients or be | |
disconnected from the inputs. Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we | |
detect that there exists an input such that all the outputs are | |
independent of it. If ``False``, we return a Tensor of zeros as the | |
jvp for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
Returns: | |
output (tuple): tuple with: | |
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` | |
jvp (tuple of Tensors or Tensor): result of the dot product with | |
the same shape as the output. | |
Note: | |
``autograd.functional.jvp`` computes the jvp by using the backward of | |
the backward (sometimes called the double backwards trick). This is not | |
the most performant way of computing the jvp. Please consider using | |
:func:`torch.func.jvp` or the | |
:ref:`low-level forward-mode AD API <forward-mode-ad>` instead. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def exp_reducer(x): | |
... return x.exp().sum(dim=1) | |
>>> inputs = torch.rand(4, 4) | |
>>> v = torch.ones(4, 4) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> jvp(exp_reducer, inputs, v) | |
(tensor([6.3090, 4.6742, 7.9114, 8.2106]), | |
tensor([6.3090, 4.6742, 7.9114, 8.2106])) | |
>>> jvp(exp_reducer, inputs, v, create_graph=True) | |
(tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SumBackward1>), | |
tensor([6.3090, 4.6742, 7.9114, 8.2106], grad_fn=<SqueezeBackward1>)) | |
>>> def adder(x, y): | |
... return 2 * x + 3 * y | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> v = (torch.ones(2), torch.ones(2)) | |
>>> jvp(adder, inputs, v) | |
(tensor([2.2399, 2.5005]), | |
tensor([5., 5.])) | |
""" | |
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
if v is not None: | |
_, v = _as_tuple(v, "v", "jvp") | |
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) | |
_validate_v(v, inputs, is_inputs_tuple) | |
else: | |
if len(inputs) != 1 or inputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The vector v can only be None if the input to " | |
"the user-provided function is a single Tensor " | |
"with a single element." | |
) | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple( | |
outputs, "outputs of the user-provided function", "jvp" | |
) | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
# The backward is linear so the value of grad_outputs is not important as | |
# it won't appear in the double backward graph. We only need to ensure that | |
# it does not contain inf or nan. | |
grad_outputs = tuple( | |
torch.zeros_like(out, requires_grad=True) for out in outputs | |
) | |
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True) | |
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict) | |
if create_graph: | |
with torch.enable_grad(): | |
grad_res = _autograd_grad( | |
grad_inputs, grad_outputs, v, create_graph=create_graph | |
) | |
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") | |
else: | |
grad_res = _autograd_grad( | |
grad_inputs, grad_outputs, v, create_graph=create_graph | |
) | |
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick") | |
# Cleanup objects and return them to the user | |
outputs = _grad_postprocess(outputs, create_graph) | |
jvp = _grad_postprocess(jvp, create_graph) | |
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( | |
jvp, is_outputs_tuple | |
) | |
def _construct_standard_basis_for( | |
tensors: Tuple[torch.Tensor, ...], tensor_numels: Tuple[int, ...] | |
) -> Tuple[torch.Tensor, ...]: | |
# This function: | |
# - constructs a N=sum(tensor_numels) standard basis. i.e. an NxN identity matrix. | |
# - Splits the identity matrix into chunks with each chunk size determined by `tensor_numels`. | |
# - Each chunk corresponds to one tensor. The chunk has the same dtype and | |
# device as the tensor | |
# | |
# For example, with tensor_numels = [1, 2, 1], this function returns: | |
# ( tensor([[1], tensor([[0, 0], tensor([[0], | |
# [0], [1, 0], [0], | |
# [0], [0, 1], [0], | |
# [0]]) , [0, 0]]) , [1]]) ) | |
# | |
# Precondition: tensor_numels == tuple(tensor.numel() for tensor in tensors) | |
# Precondition: tensors always has at least one element. | |
# | |
# See NOTE: [Computing jacobian with vmap and grad for multiple tensors] | |
# for context behind this function. All the pre-conditions are guarded for | |
# in torch.autograd.functional.jacobian. | |
assert len(tensors) == len(tensor_numels) | |
assert len(tensors) > 0 | |
total_numel = sum(tensor_numels) | |
chunks = tuple( | |
tensor.new_zeros(total_numel, tensor_numel) | |
for tensor, tensor_numel in zip(tensors, tensor_numels) | |
) | |
diag_start_idx = 0 | |
for chunk, numel in zip(chunks, tensor_numels): | |
chunk.diagonal(diag_start_idx).fill_(1) | |
diag_start_idx -= numel | |
return chunks | |
def _jacfwd(func, inputs, strict=False, vectorize=False): | |
if strict: | |
raise RuntimeError( | |
"torch.autograd.functional.jacobian: `strict=True` " | |
'and `strategy="forward-mode"` are not supported together (yet). ' | |
"Please either set `strict=False` or " | |
'`strategy="reverse-mode"`.' | |
) | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") | |
output_info = [] | |
if vectorize: | |
# See NOTE: [Computing jacobian with vmap and grad for multiple outputs] | |
input_numels = tuple(input.numel() for input in inputs) | |
# Step 1: Prepare tangents | |
tangents = _construct_standard_basis_for(inputs, input_numels) | |
# Step 2: Compute vmap over computation with dual tensors | |
def jvp(tangents): | |
with fwAD.dual_level(): | |
dual_inputs = tuple( | |
fwAD.make_dual(input, tangent.view_as(input)) | |
for input, tangent in zip(inputs, tangents) | |
) | |
_is_outputs_tuple, dual_outputs = _as_tuple( | |
func(*dual_inputs), "outputs" | |
) | |
output_info.append(_is_outputs_tuple) | |
jv = [] | |
primal_outs = [] | |
for dual_out in dual_outputs: | |
primal, tangent = fwAD.unpack_dual(dual_out) | |
primal_outs.append(primal) | |
if tangent is not None: | |
jv.append(tangent) | |
else: | |
jv.append(torch.zeros_like(primal)) | |
output_info.append(primal_outs) | |
return tuple(jv) | |
outputs_before_split = _vmap(jvp)(tangents) | |
is_outputs_tuple, outputs = output_info | |
# Step 3: for each of the output tangents, split along dim 0 | |
jacobian_input_output = [] | |
for jac_output_i, output_i in zip(outputs_before_split, outputs): | |
jacobian_output_i_output = [] | |
for jac, input_j in zip(jac_output_i.split(input_numels, dim=0), inputs): | |
# We need to transpose the Jacobian because in forward AD, the | |
# batch dimension represents that of the inputs | |
jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( | |
(*output_i.shape, *input_j.shape) | |
) # noqa: C409 | |
jacobian_output_i_output.append(jacobian_input_i_output_j) | |
jacobian_input_output.append(jacobian_output_i_output) | |
# Omit [Step 4] because everything is already transposed w/ forward AD | |
return _tuple_postprocess( | |
jacobian_input_output, (is_outputs_tuple, is_inputs_tuple) | |
) | |
else: | |
raise NotImplementedError( | |
"Computing Jacobian using forward-AD or forward-over-reverse Hessian is" | |
"only implemented for `vectorize=True`." | |
) | |
def jacobian( | |
func, | |
inputs, | |
create_graph=False, | |
strict=False, | |
vectorize=False, | |
strategy="reverse-mode", | |
): | |
r"""Compute the Jacobian of a given function. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a tuple of Tensors or a Tensor. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
create_graph (bool, optional): If ``True``, the Jacobian will be | |
computed in a differentiable manner. Note that when ``strict`` is | |
``False``, the result can not require gradients or be disconnected | |
from the inputs. Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we | |
detect that there exists an input such that all the outputs are | |
independent of it. If ``False``, we return a Tensor of zeros as the | |
jacobian for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
vectorize (bool, optional): This feature is experimental. | |
Please consider using :func:`torch.func.jacrev` or | |
:func:`torch.func.jacfwd` instead if you are looking for something | |
less experimental and more performant. | |
When computing the jacobian, usually we invoke | |
``autograd.grad`` once per row of the jacobian. If this flag is | |
``True``, we perform only a single ``autograd.grad`` call with | |
``batched_grad=True`` which uses the vmap prototype feature. | |
Though this should lead to performance improvements in many cases, | |
because this feature is still experimental, there may be performance | |
cliffs. See :func:`torch.autograd.grad`'s ``batched_grad`` parameter for | |
more information. | |
strategy (str, optional): Set to ``"forward-mode"`` or ``"reverse-mode"`` to | |
determine whether the Jacobian will be computed with forward or reverse | |
mode AD. Currently, ``"forward-mode"`` requires ``vectorized=True``. | |
Defaults to ``"reverse-mode"``. If ``func`` has more outputs than | |
inputs, ``"forward-mode"`` tends to be more performant. Otherwise, | |
prefer to use ``"reverse-mode"``. | |
Returns: | |
Jacobian (Tensor or nested tuple of Tensors): if there is a single | |
input and output, this will be a single Tensor containing the | |
Jacobian for the linearized inputs and output. If one of the two is | |
a tuple, then the Jacobian will be a tuple of Tensors. If both of | |
them are tuples, then the Jacobian will be a tuple of tuple of | |
Tensors where ``Jacobian[i][j]`` will contain the Jacobian of the | |
``i``\th output and ``j``\th input and will have as size the | |
concatenation of the sizes of the corresponding output and the | |
corresponding input and will have same dtype and device as the | |
corresponding input. If strategy is ``forward-mode``, the dtype will be | |
that of the output; otherwise, the input. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def exp_reducer(x): | |
... return x.exp().sum(dim=1) | |
>>> inputs = torch.rand(2, 2) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> jacobian(exp_reducer, inputs) | |
tensor([[[1.4917, 2.4352], | |
[0.0000, 0.0000]], | |
[[0.0000, 0.0000], | |
[2.4369, 2.3799]]]) | |
>>> jacobian(exp_reducer, inputs, create_graph=True) | |
tensor([[[1.4917, 2.4352], | |
[0.0000, 0.0000]], | |
[[0.0000, 0.0000], | |
[2.4369, 2.3799]]], grad_fn=<ViewBackward>) | |
>>> def exp_adder(x, y): | |
... return 2 * x.exp() + 3 * y | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> jacobian(exp_adder, inputs) | |
(tensor([[2.8052, 0.0000], | |
[0.0000, 3.3963]]), | |
tensor([[3., 0.], | |
[0., 3.]])) | |
""" | |
assert strategy in ("forward-mode", "reverse-mode"), ( | |
'Expected strategy to be either "forward-mode" or "reverse-mode". Hint: If your ' | |
'function has more outputs than inputs, "forward-mode" tends to be more performant. ' | |
'Otherwise, prefer to use "reverse-mode".' | |
) | |
if strategy == "forward-mode": | |
if create_graph: | |
raise NotImplementedError( | |
"torch.autograd.functional.jacobian: `create_graph=True` " | |
'and `strategy="forward-mode"` are not supported together (yet). ' | |
"Please either set `create_graph=False` or " | |
'`strategy="reverse-mode"`.' | |
) | |
return _jacfwd(func, inputs, strict, vectorize) | |
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jacobian") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple( | |
outputs, "outputs of the user-provided function", "jacobian" | |
) | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
if vectorize: | |
if strict: | |
raise RuntimeError( | |
"torch.autograd.functional.jacobian: `strict=True` " | |
"and `vectorized=True` are not supported together. " | |
"Please either set `strict=False` or " | |
"`vectorize=False`." | |
) | |
# NOTE: [Computing jacobian with vmap and grad for multiple outputs] | |
# | |
# Let's consider f(x) = (x**2, x.sum()) and let x = torch.randn(3). | |
# It turns out we can compute the jacobian of this function with a single | |
# call to autograd.grad by using vmap over the correct grad_outputs. | |
# | |
# Firstly, one way to compute the jacobian is to stack x**2 and x.sum() | |
# into a 4D vector. E.g., use g(x) = torch.stack([x**2, x.sum()]) | |
# | |
# To get the first row of the jacobian, we call | |
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([1, 0, 0, 0])) | |
# To get the 2nd row of the jacobian, we call | |
# >>> autograd.grad(g(x), x, grad_outputs=torch.tensor([0, 1, 0, 0])) | |
# and so on. | |
# | |
# Using vmap, we can vectorize all 4 of these computations into one by | |
# passing the standard basis for R^4 as the grad_output. | |
# vmap(partial(autograd.grad, g(x), x))(torch.eye(4)). | |
# | |
# Now, how do we compute the jacobian *without stacking the output*? | |
# We can just split the standard basis across the outputs. So to | |
# compute the jacobian of f(x), we'd use | |
# >>> autograd.grad(f(x), x, grad_outputs=_construct_standard_basis_for(...)) | |
# The grad_outputs looks like the following: | |
# ( torch.tensor([[1, 0, 0], | |
# [0, 1, 0], | |
# [0, 0, 1], | |
# [0, 0, 0]]), | |
# torch.tensor([[0], | |
# [0], | |
# [0], | |
# [1]]) ) | |
# | |
# But we're not done yet! | |
# >>> vmap(partial(autograd.grad(f(x), x, grad_outputs=...))) | |
# returns a Tensor of shape [4, 3]. We have to remember to split the | |
# jacobian of shape [4, 3] into two: | |
# - one of shape [3, 3] for the first output | |
# - one of shape [ 3] for the second output | |
# Step 1: Construct grad_outputs by splitting the standard basis | |
output_numels = tuple(output.numel() for output in outputs) | |
grad_outputs = _construct_standard_basis_for(outputs, output_numels) | |
flat_outputs = tuple(output.reshape(-1) for output in outputs) | |
# Step 2: Call vmap + autograd.grad | |
def vjp(grad_output): | |
vj = list( | |
_autograd_grad( | |
flat_outputs, | |
inputs, | |
grad_output, | |
create_graph=create_graph, | |
is_grads_batched=True, | |
) | |
) | |
for el_idx, vj_el in enumerate(vj): | |
if vj_el is not None: | |
continue | |
vj[el_idx] = torch.zeros_like(inputs[el_idx]).expand( | |
(sum(output_numels),) + inputs[el_idx].shape | |
) | |
return tuple(vj) | |
jacobians_of_flat_output = vjp(grad_outputs) | |
# Step 3: The returned jacobian is one big tensor per input. In this step, | |
# we split each Tensor by output. | |
jacobian_input_output = [] | |
for jac_input_i, input_i in zip(jacobians_of_flat_output, inputs): | |
jacobian_input_i_output = [] | |
for jac, output_j in zip( | |
jac_input_i.split(output_numels, dim=0), outputs | |
): | |
jacobian_input_i_output_j = jac.view(output_j.shape + input_i.shape) | |
jacobian_input_i_output.append(jacobian_input_i_output_j) | |
jacobian_input_output.append(jacobian_input_i_output) | |
# Step 4: Right now, `jacobian` is a List[List[Tensor]]. | |
# The outer List corresponds to the number of inputs, | |
# the inner List corresponds to the number of outputs. | |
# We need to exchange the order of these and convert to tuples | |
# before returning. | |
jacobian_output_input = tuple(zip(*jacobian_input_output)) | |
jacobian_output_input = _grad_postprocess( | |
jacobian_output_input, create_graph | |
) | |
return _tuple_postprocess( | |
jacobian_output_input, (is_outputs_tuple, is_inputs_tuple) | |
) | |
jacobian: Tuple[torch.Tensor, ...] = tuple() | |
for i, out in enumerate(outputs): | |
# mypy complains that expression and variable have different types due to the empty list | |
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore[assignment] | |
for j in range(out.nelement()): | |
vj = _autograd_grad( | |
(out.reshape(-1)[j],), | |
inputs, | |
retain_graph=True, | |
create_graph=create_graph, | |
) | |
for el_idx, (jac_i_el, vj_el, inp_el) in enumerate( | |
zip(jac_i, vj, inputs) | |
): | |
if vj_el is not None: | |
if strict and create_graph and not vj_el.requires_grad: | |
msg = ( | |
"The jacobian of the user-provided function is " | |
f"independent of input {i}. This is not allowed in " | |
"strict mode when create_graph=True." | |
) | |
raise RuntimeError(msg) | |
jac_i_el.append(vj_el) | |
else: | |
if strict: | |
msg = ( | |
f"Output {i} of the user-provided function is " | |
f"independent of input {el_idx}. This is not allowed in " | |
"strict mode." | |
) | |
raise RuntimeError(msg) | |
jac_i_el.append(torch.zeros_like(inp_el)) | |
jacobian += ( | |
tuple( | |
torch.stack(jac_i_el, dim=0).view( | |
out.size() + inputs[el_idx].size() # type: ignore[operator] | |
) | |
for (el_idx, jac_i_el) in enumerate(jac_i) | |
), | |
) | |
jacobian = _grad_postprocess(jacobian, create_graph) | |
return _tuple_postprocess(jacobian, (is_outputs_tuple, is_inputs_tuple)) | |
def hessian( | |
func, | |
inputs, | |
create_graph=False, | |
strict=False, | |
vectorize=False, | |
outer_jacobian_strategy="reverse-mode", | |
): | |
r"""Compute the Hessian of a given scalar function. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a Tensor with a single element. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
create_graph (bool, optional): If ``True``, the Hessian will be computed in | |
a differentiable manner. Note that when ``strict`` is ``False``, the result can not | |
require gradients or be disconnected from the inputs. | |
Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we detect that there exists an input | |
such that all the outputs are independent of it. If ``False``, we return a Tensor of zeros as the | |
hessian for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
vectorize (bool, optional): This feature is experimental. | |
Please consider using :func:`torch.func.hessian` | |
instead if you are looking for something less experimental and more performant. | |
When computing the hessian, usually we invoke | |
``autograd.grad`` once per row of the hessian. If this flag is | |
``True``, we use the vmap prototype feature as the backend to | |
vectorize calls to ``autograd.grad`` so we only invoke it once | |
instead of once per row. This should lead to performance | |
improvements in many use cases, however, due to this feature | |
being incomplete, there may be performance cliffs. Please | |
use `torch._C._debug_only_display_vmap_fallback_warnings(True)` | |
to show any performance warnings and file us issues if | |
warnings exist for your use case. Defaults to ``False``. | |
outer_jacobian_strategy (str, optional): The Hessian is computed by | |
computing the Jacobian of a Jacobian. The inner Jacobian is always | |
computed in reverse-mode AD. Setting strategy to ``"forward-mode"`` | |
or ``"reverse-mode"`` determines whether the outer Jacobian will be | |
computed with forward or reverse mode AD. Currently, computing the outer | |
Jacobian in ``"forward-mode"`` requires ``vectorized=True``. Defaults | |
to ``"reverse-mode"``. | |
Returns: | |
Hessian (Tensor or a tuple of tuple of Tensors): if there is a single input, | |
this will be a single Tensor containing the Hessian for the input. | |
If it is a tuple, then the Hessian will be a tuple of tuples where | |
``Hessian[i][j]`` will contain the Hessian of the ``i``\th input | |
and ``j``\th input with size the sum of the size of the ``i``\th input plus | |
the size of the ``j``\th input. ``Hessian[i][j]`` will have the same | |
dtype and device as the corresponding ``i``\th input. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def pow_reducer(x): | |
... return x.pow(3).sum() | |
>>> inputs = torch.rand(2, 2) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> hessian(pow_reducer, inputs) | |
tensor([[[[5.2265, 0.0000], | |
[0.0000, 0.0000]], | |
[[0.0000, 4.8221], | |
[0.0000, 0.0000]]], | |
[[[0.0000, 0.0000], | |
[1.9456, 0.0000]], | |
[[0.0000, 0.0000], | |
[0.0000, 3.2550]]]]) | |
>>> hessian(pow_reducer, inputs, create_graph=True) | |
tensor([[[[5.2265, 0.0000], | |
[0.0000, 0.0000]], | |
[[0.0000, 4.8221], | |
[0.0000, 0.0000]]], | |
[[[0.0000, 0.0000], | |
[1.9456, 0.0000]], | |
[[0.0000, 0.0000], | |
[0.0000, 3.2550]]]], grad_fn=<ViewBackward>) | |
>>> def pow_adder_reducer(x, y): | |
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> hessian(pow_adder_reducer, inputs) | |
((tensor([[4., 0.], | |
[0., 4.]]), | |
tensor([[0., 0.], | |
[0., 0.]])), | |
(tensor([[0., 0.], | |
[0., 0.]]), | |
tensor([[6., 0.], | |
[0., 6.]]))) | |
""" | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hessian") | |
assert outer_jacobian_strategy in ( | |
"forward-mode", | |
"reverse-mode", | |
), 'Expected strategy to be either "forward-mode" or "reverse-mode".' | |
def ensure_single_output_function(*inp): | |
out = func(*inp) | |
is_out_tuple, t_out = _as_tuple( | |
out, "outputs of the user-provided function", "hessian" | |
) | |
_check_requires_grad(t_out, "outputs", strict=strict) | |
if is_out_tuple or not isinstance(out, torch.Tensor): | |
raise RuntimeError( | |
"The function given to hessian should return a single Tensor" | |
) | |
if out.nelement() != 1: | |
raise RuntimeError( | |
"The Tensor returned by the function given to hessian should contain a single element" | |
) | |
return out.squeeze() | |
def jac_func(*inp): | |
if outer_jacobian_strategy == "forward-mode": | |
# _grad_preprocess requires create_graph=True and input to require_grad | |
# or else the input will be detached | |
inp = tuple(t.requires_grad_(True) for t in inp) | |
jac = jacobian(ensure_single_output_function, inp, create_graph=True) | |
_check_requires_grad(jac, "jacobian", strict=strict) | |
return jac | |
res = jacobian( | |
jac_func, | |
inputs, | |
create_graph=create_graph, | |
strict=strict, | |
vectorize=vectorize, | |
strategy=outer_jacobian_strategy, | |
) | |
return _tuple_postprocess(res, (is_inputs_tuple, is_inputs_tuple)) | |
def vhp(func, inputs, v=None, create_graph=False, strict=False): | |
r"""Compute the dot product between vector ``v`` and Hessian of a given scalar function at a specified point. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a Tensor with a single element. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
v (tuple of Tensors or Tensor): The vector for which the vector Hessian | |
product is computed. Must be the same size as the input of | |
``func``. This argument is optional when ``func``'s input contains | |
a single element and (if it is not provided) will be set as a | |
Tensor containing a single ``1``. | |
create_graph (bool, optional): If ``True``, both the output and result | |
will be computed in a differentiable way. Note that when ``strict`` | |
is ``False``, the result can not require gradients or be | |
disconnected from the inputs. | |
Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we | |
detect that there exists an input such that all the outputs are | |
independent of it. If ``False``, we return a Tensor of zeros as the | |
vhp for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
Returns: | |
output (tuple): tuple with: | |
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` | |
vhp (tuple of Tensors or Tensor): result of the dot product with the | |
same shape as the inputs. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def pow_reducer(x): | |
... return x.pow(3).sum() | |
>>> inputs = torch.rand(2, 2) | |
>>> v = torch.ones(2, 2) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> vhp(pow_reducer, inputs, v) | |
(tensor(0.5591), | |
tensor([[1.0689, 1.2431], | |
[3.0989, 4.4456]])) | |
>>> vhp(pow_reducer, inputs, v, create_graph=True) | |
(tensor(0.5591, grad_fn=<SumBackward0>), | |
tensor([[1.0689, 1.2431], | |
[3.0989, 4.4456]], grad_fn=<MulBackward0>)) | |
>>> def pow_adder_reducer(x, y): | |
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> v = (torch.zeros(2), torch.ones(2)) | |
>>> vhp(pow_adder_reducer, inputs, v) | |
(tensor(4.8053), | |
(tensor([0., 0.]), | |
tensor([6., 6.]))) | |
""" | |
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vhp") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
if v is not None: | |
_, v = _as_tuple(v, "v", "vhp") | |
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) | |
_validate_v(v, inputs, is_inputs_tuple) | |
else: | |
if len(inputs) != 1 or inputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The vector v can only be None if the input to the user-provided function " | |
"is a single Tensor with a single element." | |
) | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple( | |
outputs, "outputs of the user-provided function", "vhp" | |
) | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): | |
raise RuntimeError( | |
"The function given to vhp should return a single Tensor" | |
) | |
if outputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The Tensor returned by the function given to vhp should contain a single element" | |
) | |
jac = _autograd_grad(outputs, inputs, create_graph=True) | |
_check_requires_grad(jac, "jacobian", strict=strict) | |
enable_grad = True if create_graph else torch.is_grad_enabled() | |
with torch.set_grad_enabled(enable_grad): | |
grad_res = _autograd_grad(jac, inputs, v, create_graph=create_graph) | |
vhp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "double_back") | |
outputs = _grad_postprocess(outputs, create_graph) | |
vhp = _grad_postprocess(vhp, create_graph) | |
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( | |
vhp, is_inputs_tuple | |
) | |
def hvp(func, inputs, v=None, create_graph=False, strict=False): | |
r"""Compute the dot product between the scalar function's Hessian and a vector ``v`` at a specified point. | |
Args: | |
func (function): a Python function that takes Tensor inputs and returns | |
a Tensor with a single element. | |
inputs (tuple of Tensors or Tensor): inputs to the function ``func``. | |
v (tuple of Tensors or Tensor): The vector for which the Hessian vector | |
product is computed. Must be the same size as the input of | |
``func``. This argument is optional when ``func``'s input contains | |
a single element and (if it is not provided) will be set as a | |
Tensor containing a single ``1``. | |
create_graph (bool, optional): If ``True``, both the output and result will be | |
computed in a differentiable way. Note that when ``strict`` is | |
``False``, the result can not require gradients or be disconnected | |
from the inputs. Defaults to ``False``. | |
strict (bool, optional): If ``True``, an error will be raised when we | |
detect that there exists an input such that all the outputs are | |
independent of it. If ``False``, we return a Tensor of zeros as the | |
hvp for said inputs, which is the expected mathematical value. | |
Defaults to ``False``. | |
Returns: | |
output (tuple): tuple with: | |
func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` | |
hvp (tuple of Tensors or Tensor): result of the dot product with | |
the same shape as the inputs. | |
Example: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_AUTOGRAD) | |
>>> def pow_reducer(x): | |
... return x.pow(3).sum() | |
>>> inputs = torch.rand(2, 2) | |
>>> v = torch.ones(2, 2) | |
>>> # xdoctest: +IGNORE_WANT("non-deterministic") | |
>>> hvp(pow_reducer, inputs, v) | |
(tensor(0.1448), | |
tensor([[2.0239, 1.6456], | |
[2.4988, 1.4310]])) | |
>>> hvp(pow_reducer, inputs, v, create_graph=True) | |
(tensor(0.1448, grad_fn=<SumBackward0>), | |
tensor([[2.0239, 1.6456], | |
[2.4988, 1.4310]], grad_fn=<MulBackward0>)) | |
>>> def pow_adder_reducer(x, y): | |
... return (2 * x.pow(2) + 3 * y.pow(2)).sum() | |
>>> inputs = (torch.rand(2), torch.rand(2)) | |
>>> v = (torch.zeros(2), torch.ones(2)) | |
>>> hvp(pow_adder_reducer, inputs, v) | |
(tensor(2.3030), | |
(tensor([0., 0.]), | |
tensor([6., 6.]))) | |
Note: | |
This function is significantly slower than `vhp` due to backward mode AD constraints. | |
If your functions is twice continuously differentiable, then hvp = vhp.t(). So if you | |
know that your function satisfies this condition, you should use vhp instead that is | |
much faster with the current implementation. | |
""" | |
with torch.enable_grad(): | |
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "hvp") | |
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True) | |
if v is not None: | |
_, v = _as_tuple(v, "v", "hvp") | |
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False) | |
_validate_v(v, inputs, is_inputs_tuple) | |
else: | |
if len(inputs) != 1 or inputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The vector v can only be None if the input to the user-provided function " | |
"is a single Tensor with a single element." | |
) | |
outputs = func(*inputs) | |
is_outputs_tuple, outputs = _as_tuple( | |
outputs, "outputs of the user-provided function", "hvp" | |
) | |
_check_requires_grad(outputs, "outputs", strict=strict) | |
if is_outputs_tuple or not isinstance(outputs[0], torch.Tensor): | |
raise RuntimeError( | |
"The function given to hvp should return a single Tensor" | |
) | |
if outputs[0].nelement() != 1: | |
raise RuntimeError( | |
"The Tensor returned by the function given to hvp should contain a single element" | |
) | |
jac = _autograd_grad(outputs, inputs, create_graph=True) | |
_check_requires_grad(jac, "jacobian", strict=strict) | |
grad_jac = tuple(torch.zeros_like(inp, requires_grad=True) for inp in inputs) | |
double_back = _autograd_grad(jac, inputs, grad_jac, create_graph=True) | |
_check_requires_grad(jac, "hessian", strict=strict) | |
enable_grad = True if create_graph else torch.is_grad_enabled() | |
with torch.set_grad_enabled(enable_grad): | |
grad_res = _autograd_grad(double_back, grad_jac, v, create_graph=create_graph) | |
hvp = _fill_in_zeros( | |
grad_res, inputs, strict, create_graph, "double_back_trick" | |
) | |
outputs = _grad_postprocess(outputs, create_graph) | |
hvp = _grad_postprocess(hvp, create_graph) | |
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess( | |
hvp, is_inputs_tuple | |
) | |