Spaces:
Running
Running
File size: 14,297 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
import logging
import warnings
from copy import deepcopy
from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Union, overload
import torch
import torch.nn as nn
from torch import optim
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
__all__: List[str] = []
logger = logging.getLogger(__name__)
class _NamedOptimizer(optim.Optimizer):
"""
``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key.
We replace the original key (number) in an optim to the
fully qualified name (FQN) string. User can initialize the optim as they
initialize a PyTorch optim, the only difference is that they also need to
pass in the FQN of each parameters.
Args:
named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
Mapping from FQN to parameter.
optimizer_class (optim.Optimizer):
The class of optimizer to instantiate.
param_groups (Collection[Mapping[str, Any]]):
`param_groups` to pass to optimizer if specified.
The key of the inner map needs to be FQNs.
Default: None
module (nn.Module): the module whose parameters to updated
by the optimizer.
args: arguments to pass to the optimizer constructor.
kwargs: arguments to pass to the optimizer constructor.
Example::
>>> # xdoctest: +SKIP("distributed")
>>> from torch import optim
>>> from torch.distributed.optim import _NamedOptimizer
>>>
>>> # Define the named optimizer.
>>> m = Model(...)
>>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
>>> # Forward pass + backward pass.
>>> named_optim.step()
>>> ...
>>> # Call state_dict for the named optimizer returns a FQN state_dict.
>>> named_optim.state_dict()
Warning: This API is still in development and subject to change.
TODO: Add tutorial for _NamedOptimizer.
TODO: Add documentation in the docstring for the public attributes
like self.param_groups and self.named_parameters.
"""
def __init__(
self,
named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
optimizer_class: optim.Optimizer,
param_groups: Optional[Collection[Mapping[str, Any]]] = None,
module: Optional[nn.Module] = None,
*args,
**kwargs,
) -> None:
torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
self._param_groups_check()
self.named_parameters = dict(named_parameters)
params_for_optimizer = (
self.named_parameters.values() if param_groups is None else param_groups
)
self._optimizer = optimizer_class( # type: ignore[operator]
params_for_optimizer,
*args,
**kwargs,
)
self.module = module
if param_groups is None:
self.ordered_param_keys = list(self.named_parameters.keys())
else:
warnings.warn(
"Since we pass in param_groups, we will use param_groups to "
"initialize the optimizer, not all parameters of the module."
)
param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
ordered_param_keys = []
for group in param_groups:
for param in group["params"]:
if param not in param_to_key:
raise ValueError(
f"Expect param name {param} found in param group but is missing."
)
ordered_param_keys.append(param_to_key[param])
self.ordered_param_keys = ordered_param_keys
# Update param_groups from optimizer.
self.param_groups = self._optimizer.param_groups
def _param_groups_check(self):
if self.param_groups is not None:
for param_group in self.param_groups:
assert isinstance(param_group, dict), "param group must be a dict"
assert "params" in param_group, "param group must contain key params"
params = param_group["params"]
if isinstance(params, torch.Tensor):
params = [params]
params = list(params)
for param in params:
if not isinstance(param, torch.Tensor):
raise TypeError(
"optimizer can only optimize Tensors, "
"but one of the params is " + torch.typename(param)
)
param_group["params"] = params
def state_dict(self) -> Dict[str, Any]:
"""
Return the ``state_dict`` of the optimizer.
Instead of using number to index
parameters, we will use module fully qualified name (FQN) as the key.
"""
state_dict = self._optimizer.state_dict()
param_groups = state_dict["param_groups"]
ret_state = {
self.ordered_param_keys[st_key]: state_val
for st_key, state_val in state_dict["state"].items()
}
ret_groups = []
for group in param_groups:
param_keys = []
for param in group["params"]:
param_keys.append(self.ordered_param_keys[param])
ret_group = {"params": sorted(param_keys)}
for k, v in group.items():
if k != "params":
ret_group[k] = deepcopy(v)
ret_groups.append(ret_group)
return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
@overload
def step(self, closure: None = ...) -> None:
...
@overload
def step(self, closure: Callable[[], float]) -> float:
...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
"""
Perform a single optimization step.
This will call :meth:`torch.optim.Optimizer.step` on the wrapped
optimizer.
"""
return self._optimizer.step(closure=closure)
@property
def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override]
return self._optimizer.state
def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
"""
Define the default behavior to load a state_dict for ``_NamedOptimizer``.
Sample Code
```
my_model = MyModule()
optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
...
optim_state_dict = optimizer.state_dict()
...
...
optimizer.load_state_dict(optim_state_dict)
...
```
Args:
state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
Note that this state dict update is performed in place.
.. note:: PyTorch is using lazy init to initialize the optim states.
So it is possible that there is no optim state when user call
``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
that users can only call ``load_state_dict`` after the state is initialized.
By doing this, we can validate the optim ``state_dict`` to be loaded.
"""
new_state_dict = self._optimizer.state_dict()
state_dict = self._pre_load_state_dict(state_dict)
state = state_dict["state"]
new_state = new_state_dict["state"]
if len(new_state) == 0:
raise ValueError(
"Expects the optim to be initialized before load but found not initialized."
)
for idx, param_key in enumerate(self.ordered_param_keys):
# When the conditional training is performed, not all parameters are updated in the optim.
if param_key not in state.keys():
continue
if len(state[param_key]) != len(new_state[idx]):
raise ValueError(
f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
)
# Iterate through all optimizer states.
for state_key, state_val in new_state[idx].items():
if state_key not in state[param_key]:
raise ValueError(
f"Expects state {state_key} for parameter {param_key} but not found."
)
src_state_val = state[param_key][state_key]
if isinstance(state_val, ShardedTensor):
assert isinstance(src_state_val, ShardedTensor)
num_shards = len(state_val.local_shards())
num_new_shards = len(src_state_val.local_shards())
if num_shards != num_new_shards:
raise ValueError(
f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
)
for shard, src_shard in zip(
state_val.local_shards(), src_state_val.local_shards()
):
shard.tensor.detach().copy_(src_shard.tensor)
elif isinstance(state_val, torch.Tensor):
assert isinstance(src_state_val, torch.Tensor)
state_val.detach().copy_(src_state_val)
else:
new_state[idx][state_key] = deepcopy(src_state_val)
# Load param_groups of state_dict
src_param_groups = state_dict["param_groups"]
new_param_groups = new_state_dict["param_groups"]
src_group_map = {}
for group in src_param_groups:
param_keys = list(group["params"])
src_group_map[_gen_param_group_key(param_keys)] = group
new_group_map = {}
for new_group in new_param_groups:
param_keys = []
for param_key in new_group["params"]:
param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload]
new_group_map[_gen_param_group_key(param_keys)] = new_group
for group_key, new_group in new_group_map.items():
# When not all parameters are used in training or receive gradient, aka., not all parameters
# would be in the param_group. Thus we skip the group_key here.
if group_key not in src_group_map:
continue
src_group = src_group_map[group_key]
if len(src_group) != len(new_group):
raise ValueError(
f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
)
for k in src_group:
if k not in new_group:
raise ValueError(
f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
)
if k != "params":
new_group[k] = deepcopy(src_group[k])
self._optimizer.load_state_dict(new_state_dict)
def add_param_group(self, param_group: Mapping[str, Any]) -> None:
"""
Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
Warning: This API is still in development and subject to change.
"""
assert isinstance(param_group, dict), "param group must be a dict"
params = param_group["params"]
if isinstance(params, torch.Tensor):
param_group["params"] = [params]
else:
param_group["params"] = list(params)
param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
for param in param_group["params"]:
if param not in param_to_key:
raise ValueError("some parameters are not in the module")
self.ordered_param_keys.append(param_to_key[param])
self._optimizer.add_param_group(param_group)
# Update param_groups from optimizer.
self.param_groups = self._optimizer.param_groups
def init_state(self) -> None:
"""
Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers.
This allows doing in-place loading of optimizer state from a checkpoint.
"""
for param in self.named_parameters.values():
if param.requires_grad:
t = torch.zeros_like(param)
param.grad = torch.autograd.Variable(t)
# Calling ``step`` will load the initial state for optimizer states.
self.step(closure=None)
def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
return FSDP.optim_state_dict_to_load(
self.module, self._optimizer, state_dict, is_named_optimizer=True
)
return state_dict
def _post_state_dict(self, state_dict) -> Dict[str, Any]:
# TODO(chienchin): This API should be FSDP agnostic and should support
# general user hooks.
if isinstance(self.module, FSDP):
FSDP.optim_state_dict(self.module, self._optimizer, state_dict)
return state_dict
def _gen_param_group_key(param_keys: List[str]) -> str:
"""Concatenate all param keys as a unique indentifier for one param group."""
return "/".join(sorted(param_keys))
|