File size: 4,337 Bytes
88b5dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
import re
from functools import cache, partial
from typing import Callable, TypeVar

import deepspeed
import pandas as pd
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.utils import clip_grad_norm_
from torch import nn

from .distributed import fix_unset_envs

logger = logging.getLogger(__name__)

T = TypeVar("T")


def flatten_dict(d):
    records = pd.json_normalize(d, sep="/").to_dict(orient="records")
    return records[0] if records else {}


def _get_named_modules(module, attrname, sep="/"):
    for name, module in module.named_modules():
        name = name.replace(".", sep)
        if hasattr(module, attrname):
            yield name, module


def gather_attribute(module, attrname, delete=True, prefix=None):
    ret = {}
    for name, module in _get_named_modules(module, attrname):
        ret[name] = getattr(module, attrname)
        if delete:
            try:
                delattr(module, attrname)
            except Exception as e:
                raise RuntimeError(f"{name} {module} {attrname}") from e
    if prefix:
        ret = {prefix: ret}
    ret = flatten_dict(ret)
    # remove consecutive /
    ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()}
    return ret


def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None):
    for _, module in _get_named_modules(module, attrname):
        if filter_fn is None or filter_fn(module):
            setattr(module, attrname, value)


@cache
def update_deepspeed_logger():
    logger = logging.getLogger("DeepSpeed")
    logger.setLevel(logging.WARNING)


@cache
def init_distributed():
    update_deepspeed_logger()
    fix_unset_envs()
    deepspeed.init_distributed(get_accelerator().communication_backend_name())


def _try_each(*fns, e=None):
    if len(fns) == 0:
        raise RuntimeError("All functions failed")

    head, *tails = fns

    try:
        return head()
    except Exception as e:
        logger.warning(f"Tried {head} but failed: {e}, trying next")
        return _try_each(*tails)


class Engine(DeepSpeedEngine):
    def __init__(self, *args, ckpt_dir, **kwargs):
        init_distributed()
        super().__init__(args=None, *args, **kwargs)
        self._ckpt_dir = ckpt_dir
        self._frozen_params = set()
        self._fp32_grad_norm = None

    @property
    def path(self):
        return self._ckpt_dir

    def freeze_(self):
        for p in self.module.parameters():
            if p.requires_grad:
                p.requires_grad_(False)
                self._frozen_params.add(p)

    def unfreeze_(self):
        for p in self._frozen_params:
            p.requires_grad_(True)
        self._frozen_params.clear()

    @property
    def global_step(self):
        return self.global_steps

    def gather_attribute(self, *args, **kwargs):
        return gather_attribute(self.module, *args, **kwargs)

    def dispatch_attribute(self, *args, **kwargs):
        return dispatch_attribute(self.module, *args, **kwargs)

    def clip_fp32_gradients(self):
        self._fp32_grad_norm = clip_grad_norm_(
            parameters=self.module.parameters(),
            max_norm=self.gradient_clipping(),
            mpu=self.mpu,
        )

    def get_grad_norm(self):
        grad_norm = self.get_global_grad_norm()
        if grad_norm is None:
            grad_norm = self._fp32_grad_norm
        return grad_norm

    def save_checkpoint(self, *args, **kwargs):
        if not self._ckpt_dir.exists():
            self._ckpt_dir.mkdir(parents=True, exist_ok=True)
        super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs)
        logger.info(f"Saved checkpoint to {self._ckpt_dir}")

    def load_checkpoint(self, *args, **kwargs):
        fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs)
        return _try_each(
            lambda: fn(),
            lambda: fn(load_optimizer_states=False),
            lambda: fn(load_lr_scheduler_states=False),
            lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False),
            lambda: fn(
                load_optimizer_states=False,
                load_lr_scheduler_states=False,
                load_module_strict=False,
            ),
        )