Spaces:
Running
Running
File size: 9,321 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# mypy: ignore-errors
import weakref
from typing import Dict, List
import torch
from ..decorators import mark_static_address
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstDictKeySource, GetItemSource, GlobalWeakRefSource
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
):
super().__init__(value, **kwargs)
for group in self.value.param_groups:
if "capturable" in group:
group["capturable"] = True
for p in group["params"]:
mark_static_address(p, guard=False)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
mangled_name = f"__optimizer_{id(self.value)}"
tx.store_global_weakref_by_id(mangled_name, self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)
return super().var_getattr(tx, name)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException()
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def map_sources_and_install_guards(self, tx):
self.grad_to_source = {}
self.tensor_to_source = {}
from .builder import VariableBuilder
param_groups_vt = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
).recursive_realize()
for g_ind, (group, group_vt) in enumerate(
zip(self.value.param_groups, param_groups_vt.items)
):
group_source = group_vt.source
params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = AttrSource(
param_source,
"grad",
)
if p.grad is not None:
self.grad_to_source[p.grad] = grad_source
else:
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
# state guards take a long time to generate
# so we manually generate them here
state_source = AttrSource(self.source, "state")
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for idx, (p, value) in enumerate(self.value.state.items()):
tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, p)
p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
elif v is None or isinstance(v, (bool, int, float, str)):
install_guard(
GetItemSource(p_state_source, k).make_guard(
GuardBuilder.CONSTANT_MATCH
)
)
else:
raise GuardInstallException()
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
result = builder(tensor_value)
return result
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
assert isinstance(
py_arg, list
), "py_arg should be a list in optimizer variable"
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder()(tx, val))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)
|