alexnasa commited on
Commit
d047245
·
verified ·
1 Parent(s): 94bc39b

Update OmniAvatar/models/model_manager.py

Browse files
Files changed (1) hide show
  1. OmniAvatar/models/model_manager.py +474 -474
OmniAvatar/models/model_manager.py CHANGED
@@ -1,474 +1,474 @@
1
- import os, torch, json, importlib
2
- from typing import List
3
- import torch.nn as nn
4
- from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs
5
- from ..utils.io_utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix, smart_load_weights
6
-
7
- class GeneralLoRAFromPeft:
8
-
9
- def get_name_dict(self, lora_state_dict):
10
- lora_name_dict = {}
11
- for key in lora_state_dict:
12
- if ".lora_B." not in key:
13
- continue
14
- keys = key.split(".")
15
- if len(keys) > keys.index("lora_B") + 2:
16
- keys.pop(keys.index("lora_B") + 1)
17
- keys.pop(keys.index("lora_B"))
18
- if keys[0] == "diffusion_model":
19
- keys.pop(0)
20
- target_name = ".".join(keys)
21
- lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
22
- return lora_name_dict
23
-
24
-
25
- def match(self, model: torch.nn.Module, state_dict_lora):
26
- lora_name_dict = self.get_name_dict(state_dict_lora)
27
- model_name_dict = {name: None for name, _ in model.named_parameters()}
28
- matched_num = sum([i in model_name_dict for i in lora_name_dict])
29
- if matched_num == len(lora_name_dict):
30
- return "", ""
31
- else:
32
- return None
33
-
34
-
35
- def fetch_device_and_dtype(self, state_dict):
36
- device, dtype = None, None
37
- for name, param in state_dict.items():
38
- device, dtype = param.device, param.dtype
39
- break
40
- computation_device = device
41
- computation_dtype = dtype
42
- if computation_device == torch.device("cpu"):
43
- if torch.cuda.is_available():
44
- computation_device = torch.device("cuda")
45
- if computation_dtype == torch.float8_e4m3fn:
46
- computation_dtype = torch.float32
47
- return device, dtype, computation_device, computation_dtype
48
-
49
-
50
- def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
51
- state_dict_model = model.state_dict()
52
- device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
53
- lora_name_dict = self.get_name_dict(state_dict_lora)
54
- for name in lora_name_dict:
55
- weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
56
- weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
57
- if len(weight_up.shape) == 4:
58
- weight_up = weight_up.squeeze(3).squeeze(2)
59
- weight_down = weight_down.squeeze(3).squeeze(2)
60
- weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
61
- else:
62
- weight_lora = alpha * torch.mm(weight_up, weight_down)
63
- weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
64
- weight_patched = weight_model + weight_lora
65
- state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
66
- print(f" {len(lora_name_dict)} tensors are updated.")
67
- model.load_state_dict(state_dict_model)
68
-
69
-
70
- def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer):
71
- loaded_model_names, loaded_models = [], []
72
- for model_name, model_class in zip(model_names, model_classes):
73
- print(f" model_name: {model_name} model_class: {model_class.__name__}")
74
- state_dict_converter = model_class.state_dict_converter()
75
- if model_resource == "civitai":
76
- state_dict_results = state_dict_converter.from_civitai(state_dict)
77
- elif model_resource == "diffusers":
78
- state_dict_results = state_dict_converter.from_diffusers(state_dict)
79
- if isinstance(state_dict_results, tuple):
80
- model_state_dict, extra_kwargs = state_dict_results
81
- print(f" This model is initialized with extra kwargs: {extra_kwargs}")
82
- else:
83
- model_state_dict, extra_kwargs = state_dict_results, {}
84
- torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
85
- with init_weights_on_device():
86
- model = model_class(**extra_kwargs)
87
- if hasattr(model, "eval"):
88
- model = model.eval()
89
- if not infer: # 训练才初始化
90
- model = model.to_empty(device=torch.device("cuda"))
91
- for name, param in model.named_parameters():
92
- if param.dim() > 1: # 通常只对权重矩阵而不是偏置做初始化
93
- nn.init.xavier_uniform_(param, gain=0.05)
94
- else:
95
- nn.init.zeros_(param)
96
- else:
97
- model = model.to_empty(device=device)
98
- model, _, _ = smart_load_weights(model, model_state_dict)
99
- # model.load_state_dict(model_state_dict, assign=True, strict=False)
100
- model = model.to(dtype=torch_dtype, device=device)
101
- loaded_model_names.append(model_name)
102
- loaded_models.append(model)
103
- return loaded_model_names, loaded_models
104
-
105
-
106
- def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
107
- loaded_model_names, loaded_models = [], []
108
- for model_name, model_class in zip(model_names, model_classes):
109
- if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
110
- model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
111
- else:
112
- model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
113
- if torch_dtype == torch.float16 and hasattr(model, "half"):
114
- model = model.half()
115
- try:
116
- model = model.to(device=device)
117
- except:
118
- pass
119
- loaded_model_names.append(model_name)
120
- loaded_models.append(model)
121
- return loaded_model_names, loaded_models
122
-
123
-
124
- def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
125
- print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
126
- base_state_dict = base_model.state_dict()
127
- base_model.to("cpu")
128
- del base_model
129
- model = model_class(**extra_kwargs)
130
- model.load_state_dict(base_state_dict, strict=False)
131
- model.load_state_dict(state_dict, strict=False)
132
- model.to(dtype=torch_dtype, device=device)
133
- return model
134
-
135
-
136
- def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
137
- loaded_model_names, loaded_models = [], []
138
- for model_name, model_class in zip(model_names, model_classes):
139
- while True:
140
- for model_id in range(len(model_manager.model)):
141
- base_model_name = model_manager.model_name[model_id]
142
- if base_model_name == model_name:
143
- base_model_path = model_manager.model_path[model_id]
144
- base_model = model_manager.model[model_id]
145
- print(f" Adding patch model to {base_model_name} ({base_model_path})")
146
- patched_model = load_single_patch_model_from_single_file(
147
- state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
148
- loaded_model_names.append(base_model_name)
149
- loaded_models.append(patched_model)
150
- model_manager.model.pop(model_id)
151
- model_manager.model_path.pop(model_id)
152
- model_manager.model_name.pop(model_id)
153
- break
154
- else:
155
- break
156
- return loaded_model_names, loaded_models
157
-
158
-
159
-
160
- class ModelDetectorTemplate:
161
- def __init__(self):
162
- pass
163
-
164
- def match(self, file_path="", state_dict={}):
165
- return False
166
-
167
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
168
- return [], []
169
-
170
-
171
-
172
- class ModelDetectorFromSingleFile:
173
- def __init__(self, model_loader_configs=[]):
174
- self.keys_hash_with_shape_dict = {}
175
- self.keys_hash_dict = {}
176
- for metadata in model_loader_configs:
177
- self.add_model_metadata(*metadata)
178
-
179
-
180
- def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
181
- self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
182
- if keys_hash is not None:
183
- self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
184
-
185
-
186
- def match(self, file_path="", state_dict={}):
187
- if isinstance(file_path, str) and os.path.isdir(file_path):
188
- return False
189
- if len(state_dict) == 0:
190
- state_dict = load_state_dict(file_path)
191
- keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
192
- if keys_hash_with_shape in self.keys_hash_with_shape_dict:
193
- return True
194
- keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
195
- if keys_hash in self.keys_hash_dict:
196
- return True
197
- return False
198
-
199
-
200
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, infer=False, **kwargs):
201
- if len(state_dict) == 0:
202
- state_dict = load_state_dict(file_path)
203
-
204
- # Load models with strict matching
205
- keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
206
- if keys_hash_with_shape in self.keys_hash_with_shape_dict:
207
- model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
208
- loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
209
- return loaded_model_names, loaded_models
210
-
211
- # Load models without strict matching
212
- # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
213
- keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
214
- if keys_hash in self.keys_hash_dict:
215
- model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
216
- loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
217
- return loaded_model_names, loaded_models
218
-
219
- return loaded_model_names, loaded_models
220
-
221
-
222
-
223
- class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
224
- def __init__(self, model_loader_configs=[]):
225
- super().__init__(model_loader_configs)
226
-
227
-
228
- def match(self, file_path="", state_dict={}):
229
- if isinstance(file_path, str) and os.path.isdir(file_path):
230
- return False
231
- if len(state_dict) == 0:
232
- state_dict = load_state_dict(file_path)
233
- splited_state_dict = split_state_dict_with_prefix(state_dict)
234
- for sub_state_dict in splited_state_dict:
235
- if super().match(file_path, sub_state_dict):
236
- return True
237
- return False
238
-
239
-
240
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
241
- # Split the state_dict and load from each component
242
- splited_state_dict = split_state_dict_with_prefix(state_dict)
243
- valid_state_dict = {}
244
- for sub_state_dict in splited_state_dict:
245
- if super().match(file_path, sub_state_dict):
246
- valid_state_dict.update(sub_state_dict)
247
- if super().match(file_path, valid_state_dict):
248
- loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
249
- else:
250
- loaded_model_names, loaded_models = [], []
251
- for sub_state_dict in splited_state_dict:
252
- if super().match(file_path, sub_state_dict):
253
- loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
254
- loaded_model_names += loaded_model_names_
255
- loaded_models += loaded_models_
256
- return loaded_model_names, loaded_models
257
-
258
-
259
-
260
- class ModelDetectorFromHuggingfaceFolder:
261
- def __init__(self, model_loader_configs=[]):
262
- self.architecture_dict = {}
263
- for metadata in model_loader_configs:
264
- self.add_model_metadata(*metadata)
265
-
266
-
267
- def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
268
- self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
269
-
270
-
271
- def match(self, file_path="", state_dict={}):
272
- if not isinstance(file_path, str) or os.path.isfile(file_path):
273
- return False
274
- file_list = os.listdir(file_path)
275
- if "config.json" not in file_list:
276
- return False
277
- with open(os.path.join(file_path, "config.json"), "r") as f:
278
- config = json.load(f)
279
- if "architectures" not in config and "_class_name" not in config:
280
- return False
281
- return True
282
-
283
-
284
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
285
- with open(os.path.join(file_path, "config.json"), "r") as f:
286
- config = json.load(f)
287
- loaded_model_names, loaded_models = [], []
288
- architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
289
- for architecture in architectures:
290
- huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
291
- if redirected_architecture is not None:
292
- architecture = redirected_architecture
293
- model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
294
- loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
295
- loaded_model_names += loaded_model_names_
296
- loaded_models += loaded_models_
297
- return loaded_model_names, loaded_models
298
-
299
-
300
-
301
- class ModelDetectorFromPatchedSingleFile:
302
- def __init__(self, model_loader_configs=[]):
303
- self.keys_hash_with_shape_dict = {}
304
- for metadata in model_loader_configs:
305
- self.add_model_metadata(*metadata)
306
-
307
-
308
- def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
309
- self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
310
-
311
-
312
- def match(self, file_path="", state_dict={}):
313
- if not isinstance(file_path, str) or os.path.isdir(file_path):
314
- return False
315
- if len(state_dict) == 0:
316
- state_dict = load_state_dict(file_path)
317
- keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
318
- if keys_hash_with_shape in self.keys_hash_with_shape_dict:
319
- return True
320
- return False
321
-
322
-
323
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
324
- if len(state_dict) == 0:
325
- state_dict = load_state_dict(file_path)
326
-
327
- # Load models with strict matching
328
- loaded_model_names, loaded_models = [], []
329
- keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
330
- if keys_hash_with_shape in self.keys_hash_with_shape_dict:
331
- model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
332
- loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
333
- state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
334
- loaded_model_names += loaded_model_names_
335
- loaded_models += loaded_models_
336
- return loaded_model_names, loaded_models
337
-
338
-
339
-
340
- class ModelManager:
341
- def __init__(
342
- self,
343
- torch_dtype=torch.float16,
344
- device="cuda",
345
- model_id_list: List = [],
346
- downloading_priority: List = ["ModelScope", "HuggingFace"],
347
- file_path_list: List[str] = [],
348
- infer: bool = False
349
- ):
350
- self.torch_dtype = torch_dtype
351
- self.device = device
352
- self.model = []
353
- self.model_path = []
354
- self.model_name = []
355
- self.infer = infer
356
- downloaded_files = []
357
- self.model_detector = [
358
- ModelDetectorFromSingleFile(model_loader_configs),
359
- ModelDetectorFromSplitedSingleFile(model_loader_configs),
360
- ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
361
- ]
362
- self.load_models(downloaded_files + file_path_list)
363
-
364
- def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
365
- if isinstance(file_path, list):
366
- for file_path_ in file_path:
367
- self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
368
- else:
369
- print(f"Loading LoRA models from file: {file_path}")
370
- is_loaded = False
371
- if len(state_dict) == 0:
372
- state_dict = load_state_dict(file_path)
373
- for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
374
- lora = GeneralLoRAFromPeft()
375
- match_results = lora.match(model, state_dict)
376
- if match_results is not None:
377
- print(f" Adding LoRA to {model_name} ({model_path}).")
378
- lora_prefix, model_resource = match_results
379
- lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
380
-
381
-
382
-
383
- def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
384
- print(f"Loading models from file: {file_path}")
385
- if len(state_dict) == 0:
386
- state_dict = load_state_dict(file_path)
387
- model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device, self.infer)
388
- for model_name, model in zip(model_names, models):
389
- self.model.append(model)
390
- self.model_path.append(file_path)
391
- self.model_name.append(model_name)
392
- print(f" The following models are loaded: {model_names}.")
393
-
394
-
395
- def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
396
- print(f"Loading models from folder: {file_path}")
397
- model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
398
- for model_name, model in zip(model_names, models):
399
- self.model.append(model)
400
- self.model_path.append(file_path)
401
- self.model_name.append(model_name)
402
- print(f" The following models are loaded: {model_names}.")
403
-
404
-
405
- def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
406
- print(f"Loading patch models from file: {file_path}")
407
- model_names, models = load_patch_model_from_single_file(
408
- state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
409
- for model_name, model in zip(model_names, models):
410
- self.model.append(model)
411
- self.model_path.append(file_path)
412
- self.model_name.append(model_name)
413
- print(f" The following patched models are loaded: {model_names}.")
414
-
415
- def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
416
- print(f"Loading models from: {file_path}")
417
- if device is None: device = self.device
418
- if torch_dtype is None: torch_dtype = self.torch_dtype
419
- if isinstance(file_path, list):
420
- state_dict = {}
421
- for path in file_path:
422
- state_dict.update(load_state_dict(path))
423
- elif os.path.isfile(file_path):
424
- state_dict = load_state_dict(file_path)
425
- else:
426
- state_dict = None
427
- for model_detector in self.model_detector:
428
- if model_detector.match(file_path, state_dict):
429
- model_names, models = model_detector.load(
430
- file_path, state_dict,
431
- device=device, torch_dtype=torch_dtype,
432
- allowed_model_names=model_names, model_manager=self, infer=self.infer
433
- )
434
- for model_name, model in zip(model_names, models):
435
- self.model.append(model)
436
- self.model_path.append(file_path)
437
- self.model_name.append(model_name)
438
- print(f" The following models are loaded: {model_names}.")
439
- break
440
- else:
441
- print(f" We cannot detect the model type. No models are loaded.")
442
-
443
-
444
- def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
445
- for file_path in file_path_list:
446
- self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
447
-
448
-
449
- def fetch_model(self, model_name, file_path=None, require_model_path=False):
450
- fetched_models = []
451
- fetched_model_paths = []
452
- for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
453
- if file_path is not None and file_path != model_path:
454
- continue
455
- if model_name == model_name_:
456
- fetched_models.append(model)
457
- fetched_model_paths.append(model_path)
458
- if len(fetched_models) == 0:
459
- print(f"No {model_name} models available.")
460
- return None
461
- if len(fetched_models) == 1:
462
- print(f"Using {model_name} from {fetched_model_paths[0]}.")
463
- else:
464
- print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
465
- if require_model_path:
466
- return fetched_models[0], fetched_model_paths[0]
467
- else:
468
- return fetched_models[0]
469
-
470
-
471
- def to(self, device):
472
- for model in self.model:
473
- model.to(device)
474
-
 
1
+ import os, torch, json, importlib
2
+ from typing import List
3
+ import torch.nn as nn
4
+ from ..configs.model_config import model_loader_configs
5
+ from ..utils.io_utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix, smart_load_weights
6
+
7
+ class GeneralLoRAFromPeft:
8
+
9
+ def get_name_dict(self, lora_state_dict):
10
+ lora_name_dict = {}
11
+ for key in lora_state_dict:
12
+ if ".lora_B." not in key:
13
+ continue
14
+ keys = key.split(".")
15
+ if len(keys) > keys.index("lora_B") + 2:
16
+ keys.pop(keys.index("lora_B") + 1)
17
+ keys.pop(keys.index("lora_B"))
18
+ if keys[0] == "diffusion_model":
19
+ keys.pop(0)
20
+ target_name = ".".join(keys)
21
+ lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
22
+ return lora_name_dict
23
+
24
+
25
+ def match(self, model: torch.nn.Module, state_dict_lora):
26
+ lora_name_dict = self.get_name_dict(state_dict_lora)
27
+ model_name_dict = {name: None for name, _ in model.named_parameters()}
28
+ matched_num = sum([i in model_name_dict for i in lora_name_dict])
29
+ if matched_num == len(lora_name_dict):
30
+ return "", ""
31
+ else:
32
+ return None
33
+
34
+
35
+ def fetch_device_and_dtype(self, state_dict):
36
+ device, dtype = None, None
37
+ for name, param in state_dict.items():
38
+ device, dtype = param.device, param.dtype
39
+ break
40
+ computation_device = device
41
+ computation_dtype = dtype
42
+ if computation_device == torch.device("cpu"):
43
+ if torch.cuda.is_available():
44
+ computation_device = torch.device("cuda")
45
+ if computation_dtype == torch.float8_e4m3fn:
46
+ computation_dtype = torch.float32
47
+ return device, dtype, computation_device, computation_dtype
48
+
49
+
50
+ def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
51
+ state_dict_model = model.state_dict()
52
+ device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
53
+ lora_name_dict = self.get_name_dict(state_dict_lora)
54
+ for name in lora_name_dict:
55
+ weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
56
+ weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
57
+ if len(weight_up.shape) == 4:
58
+ weight_up = weight_up.squeeze(3).squeeze(2)
59
+ weight_down = weight_down.squeeze(3).squeeze(2)
60
+ weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
61
+ else:
62
+ weight_lora = alpha * torch.mm(weight_up, weight_down)
63
+ weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
64
+ weight_patched = weight_model + weight_lora
65
+ state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
66
+ print(f" {len(lora_name_dict)} tensors are updated.")
67
+ model.load_state_dict(state_dict_model)
68
+
69
+
70
+ def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer):
71
+ loaded_model_names, loaded_models = [], []
72
+ for model_name, model_class in zip(model_names, model_classes):
73
+ print(f" model_name: {model_name} model_class: {model_class.__name__}")
74
+ state_dict_converter = model_class.state_dict_converter()
75
+ if model_resource == "civitai":
76
+ state_dict_results = state_dict_converter.from_civitai(state_dict)
77
+ elif model_resource == "diffusers":
78
+ state_dict_results = state_dict_converter.from_diffusers(state_dict)
79
+ if isinstance(state_dict_results, tuple):
80
+ model_state_dict, extra_kwargs = state_dict_results
81
+ print(f" This model is initialized with extra kwargs: {extra_kwargs}")
82
+ else:
83
+ model_state_dict, extra_kwargs = state_dict_results, {}
84
+ torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
85
+ with init_weights_on_device():
86
+ model = model_class(**extra_kwargs)
87
+ if hasattr(model, "eval"):
88
+ model = model.eval()
89
+ if not infer: # 训练才初始化
90
+ model = model.to_empty(device=torch.device("cuda"))
91
+ for name, param in model.named_parameters():
92
+ if param.dim() > 1: # 通常只对权重矩阵而不是偏置做初始化
93
+ nn.init.xavier_uniform_(param, gain=0.05)
94
+ else:
95
+ nn.init.zeros_(param)
96
+ else:
97
+ model = model.to_empty(device=device)
98
+ model, _, _ = smart_load_weights(model, model_state_dict)
99
+ # model.load_state_dict(model_state_dict, assign=True, strict=False)
100
+ model = model.to(dtype=torch_dtype, device=device)
101
+ loaded_model_names.append(model_name)
102
+ loaded_models.append(model)
103
+ return loaded_model_names, loaded_models
104
+
105
+
106
+ def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
107
+ loaded_model_names, loaded_models = [], []
108
+ for model_name, model_class in zip(model_names, model_classes):
109
+ if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
110
+ model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
111
+ else:
112
+ model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
113
+ if torch_dtype == torch.float16 and hasattr(model, "half"):
114
+ model = model.half()
115
+ try:
116
+ model = model.to(device=device)
117
+ except:
118
+ pass
119
+ loaded_model_names.append(model_name)
120
+ loaded_models.append(model)
121
+ return loaded_model_names, loaded_models
122
+
123
+
124
+ def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
125
+ print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
126
+ base_state_dict = base_model.state_dict()
127
+ base_model.to("cpu")
128
+ del base_model
129
+ model = model_class(**extra_kwargs)
130
+ model.load_state_dict(base_state_dict, strict=False)
131
+ model.load_state_dict(state_dict, strict=False)
132
+ model.to(dtype=torch_dtype, device=device)
133
+ return model
134
+
135
+
136
+ def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
137
+ loaded_model_names, loaded_models = [], []
138
+ for model_name, model_class in zip(model_names, model_classes):
139
+ while True:
140
+ for model_id in range(len(model_manager.model)):
141
+ base_model_name = model_manager.model_name[model_id]
142
+ if base_model_name == model_name:
143
+ base_model_path = model_manager.model_path[model_id]
144
+ base_model = model_manager.model[model_id]
145
+ print(f" Adding patch model to {base_model_name} ({base_model_path})")
146
+ patched_model = load_single_patch_model_from_single_file(
147
+ state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
148
+ loaded_model_names.append(base_model_name)
149
+ loaded_models.append(patched_model)
150
+ model_manager.model.pop(model_id)
151
+ model_manager.model_path.pop(model_id)
152
+ model_manager.model_name.pop(model_id)
153
+ break
154
+ else:
155
+ break
156
+ return loaded_model_names, loaded_models
157
+
158
+
159
+
160
+ class ModelDetectorTemplate:
161
+ def __init__(self):
162
+ pass
163
+
164
+ def match(self, file_path="", state_dict={}):
165
+ return False
166
+
167
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
168
+ return [], []
169
+
170
+
171
+
172
+ class ModelDetectorFromSingleFile:
173
+ def __init__(self, model_loader_configs=[]):
174
+ self.keys_hash_with_shape_dict = {}
175
+ self.keys_hash_dict = {}
176
+ for metadata in model_loader_configs:
177
+ self.add_model_metadata(*metadata)
178
+
179
+
180
+ def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
181
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
182
+ if keys_hash is not None:
183
+ self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
184
+
185
+
186
+ def match(self, file_path="", state_dict={}):
187
+ if isinstance(file_path, str) and os.path.isdir(file_path):
188
+ return False
189
+ if len(state_dict) == 0:
190
+ state_dict = load_state_dict(file_path)
191
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
192
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
193
+ return True
194
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
195
+ if keys_hash in self.keys_hash_dict:
196
+ return True
197
+ return False
198
+
199
+
200
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, infer=False, **kwargs):
201
+ if len(state_dict) == 0:
202
+ state_dict = load_state_dict(file_path)
203
+
204
+ # Load models with strict matching
205
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
206
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
207
+ model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
208
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
209
+ return loaded_model_names, loaded_models
210
+
211
+ # Load models without strict matching
212
+ # (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
213
+ keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
214
+ if keys_hash in self.keys_hash_dict:
215
+ model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
216
+ loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device, infer)
217
+ return loaded_model_names, loaded_models
218
+
219
+ return loaded_model_names, loaded_models
220
+
221
+
222
+
223
+ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
224
+ def __init__(self, model_loader_configs=[]):
225
+ super().__init__(model_loader_configs)
226
+
227
+
228
+ def match(self, file_path="", state_dict={}):
229
+ if isinstance(file_path, str) and os.path.isdir(file_path):
230
+ return False
231
+ if len(state_dict) == 0:
232
+ state_dict = load_state_dict(file_path)
233
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
234
+ for sub_state_dict in splited_state_dict:
235
+ if super().match(file_path, sub_state_dict):
236
+ return True
237
+ return False
238
+
239
+
240
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
241
+ # Split the state_dict and load from each component
242
+ splited_state_dict = split_state_dict_with_prefix(state_dict)
243
+ valid_state_dict = {}
244
+ for sub_state_dict in splited_state_dict:
245
+ if super().match(file_path, sub_state_dict):
246
+ valid_state_dict.update(sub_state_dict)
247
+ if super().match(file_path, valid_state_dict):
248
+ loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
249
+ else:
250
+ loaded_model_names, loaded_models = [], []
251
+ for sub_state_dict in splited_state_dict:
252
+ if super().match(file_path, sub_state_dict):
253
+ loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
254
+ loaded_model_names += loaded_model_names_
255
+ loaded_models += loaded_models_
256
+ return loaded_model_names, loaded_models
257
+
258
+
259
+
260
+ class ModelDetectorFromHuggingfaceFolder:
261
+ def __init__(self, model_loader_configs=[]):
262
+ self.architecture_dict = {}
263
+ for metadata in model_loader_configs:
264
+ self.add_model_metadata(*metadata)
265
+
266
+
267
+ def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
268
+ self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
269
+
270
+
271
+ def match(self, file_path="", state_dict={}):
272
+ if not isinstance(file_path, str) or os.path.isfile(file_path):
273
+ return False
274
+ file_list = os.listdir(file_path)
275
+ if "config.json" not in file_list:
276
+ return False
277
+ with open(os.path.join(file_path, "config.json"), "r") as f:
278
+ config = json.load(f)
279
+ if "architectures" not in config and "_class_name" not in config:
280
+ return False
281
+ return True
282
+
283
+
284
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
285
+ with open(os.path.join(file_path, "config.json"), "r") as f:
286
+ config = json.load(f)
287
+ loaded_model_names, loaded_models = [], []
288
+ architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
289
+ for architecture in architectures:
290
+ huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
291
+ if redirected_architecture is not None:
292
+ architecture = redirected_architecture
293
+ model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
294
+ loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
295
+ loaded_model_names += loaded_model_names_
296
+ loaded_models += loaded_models_
297
+ return loaded_model_names, loaded_models
298
+
299
+
300
+
301
+ class ModelDetectorFromPatchedSingleFile:
302
+ def __init__(self, model_loader_configs=[]):
303
+ self.keys_hash_with_shape_dict = {}
304
+ for metadata in model_loader_configs:
305
+ self.add_model_metadata(*metadata)
306
+
307
+
308
+ def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
309
+ self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
310
+
311
+
312
+ def match(self, file_path="", state_dict={}):
313
+ if not isinstance(file_path, str) or os.path.isdir(file_path):
314
+ return False
315
+ if len(state_dict) == 0:
316
+ state_dict = load_state_dict(file_path)
317
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
318
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
319
+ return True
320
+ return False
321
+
322
+
323
+ def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
324
+ if len(state_dict) == 0:
325
+ state_dict = load_state_dict(file_path)
326
+
327
+ # Load models with strict matching
328
+ loaded_model_names, loaded_models = [], []
329
+ keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
330
+ if keys_hash_with_shape in self.keys_hash_with_shape_dict:
331
+ model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
332
+ loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
333
+ state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
334
+ loaded_model_names += loaded_model_names_
335
+ loaded_models += loaded_models_
336
+ return loaded_model_names, loaded_models
337
+
338
+
339
+
340
+ class ModelManager:
341
+ def __init__(
342
+ self,
343
+ torch_dtype=torch.float16,
344
+ device="cuda",
345
+ model_id_list: List = [],
346
+ downloading_priority: List = ["ModelScope", "HuggingFace"],
347
+ file_path_list: List[str] = [],
348
+ infer: bool = False
349
+ ):
350
+ self.torch_dtype = torch_dtype
351
+ self.device = device
352
+ self.model = []
353
+ self.model_path = []
354
+ self.model_name = []
355
+ self.infer = infer
356
+ downloaded_files = []
357
+ self.model_detector = [
358
+ ModelDetectorFromSingleFile(model_loader_configs),
359
+ ModelDetectorFromSplitedSingleFile(model_loader_configs),
360
+ ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
361
+ ]
362
+ self.load_models(downloaded_files + file_path_list)
363
+
364
+ def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
365
+ if isinstance(file_path, list):
366
+ for file_path_ in file_path:
367
+ self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
368
+ else:
369
+ print(f"Loading LoRA models from file: {file_path}")
370
+ is_loaded = False
371
+ if len(state_dict) == 0:
372
+ state_dict = load_state_dict(file_path)
373
+ for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
374
+ lora = GeneralLoRAFromPeft()
375
+ match_results = lora.match(model, state_dict)
376
+ if match_results is not None:
377
+ print(f" Adding LoRA to {model_name} ({model_path}).")
378
+ lora_prefix, model_resource = match_results
379
+ lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
380
+
381
+
382
+
383
+ def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
384
+ print(f"Loading models from file: {file_path}")
385
+ if len(state_dict) == 0:
386
+ state_dict = load_state_dict(file_path)
387
+ model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device, self.infer)
388
+ for model_name, model in zip(model_names, models):
389
+ self.model.append(model)
390
+ self.model_path.append(file_path)
391
+ self.model_name.append(model_name)
392
+ print(f" The following models are loaded: {model_names}.")
393
+
394
+
395
+ def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
396
+ print(f"Loading models from folder: {file_path}")
397
+ model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
398
+ for model_name, model in zip(model_names, models):
399
+ self.model.append(model)
400
+ self.model_path.append(file_path)
401
+ self.model_name.append(model_name)
402
+ print(f" The following models are loaded: {model_names}.")
403
+
404
+
405
+ def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
406
+ print(f"Loading patch models from file: {file_path}")
407
+ model_names, models = load_patch_model_from_single_file(
408
+ state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
409
+ for model_name, model in zip(model_names, models):
410
+ self.model.append(model)
411
+ self.model_path.append(file_path)
412
+ self.model_name.append(model_name)
413
+ print(f" The following patched models are loaded: {model_names}.")
414
+
415
+ def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
416
+ print(f"Loading models from: {file_path}")
417
+ if device is None: device = self.device
418
+ if torch_dtype is None: torch_dtype = self.torch_dtype
419
+ if isinstance(file_path, list):
420
+ state_dict = {}
421
+ for path in file_path:
422
+ state_dict.update(load_state_dict(path))
423
+ elif os.path.isfile(file_path):
424
+ state_dict = load_state_dict(file_path)
425
+ else:
426
+ state_dict = None
427
+ for model_detector in self.model_detector:
428
+ if model_detector.match(file_path, state_dict):
429
+ model_names, models = model_detector.load(
430
+ file_path, state_dict,
431
+ device=device, torch_dtype=torch_dtype,
432
+ allowed_model_names=model_names, model_manager=self, infer=self.infer
433
+ )
434
+ for model_name, model in zip(model_names, models):
435
+ self.model.append(model)
436
+ self.model_path.append(file_path)
437
+ self.model_name.append(model_name)
438
+ print(f" The following models are loaded: {model_names}.")
439
+ break
440
+ else:
441
+ print(f" We cannot detect the model type. No models are loaded.")
442
+
443
+
444
+ def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
445
+ for file_path in file_path_list:
446
+ self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
447
+
448
+
449
+ def fetch_model(self, model_name, file_path=None, require_model_path=False):
450
+ fetched_models = []
451
+ fetched_model_paths = []
452
+ for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
453
+ if file_path is not None and file_path != model_path:
454
+ continue
455
+ if model_name == model_name_:
456
+ fetched_models.append(model)
457
+ fetched_model_paths.append(model_path)
458
+ if len(fetched_models) == 0:
459
+ print(f"No {model_name} models available.")
460
+ return None
461
+ if len(fetched_models) == 1:
462
+ print(f"Using {model_name} from {fetched_model_paths[0]}.")
463
+ else:
464
+ print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
465
+ if require_model_path:
466
+ return fetched_models[0], fetched_model_paths[0]
467
+ else:
468
+ return fetched_models[0]
469
+
470
+
471
+ def to(self, device):
472
+ for model in self.model:
473
+ model.to(device)
474
+