File size: 7,919 Bytes
593f3bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import glob
import os
import re
import subprocess
import traceback

import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist


@contextlib.contextmanager
def dist_load(path):
    if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
        yield path
    else:
        from tts.utils.commons.hparams import hparams
        from tts.utils.commons.trainer import LOCAL_RANK
        tmpdir = '/dev/shm'
        assert len(os.path.basename(path)) > 0
        shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
        if LOCAL_RANK == 0:
            subprocess.check_call(
                f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
                f'cp -Lr {path} {shm_ckpt_path}', shell=True)
        dist.barrier()
        yield shm_ckpt_path
        dist.barrier()
        if LOCAL_RANK == 0:
            subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)


def torch_load_dist(path, map_location='cpu'):
    with dist_load(path) as tmp_path:
        checkpoint = torch.load(tmp_path, map_location=map_location)
    return checkpoint


def get_last_checkpoint(work_dir, steps=None):
    checkpoint = None
    last_ckpt_path = None
    ckpt_paths = get_all_ckpts(work_dir, steps)
    if len(ckpt_paths) > 0:
        last_ckpt_path = ckpt_paths[0]
        checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
    return checkpoint, last_ckpt_path


def get_all_ckpts(work_dir, steps=None):
    if steps is None or steps == 0:
        ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
    else:
        ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
    return sorted(glob.glob(ckpt_path_pattern),
                  key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))


def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
              silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
    if checkpoint is None:
        if os.path.isfile(ckpt_base_dir):
            base_dir = os.path.dirname(ckpt_base_dir)
            ckpt_path = ckpt_base_dir
            checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
        else:
            base_dir = ckpt_base_dir
            if load_opt:
                checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
            else:
                ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
                if os.path.exists(ckpt_path):
                    checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
                else:
                    checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
    if checkpoint is not None:
        state_dict_all = {
            k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
        if not isinstance(cur_model, list):
            cur_models = [cur_model]
            model_names = [model_name]
        else:
            cur_models = cur_model
            model_names = model_name
        for model_name, cur_model in zip(model_names, cur_models):
            if isinstance(cur_model, DistributedDataParallel):
                cur_model = cur_model.module
            device = next(cur_model.parameters()).device
            if '.' not in model_name:
                state_dict = state_dict_all[model_name]
            else:
                base_model_name = model_name.split('.')[0]
                rest_model_name = model_name[len(base_model_name) + 1:]
                state_dict = {
                    k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
                    if k.startswith(f'{rest_model_name}.')}
            state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
            if not strict and delete_unmatch:
                try:
                    cur_model.load_state_dict(state_dict, strict=True)
                    if not silent:
                        print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
                except:
                    cur_model_state_dict = cur_model.state_dict()
                    cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
                                            cur_model_state_dict.items()}
                    unmatched_keys = []
                    for key, param in state_dict.items():
                        if key in cur_model_state_dict:
                            new_param = cur_model_state_dict[key]
                            if new_param.shape != param.shape:
                                unmatched_keys.append(key)
                                print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
                                        "ckpt model: ", param.shape)
                    for key in unmatched_keys:
                        del state_dict[key]
            load_results = cur_model.load_state_dict(state_dict, strict=strict)
            cur_model.to(device)
            if not silent:
                print(f"| loaded '{model_name}' from '{ckpt_path}'.")
                missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
                print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
        if load_opt:
            optimizer_states = checkpoint['optimizer_states']
            assert len(opts) == len(optimizer_states)
            for optimizer, opt_state in zip(opts, optimizer_states):
                opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
                if optimizer is None:
                    return
                try:
                    optimizer.load_state_dict(opt_state)
                    for i, state in enumerate(optimizer.state.values()):
                        for k, v in state.items():
                            if isinstance(v, torch.Tensor):
                                state[k] = v.to(device)
                except ValueError:
                    print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
        return checkpoint.get('global_step', 0)
    else:
        e_msg = f"| ckpt not found in {base_dir}."
        if force:
            assert False, e_msg
        else:
            print(e_msg)


def load_with_size_mismatch(model, state_dict, prefix=""):
    current_model_dict = model.state_dict()
    cm_keys = current_model_dict.keys()
    mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
    new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
    missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
    print(f"| mismatch keys: ", mismatch_keys)
    if len(missing_keys) > 0:
        print(f"| missing_keys in dit: {missing_keys}")
    if len(unexpected_keys) > 0:
        print(f"| unexpected_keys in dit: {unexpected_keys}")