Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,919 Bytes
593f3bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# Copyright 2025 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import glob
import os
import re
import subprocess
import traceback
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
@contextlib.contextmanager
def dist_load(path):
if not dist.is_initialized() or dist.get_world_size() == 1 or os.path.realpath(path).startswith('/dev/shm'):
yield path
else:
from tts.utils.commons.hparams import hparams
from tts.utils.commons.trainer import LOCAL_RANK
tmpdir = '/dev/shm'
assert len(os.path.basename(path)) > 0
shm_ckpt_path = f'{tmpdir}/{hparams["exp_name"]}/{os.path.basename(path)}'
if LOCAL_RANK == 0:
subprocess.check_call(
f'mkdir -p {os.path.dirname(shm_ckpt_path)}; '
f'cp -Lr {path} {shm_ckpt_path}', shell=True)
dist.barrier()
yield shm_ckpt_path
dist.barrier()
if LOCAL_RANK == 0:
subprocess.check_call(f'rm -rf {shm_ckpt_path}', shell=True)
def torch_load_dist(path, map_location='cpu'):
with dist_load(path) as tmp_path:
checkpoint = torch.load(tmp_path, map_location=map_location)
return checkpoint
def get_last_checkpoint(work_dir, steps=None):
checkpoint = None
last_ckpt_path = None
ckpt_paths = get_all_ckpts(work_dir, steps)
if len(ckpt_paths) > 0:
last_ckpt_path = ckpt_paths[0]
checkpoint = torch_load_dist(last_ckpt_path, map_location='cpu')
return checkpoint, last_ckpt_path
def get_all_ckpts(work_dir, steps=None):
if steps is None or steps == 0:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt'
else:
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt'
return sorted(glob.glob(ckpt_path_pattern),
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True,
silent=False, load_opt=False, opts=None, steps=None, checkpoint=None, ckpt_path='', delete_unmatch=True):
if checkpoint is None:
if os.path.isfile(ckpt_base_dir):
base_dir = os.path.dirname(ckpt_base_dir)
ckpt_path = ckpt_base_dir
checkpoint = torch_load_dist(ckpt_base_dir, map_location='cpu')
else:
base_dir = ckpt_base_dir
if load_opt:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
else:
ckpt_path = f'{ckpt_base_dir}/model_only_last.ckpt'
if os.path.exists(ckpt_path):
checkpoint = torch_load_dist(ckpt_path, map_location='cpu')
else:
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps)
if checkpoint is not None:
state_dict_all = {
k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in checkpoint["state_dict"].items()}
if not isinstance(cur_model, list):
cur_models = [cur_model]
model_names = [model_name]
else:
cur_models = cur_model
model_names = model_name
for model_name, cur_model in zip(model_names, cur_models):
if isinstance(cur_model, DistributedDataParallel):
cur_model = cur_model.module
device = next(cur_model.parameters()).device
if '.' not in model_name:
state_dict = state_dict_all[model_name]
else:
base_model_name = model_name.split('.')[0]
rest_model_name = model_name[len(base_model_name) + 1:]
state_dict = {
k[len(rest_model_name) + 1:]: v for k, v in state_dict_all[base_model_name].items()
if k.startswith(f'{rest_model_name}.')}
state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in state_dict.items()}
if not strict and delete_unmatch:
try:
cur_model.load_state_dict(state_dict, strict=True)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}' with strict=True.")
except:
cur_model_state_dict = cur_model.state_dict()
cur_model_state_dict = {k.replace('module.', '').replace('_orig_mod.', ''): v for k, v in
cur_model_state_dict.items()}
unmatched_keys = []
for key, param in state_dict.items():
if key in cur_model_state_dict:
new_param = cur_model_state_dict[key]
if new_param.shape != param.shape:
unmatched_keys.append(key)
print("| Unmatched keys: ", key, "cur model: ", new_param.shape,
"ckpt model: ", param.shape)
for key in unmatched_keys:
del state_dict[key]
load_results = cur_model.load_state_dict(state_dict, strict=strict)
cur_model.to(device)
if not silent:
print(f"| loaded '{model_name}' from '{ckpt_path}'.")
missing_keys, unexpected_keys = load_results.missing_keys, load_results.unexpected_keys
print(f"| Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
if load_opt:
optimizer_states = checkpoint['optimizer_states']
assert len(opts) == len(optimizer_states)
for optimizer, opt_state in zip(opts, optimizer_states):
opt_state = {k.replace('_orig_mod.', ''): v for k, v in opt_state.items()}
if optimizer is None:
return
try:
optimizer.load_state_dict(opt_state)
for i, state in enumerate(optimizer.state.values()):
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
except ValueError:
print(f"| WARMING: optimizer {optimizer} parameters not match !!!")
return checkpoint.get('global_step', 0)
else:
e_msg = f"| ckpt not found in {base_dir}."
if force:
assert False, e_msg
else:
print(e_msg)
def load_with_size_mismatch(model, state_dict, prefix=""):
current_model_dict = model.state_dict()
cm_keys = current_model_dict.keys()
mismatch_keys = {k.replace(prefix, "") for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() != current_model_dict[k.replace(prefix, "")].size()}
new_state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if k.replace(prefix, "") in cm_keys and v.size() == current_model_dict[k.replace(prefix, "")].size()}
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print(f"| mismatch keys: ", mismatch_keys)
if len(missing_keys) > 0:
print(f"| missing_keys in dit: {missing_keys}")
if len(unexpected_keys) > 0:
print(f"| unexpected_keys in dit: {unexpected_keys}")
|