File size: 4,337 Bytes
88b5dc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import logging
import re
from functools import cache, partial
from typing import Callable, TypeVar
import deepspeed
import pandas as pd
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.utils import clip_grad_norm_
from torch import nn
from .distributed import fix_unset_envs
logger = logging.getLogger(__name__)
T = TypeVar("T")
def flatten_dict(d):
records = pd.json_normalize(d, sep="/").to_dict(orient="records")
return records[0] if records else {}
def _get_named_modules(module, attrname, sep="/"):
for name, module in module.named_modules():
name = name.replace(".", sep)
if hasattr(module, attrname):
yield name, module
def gather_attribute(module, attrname, delete=True, prefix=None):
ret = {}
for name, module in _get_named_modules(module, attrname):
ret[name] = getattr(module, attrname)
if delete:
try:
delattr(module, attrname)
except Exception as e:
raise RuntimeError(f"{name} {module} {attrname}") from e
if prefix:
ret = {prefix: ret}
ret = flatten_dict(ret)
# remove consecutive /
ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()}
return ret
def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None):
for _, module in _get_named_modules(module, attrname):
if filter_fn is None or filter_fn(module):
setattr(module, attrname, value)
@cache
def update_deepspeed_logger():
logger = logging.getLogger("DeepSpeed")
logger.setLevel(logging.WARNING)
@cache
def init_distributed():
update_deepspeed_logger()
fix_unset_envs()
deepspeed.init_distributed(get_accelerator().communication_backend_name())
def _try_each(*fns, e=None):
if len(fns) == 0:
raise RuntimeError("All functions failed")
head, *tails = fns
try:
return head()
except Exception as e:
logger.warning(f"Tried {head} but failed: {e}, trying next")
return _try_each(*tails)
class Engine(DeepSpeedEngine):
def __init__(self, *args, ckpt_dir, **kwargs):
init_distributed()
super().__init__(args=None, *args, **kwargs)
self._ckpt_dir = ckpt_dir
self._frozen_params = set()
self._fp32_grad_norm = None
@property
def path(self):
return self._ckpt_dir
def freeze_(self):
for p in self.module.parameters():
if p.requires_grad:
p.requires_grad_(False)
self._frozen_params.add(p)
def unfreeze_(self):
for p in self._frozen_params:
p.requires_grad_(True)
self._frozen_params.clear()
@property
def global_step(self):
return self.global_steps
def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs)
def dispatch_attribute(self, *args, **kwargs):
return dispatch_attribute(self.module, *args, **kwargs)
def clip_fp32_gradients(self):
self._fp32_grad_norm = clip_grad_norm_(
parameters=self.module.parameters(),
max_norm=self.gradient_clipping(),
mpu=self.mpu,
)
def get_grad_norm(self):
grad_norm = self.get_global_grad_norm()
if grad_norm is None:
grad_norm = self._fp32_grad_norm
return grad_norm
def save_checkpoint(self, *args, **kwargs):
if not self._ckpt_dir.exists():
self._ckpt_dir.mkdir(parents=True, exist_ok=True)
super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs)
logger.info(f"Saved checkpoint to {self._ckpt_dir}")
def load_checkpoint(self, *args, **kwargs):
fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs)
return _try_each(
lambda: fn(),
lambda: fn(load_optimizer_states=False),
lambda: fn(load_lr_scheduler_states=False),
lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False),
lambda: fn(
load_optimizer_states=False,
load_lr_scheduler_states=False,
load_module_strict=False,
),
)
|