File size: 5,218 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch._functorch.apis as apis
import torch._functorch.eager_transforms as _impl
import torch._functorch.make_functional as _nn_impl
from torch._functorch.vmap import in_dims_t, out_dims_t
from torch._functorch.eager_transforms import argnums_t
import torch.nn as nn
import textwrap
from typing import Any, Callable, Optional, Tuple, Union
import warnings

"""

The APIs in this file are exposed as `functorch.*`. They are thin wrappers

around the torch.func.* APIs that have deprecation warnings -- we're trying

to move people to the torch.func.* equivalents.



NB: We don't use *args, **kwargs in the signatures because that changes the

documentation.

"""

def get_warning(api, new_api=None, replace_newlines=False):
    if new_api is None:
        new_api = f'torch.func.{api}'
    warning = (
        f"We've integrated functorch into PyTorch. As the final step of the \n"
        f"integration, functorch.{api} is deprecated as of PyTorch \n"
        f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
        f"Please use {new_api} instead; see the PyTorch 2.0 release notes \n"
        f"and/or the torch.func migration guide for more details \n"
        f"https://pytorch.org/docs/master/func.migrating.html"
    )
    if replace_newlines:
        warning = warning.replace("\n", "")
    return warning


def warn_deprecated(api, new_api=None):
    warning = get_warning(api, new_api, replace_newlines=True)
    warnings.warn(warning, stacklevel=2)


def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
    api_name = functorch_api.__name__
    if torch_func_api is None:
        torch_func_api = getattr(_impl, api_name)
    # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
    if torch_func_api.__doc__ is None:
        return

    warning = get_warning(api_name, new_api_name)
    warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, "    ")
    warning_note = textwrap.indent(warning_note, "    ")
    functorch_api.__doc__ = torch_func_api.__doc__ + warning_note

def vmap(

        func: Callable,

        in_dims: in_dims_t = 0,

        out_dims: out_dims_t = 0,

        randomness: str = 'error',

        *,

        chunk_size=None) -> Callable:
    warn_deprecated('vmap', 'torch.vmap')
    return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)

def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
    warn_deprecated('grad')
    return apis.grad(func, argnums, has_aux)

def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
    warn_deprecated('grad_and_value')
    return apis.grad_and_value(func, argnums, has_aux)

def vjp(func: Callable, *primals, has_aux: bool = False):
    warn_deprecated('vjp')
    return _impl.vjp(func, *primals, has_aux=has_aux)

def jvp(func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False):
    warn_deprecated('jvp')
    return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)

def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,

           chunk_size: Optional[int] = None,

           _preallocate_and_copy=False):
    warn_deprecated('jacrev')
    return _impl.jacrev(func, argnums, has_aux=has_aux, chunk_size=chunk_size,
                        _preallocate_and_copy=_preallocate_and_copy)

def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
    warn_deprecated('jacfwd')
    return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)

def hessian(func, argnums=0):
    warn_deprecated('hessian')
    return _impl.hessian(func, argnums=argnums)

def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
    warn_deprecated('functionalize')
    return _impl.functionalize(func, remove=remove)

def make_functional(model: nn.Module, disable_autograd_tracking: bool = False):
    warn_deprecated('make_functional', 'torch.func.functional_call')
    return _nn_impl.make_functional(model, disable_autograd_tracking)

def make_functional_with_buffers(model: nn.Module, disable_autograd_tracking: bool = False):
    warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')
    return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)

def combine_state_for_ensemble(models):
    warn_deprecated('combine_state_for_ensemble', 'torch.func.stack_module_state')
    return _nn_impl.combine_state_for_ensemble(models)

setup_docs(vmap, apis.vmap, 'torch.vmap')
setup_docs(grad, apis.grad)
setup_docs(grad_and_value, apis.grad_and_value)
setup_docs(vjp)
setup_docs(jvp)
setup_docs(jacrev)
setup_docs(jacfwd)
setup_docs(hessian)
setup_docs(functionalize)
setup_docs(make_functional, _nn_impl.make_functional,
           'torch.func.functional_call')
setup_docs(make_functional_with_buffers, _nn_impl.make_functional,
           'torch.func.functional_call')
setup_docs(combine_state_for_ensemble, _nn_impl.combine_state_for_ensemble,
           'torch.func.stack_module_state')