Harshveer commited on
Commit
c801664
·
1 Parent(s): 1ef0db4

Update modules/sd_models.py

Browse files
Files changed (1) hide show
  1. modules/sd_models.py +496 -495
modules/sd_models.py CHANGED
@@ -1,495 +1,496 @@
1
- import collections
2
- import os.path
3
- import sys
4
- import gc
5
- import torch
6
- import re
7
- import safetensors.torch
8
- from omegaconf import OmegaConf
9
- from os import mkdir
10
- from urllib import request
11
- import ldm.modules.midas as midas
12
-
13
- from ldm.util import instantiate_from_config
14
-
15
- from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
16
- from modules.paths import models_path
17
- from modules.sd_hijack_inpainting import do_inpainting_hijack
18
- from modules.timer import Timer
19
-
20
- model_dir = "Stable-diffusion"
21
- model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
22
-
23
- checkpoints_list = {}
24
- checkpoint_alisases = {}
25
- checkpoints_loaded = collections.OrderedDict()
26
-
27
-
28
- class CheckpointInfo:
29
- def __init__(self, filename):
30
- self.filename = filename
31
- abspath = os.path.abspath(filename)
32
-
33
- if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
34
- name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
35
- elif abspath.startswith(model_path):
36
- name = abspath.replace(model_path, '')
37
- else:
38
- name = os.path.basename(filename)
39
-
40
- if name.startswith("\\") or name.startswith("/"):
41
- name = name[1:]
42
-
43
- self.name = name
44
- self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
45
- self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
46
- self.hash = model_hash(filename)
47
-
48
- self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
49
- self.shorthash = self.sha256[0:10] if self.sha256 else None
50
-
51
- self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
52
-
53
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
54
-
55
- def register(self):
56
- checkpoints_list[self.title] = self
57
- for id in self.ids:
58
- checkpoint_alisases[id] = self
59
-
60
- def calculate_shorthash(self):
61
- self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
62
- if self.sha256 is None:
63
- return
64
-
65
- self.shorthash = self.sha256[0:10]
66
-
67
- if self.shorthash not in self.ids:
68
- self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
69
-
70
- checkpoints_list.pop(self.title)
71
- self.title = f'{self.name} [{self.shorthash}]'
72
- self.register()
73
-
74
- return self.shorthash
75
-
76
-
77
- try:
78
- # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
79
-
80
- from transformers import logging, CLIPModel
81
-
82
- logging.set_verbosity_error()
83
- except Exception:
84
- pass
85
-
86
-
87
- def setup_model():
88
- if not os.path.exists(model_path):
89
- os.makedirs(model_path)
90
-
91
- list_models()
92
- enable_midas_autodownload()
93
-
94
-
95
- def checkpoint_tiles():
96
- def convert(name):
97
- return int(name) if name.isdigit() else name.lower()
98
-
99
- def alphanumeric_key(key):
100
- return [convert(c) for c in re.split('([0-9]+)', key)]
101
-
102
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
103
-
104
-
105
- def list_models():
106
- checkpoints_list.clear()
107
- checkpoint_alisases.clear()
108
-
109
- cmd_ckpt = shared.cmd_opts.ckpt
110
- if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
111
- model_url = None
112
- else:
113
- model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
114
-
115
- model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
116
-
117
- if os.path.exists(cmd_ckpt):
118
- checkpoint_info = CheckpointInfo(cmd_ckpt)
119
- checkpoint_info.register()
120
-
121
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
122
- elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
123
- print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
124
-
125
- for filename in model_list:
126
- checkpoint_info = CheckpointInfo(filename)
127
- checkpoint_info.register()
128
-
129
-
130
- def get_closet_checkpoint_match(search_string):
131
- checkpoint_info = checkpoint_alisases.get(search_string, None)
132
- if checkpoint_info is not None:
133
- return checkpoint_info
134
-
135
- found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
136
- if found:
137
- return found[0]
138
-
139
- return None
140
-
141
-
142
- def model_hash(filename):
143
- """old hash that only looks at a small part of the file and is prone to collisions"""
144
-
145
- try:
146
- with open(filename, "rb") as file:
147
- import hashlib
148
- m = hashlib.sha256()
149
-
150
- file.seek(0x100000)
151
- m.update(file.read(0x10000))
152
- return m.hexdigest()[0:8]
153
- except FileNotFoundError:
154
- return 'NOFILE'
155
-
156
-
157
- def select_checkpoint():
158
- model_checkpoint = shared.opts.sd_model_checkpoint
159
-
160
- checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
161
- if checkpoint_info is not None:
162
- return checkpoint_info
163
-
164
- if len(checkpoints_list) == 0:
165
- print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
166
- if shared.cmd_opts.ckpt is not None:
167
- print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
168
- print(f" - directory {model_path}", file=sys.stderr)
169
- if shared.cmd_opts.ckpt_dir is not None:
170
- print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
171
- print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
172
- exit(1)
173
-
174
- checkpoint_info = next(iter(checkpoints_list.values()))
175
- if model_checkpoint is not None:
176
- print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
177
-
178
- return checkpoint_info
179
-
180
-
181
- chckpoint_dict_replacements = {
182
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
183
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
184
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
185
- }
186
-
187
-
188
- def transform_checkpoint_dict_key(k):
189
- for text, replacement in chckpoint_dict_replacements.items():
190
- if k.startswith(text):
191
- k = replacement + k[len(text):]
192
-
193
- return k
194
-
195
-
196
- def get_state_dict_from_checkpoint(pl_sd):
197
- pl_sd = pl_sd.pop("state_dict", pl_sd)
198
- pl_sd.pop("state_dict", None)
199
-
200
- sd = {}
201
- for k, v in pl_sd.items():
202
- new_key = transform_checkpoint_dict_key(k)
203
-
204
- if new_key is not None:
205
- sd[new_key] = v
206
-
207
- pl_sd.clear()
208
- pl_sd.update(sd)
209
-
210
- return pl_sd
211
-
212
-
213
- def read_state_dict(checkpoint_file, print_global_state=False, map_location='cuda'):
214
- _, extension = os.path.splitext(checkpoint_file)
215
- if extension.lower() == ".safetensors":
216
- device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
217
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
218
- else:
219
- pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
220
-
221
- if print_global_state and "global_step" in pl_sd:
222
- print(f"Global Step: {pl_sd['global_step']}")
223
-
224
- sd = get_state_dict_from_checkpoint(pl_sd)
225
- return sd
226
-
227
-
228
- def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
229
- sd_model_hash = checkpoint_info.calculate_shorthash()
230
- timer.record("calculate hash")
231
-
232
- if checkpoint_info in checkpoints_loaded:
233
- # use checkpoint cache
234
- print(f"Loading weights [{sd_model_hash}] from cache")
235
- return checkpoints_loaded[checkpoint_info]
236
-
237
- print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
238
- res = read_state_dict(checkpoint_info.filename)
239
- timer.record("load weights from disk")
240
-
241
- return res
242
-
243
-
244
- def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
245
- sd_model_hash = checkpoint_info.calculate_shorthash()
246
- timer.record("calculate hash")
247
-
248
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
249
-
250
- if state_dict is None:
251
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
252
-
253
- model.load_state_dict(state_dict, strict=False)
254
- del state_dict
255
- timer.record("apply weights to model")
256
-
257
- if shared.opts.sd_checkpoint_cache > 0:
258
- # cache newly loaded model
259
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
260
-
261
- if shared.cmd_opts.opt_channelslast:
262
- model.to(memory_format=torch.channels_last)
263
- timer.record("apply channels_last")
264
-
265
- if not shared.cmd_opts.no_half:
266
- vae = model.first_stage_model
267
- depth_model = getattr(model, 'depth_model', None)
268
-
269
- # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
270
- if shared.cmd_opts.no_half_vae:
271
- model.first_stage_model = None
272
- # with --upcast-sampling, don't convert the depth model weights to float16
273
- if shared.cmd_opts.upcast_sampling and depth_model:
274
- model.depth_model = None
275
-
276
- model.half()
277
- model.first_stage_model = vae
278
- if depth_model:
279
- model.depth_model = depth_model
280
-
281
- timer.record("apply half()")
282
-
283
- devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
284
- devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
285
- devices.dtype_unet = model.model.diffusion_model.dtype
286
- devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
287
-
288
- model.first_stage_model.to(devices.dtype_vae)
289
- timer.record("apply dtype to VAE")
290
-
291
- # clean up cache if limit is reached
292
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
293
- checkpoints_loaded.popitem(last=False)
294
-
295
- model.sd_model_hash = sd_model_hash
296
- model.sd_model_checkpoint = checkpoint_info.filename
297
- model.sd_checkpoint_info = checkpoint_info
298
- shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
299
-
300
- model.logvar = model.logvar.to(devices.device) # fix for training
301
-
302
- sd_vae.delete_base_vae()
303
- sd_vae.clear_loaded_vae()
304
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
305
- sd_vae.load_vae(model, vae_file, vae_source)
306
- timer.record("load VAE")
307
-
308
-
309
- def enable_midas_autodownload():
310
- """
311
- Gives the ldm.modules.midas.api.load_model function automatic downloading.
312
-
313
- When the 512-depth-ema model, and other future models like it, is loaded,
314
- it calls midas.api.load_model to load the associated midas depth model.
315
- This function applies a wrapper to download the model to the correct
316
- location automatically.
317
- """
318
-
319
- midas_path = os.path.join(paths.models_path, 'midas')
320
-
321
- # stable-diffusion-stability-ai hard-codes the midas model path to
322
- # a location that differs from where other scripts using this model look.
323
- # HACK: Overriding the path here.
324
- for k, v in midas.api.ISL_PATHS.items():
325
- file_name = os.path.basename(v)
326
- midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
327
-
328
- midas_urls = {
329
- "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
330
- "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
331
- "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
332
- "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
333
- }
334
-
335
- midas.api.load_model_inner = midas.api.load_model
336
-
337
- def load_model_wrapper(model_type):
338
- path = midas.api.ISL_PATHS[model_type]
339
- if not os.path.exists(path):
340
- if not os.path.exists(midas_path):
341
- mkdir(midas_path)
342
-
343
- print(f"Downloading midas model weights for {model_type} to {path}")
344
- request.urlretrieve(midas_urls[model_type], path)
345
- print(f"{model_type} downloaded")
346
-
347
- return midas.api.load_model_inner(model_type)
348
-
349
- midas.api.load_model = load_model_wrapper
350
-
351
-
352
- def repair_config(sd_config):
353
-
354
- if not hasattr(sd_config.model.params, "use_ema"):
355
- sd_config.model.params.use_ema = False
356
-
357
- if shared.cmd_opts.no_half:
358
- sd_config.model.params.unet_config.params.use_fp16 = False
359
- elif shared.cmd_opts.upcast_sampling:
360
- sd_config.model.params.unet_config.params.use_fp16 = True
361
-
362
-
363
- sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
364
- sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
365
-
366
- def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
367
- from modules import lowvram, sd_hijack
368
- checkpoint_info = checkpoint_info or select_checkpoint()
369
-
370
- if shared.sd_model:
371
- sd_hijack.model_hijack.undo_hijack(shared.sd_model)
372
- shared.sd_model = None
373
- gc.collect()
374
- devices.torch_gc()
375
-
376
- do_inpainting_hijack()
377
-
378
- timer = Timer()
379
-
380
- if already_loaded_state_dict is not None:
381
- state_dict = already_loaded_state_dict
382
- else:
383
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
384
-
385
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
386
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
387
-
388
- timer.record("find config")
389
-
390
- sd_config = OmegaConf.load(checkpoint_config)
391
- repair_config(sd_config)
392
-
393
- timer.record("load config")
394
-
395
- print(f"Creating model from config: {checkpoint_config}")
396
-
397
- sd_model = None
398
- try:
399
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
400
- sd_model = instantiate_from_config(sd_config.model)
401
- except Exception as e:
402
- pass
403
-
404
- if sd_model is None:
405
- print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
406
- sd_model = instantiate_from_config(sd_config.model)
407
-
408
- sd_model.used_config = checkpoint_config
409
-
410
- timer.record("create model")
411
-
412
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
413
-
414
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
415
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
416
- else:
417
- sd_model.to(shared.device)
418
-
419
- timer.record("move model to device")
420
-
421
- sd_hijack.model_hijack.hijack(sd_model)
422
-
423
- timer.record("hijack")
424
-
425
- sd_model.eval()
426
- shared.sd_model = sd_model
427
-
428
- sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
429
-
430
- timer.record("load textual inversion embeddings")
431
-
432
- script_callbacks.model_loaded_callback(sd_model)
433
-
434
- timer.record("scripts callbacks")
435
-
436
- print(f"Model loaded in {timer.summary()}.")
437
-
438
- return sd_model
439
-
440
-
441
- def reload_model_weights(sd_model=None, info=None):
442
- from modules import lowvram, devices, sd_hijack
443
- checkpoint_info = info or select_checkpoint()
444
-
445
- if not sd_model:
446
- sd_model = shared.sd_model
447
-
448
- if sd_model is None: # previous model load failed
449
- current_checkpoint_info = None
450
- else:
451
- current_checkpoint_info = sd_model.sd_checkpoint_info
452
- if sd_model.sd_model_checkpoint == checkpoint_info.filename:
453
- return
454
-
455
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
456
- lowvram.send_everything_to_cpu()
457
- else:
458
- sd_model.to(devices.cpu)
459
-
460
- sd_hijack.model_hijack.undo_hijack(sd_model)
461
-
462
- timer = Timer()
463
-
464
- state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
465
-
466
- checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
467
-
468
- timer.record("find config")
469
-
470
- if sd_model is None or checkpoint_config != sd_model.used_config:
471
- del sd_model
472
- checkpoints_loaded.clear()
473
- load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
474
- return shared.sd_model
475
-
476
- try:
477
- load_model_weights(sd_model, checkpoint_info, state_dict, timer)
478
- except Exception as e:
479
- print("Failed to load checkpoint, restoring previous")
480
- load_model_weights(sd_model, current_checkpoint_info, None, timer)
481
- raise
482
- finally:
483
- sd_hijack.model_hijack.hijack(sd_model)
484
- timer.record("hijack")
485
-
486
- script_callbacks.model_loaded_callback(sd_model)
487
- timer.record("script callbacks")
488
-
489
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
490
- sd_model.to(devices.device)
491
- timer.record("move model to device")
492
-
493
- print(f"Weights loaded in {timer.summary()}.")
494
-
495
- return sd_model
 
 
1
+ import collections
2
+ import os.path
3
+ import sys
4
+ import gc
5
+ import torch
6
+ import re
7
+ import safetensors.torch
8
+ from omegaconf import OmegaConf
9
+ from os import mkdir
10
+ from urllib import request
11
+ import ldm.modules.midas as midas
12
+
13
+ from ldm.util import instantiate_from_config
14
+
15
+ from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
16
+ from modules.paths import models_path
17
+ from modules.sd_hijack_inpainting import do_inpainting_hijack
18
+ from modules.timer import Timer
19
+
20
+ model_dir = "Stable-diffusion"
21
+ model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
22
+
23
+ checkpoints_list = {}
24
+ checkpoint_alisases = {}
25
+ checkpoints_loaded = collections.OrderedDict()
26
+
27
+
28
+ class CheckpointInfo:
29
+ def __init__(self, filename):
30
+ self.filename = filename
31
+ abspath = os.path.abspath(filename)
32
+
33
+ shared.cmd_opts.ckpt_dir='/content/gdrive/MyDrive/sd/stable-diffusion-webui/models/Stable-diffusion/model.ckpt'
34
+ if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
35
+ name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
36
+ elif abspath.startswith(model_path):
37
+ name = abspath.replace(model_path, '')
38
+ else:
39
+ name = os.path.basename(filename)
40
+
41
+ if name.startswith("\\") or name.startswith("/"):
42
+ name = name[1:]
43
+
44
+ self.name = name
45
+ self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
46
+ self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
47
+ self.hash = model_hash(filename)
48
+
49
+ self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
50
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
51
+
52
+ self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
53
+
54
+ self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
55
+
56
+ def register(self):
57
+ checkpoints_list[self.title] = self
58
+ for id in self.ids:
59
+ checkpoint_alisases[id] = self
60
+
61
+ def calculate_shorthash(self):
62
+ self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
63
+ if self.sha256 is None:
64
+ return
65
+
66
+ self.shorthash = self.sha256[0:10]
67
+
68
+ if self.shorthash not in self.ids:
69
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
70
+
71
+ checkpoints_list.pop(self.title)
72
+ self.title = f'{self.name} [{self.shorthash}]'
73
+ self.register()
74
+
75
+ return self.shorthash
76
+
77
+
78
+ try:
79
+ # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
80
+
81
+ from transformers import logging, CLIPModel
82
+
83
+ logging.set_verbosity_error()
84
+ except Exception:
85
+ pass
86
+
87
+
88
+ def setup_model():
89
+ if not os.path.exists(model_path):
90
+ os.makedirs(model_path)
91
+
92
+ list_models()
93
+ enable_midas_autodownload()
94
+
95
+
96
+ def checkpoint_tiles():
97
+ def convert(name):
98
+ return int(name) if name.isdigit() else name.lower()
99
+
100
+ def alphanumeric_key(key):
101
+ return [convert(c) for c in re.split('([0-9]+)', key)]
102
+
103
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
104
+
105
+
106
+ def list_models():
107
+ checkpoints_list.clear()
108
+ checkpoint_alisases.clear()
109
+
110
+ cmd_ckpt = shared.cmd_opts.ckpt
111
+ if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
112
+ model_url = None
113
+ else:
114
+ model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"
115
+
116
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
117
+
118
+ if os.path.exists(cmd_ckpt):
119
+ checkpoint_info = CheckpointInfo(cmd_ckpt)
120
+ checkpoint_info.register()
121
+
122
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
123
+ elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
124
+ print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
125
+
126
+ for filename in model_list:
127
+ checkpoint_info = CheckpointInfo(filename)
128
+ checkpoint_info.register()
129
+
130
+
131
+ def get_closet_checkpoint_match(search_string):
132
+ checkpoint_info = checkpoint_alisases.get(search_string, None)
133
+ if checkpoint_info is not None:
134
+ return checkpoint_info
135
+
136
+ found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
137
+ if found:
138
+ return found[0]
139
+
140
+ return None
141
+
142
+
143
+ def model_hash(filename):
144
+ """old hash that only looks at a small part of the file and is prone to collisions"""
145
+
146
+ try:
147
+ with open(filename, "rb") as file:
148
+ import hashlib
149
+ m = hashlib.sha256()
150
+
151
+ file.seek(0x100000)
152
+ m.update(file.read(0x10000))
153
+ return m.hexdigest()[0:8]
154
+ except FileNotFoundError:
155
+ return 'NOFILE'
156
+
157
+
158
+ def select_checkpoint():
159
+ model_checkpoint = shared.opts.sd_model_checkpoint
160
+
161
+ checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
162
+ if checkpoint_info is not None:
163
+ return checkpoint_info
164
+
165
+ if len(checkpoints_list) == 0:
166
+ print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
167
+ if shared.cmd_opts.ckpt is not None:
168
+ print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
169
+ print(f" - directory {model_path}", file=sys.stderr)
170
+ if shared.cmd_opts.ckpt_dir is not None:
171
+ print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
172
+ print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
173
+ exit(1)
174
+
175
+ checkpoint_info = next(iter(checkpoints_list.values()))
176
+ if model_checkpoint is not None:
177
+ print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)
178
+
179
+ return checkpoint_info
180
+
181
+
182
+ chckpoint_dict_replacements = {
183
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
184
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
185
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
186
+ }
187
+
188
+
189
+ def transform_checkpoint_dict_key(k):
190
+ for text, replacement in chckpoint_dict_replacements.items():
191
+ if k.startswith(text):
192
+ k = replacement + k[len(text):]
193
+
194
+ return k
195
+
196
+
197
+ def get_state_dict_from_checkpoint(pl_sd):
198
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
199
+ pl_sd.pop("state_dict", None)
200
+
201
+ sd = {}
202
+ for k, v in pl_sd.items():
203
+ new_key = transform_checkpoint_dict_key(k)
204
+
205
+ if new_key is not None:
206
+ sd[new_key] = v
207
+
208
+ pl_sd.clear()
209
+ pl_sd.update(sd)
210
+
211
+ return pl_sd
212
+
213
+
214
+ def read_state_dict(checkpoint_file, print_global_state=False, map_location='cuda'):
215
+ _, extension = os.path.splitext(checkpoint_file)
216
+ if extension.lower() == ".safetensors":
217
+ device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
218
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
219
+ else:
220
+ pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
221
+
222
+ if print_global_state and "global_step" in pl_sd:
223
+ print(f"Global Step: {pl_sd['global_step']}")
224
+
225
+ sd = get_state_dict_from_checkpoint(pl_sd)
226
+ return sd
227
+
228
+
229
+ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
230
+ sd_model_hash = checkpoint_info.calculate_shorthash()
231
+ timer.record("calculate hash")
232
+
233
+ if checkpoint_info in checkpoints_loaded:
234
+ # use checkpoint cache
235
+ print(f"Loading weights [{sd_model_hash}] from cache")
236
+ return checkpoints_loaded[checkpoint_info]
237
+
238
+ print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
239
+ res = read_state_dict(checkpoint_info.filename)
240
+ timer.record("load weights from disk")
241
+
242
+ return res
243
+
244
+
245
+ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
246
+ sd_model_hash = checkpoint_info.calculate_shorthash()
247
+ timer.record("calculate hash")
248
+
249
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
250
+
251
+ if state_dict is None:
252
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
253
+
254
+ model.load_state_dict(state_dict, strict=False)
255
+ del state_dict
256
+ timer.record("apply weights to model")
257
+
258
+ if shared.opts.sd_checkpoint_cache > 0:
259
+ # cache newly loaded model
260
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
261
+
262
+ if shared.cmd_opts.opt_channelslast:
263
+ model.to(memory_format=torch.channels_last)
264
+ timer.record("apply channels_last")
265
+
266
+ if not shared.cmd_opts.no_half:
267
+ vae = model.first_stage_model
268
+ depth_model = getattr(model, 'depth_model', None)
269
+
270
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
271
+ if shared.cmd_opts.no_half_vae:
272
+ model.first_stage_model = None
273
+ # with --upcast-sampling, don't convert the depth model weights to float16
274
+ if shared.cmd_opts.upcast_sampling and depth_model:
275
+ model.depth_model = None
276
+
277
+ model.half()
278
+ model.first_stage_model = vae
279
+ if depth_model:
280
+ model.depth_model = depth_model
281
+
282
+ timer.record("apply half()")
283
+
284
+ devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
285
+ devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
286
+ devices.dtype_unet = model.model.diffusion_model.dtype
287
+ devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
288
+
289
+ model.first_stage_model.to(devices.dtype_vae)
290
+ timer.record("apply dtype to VAE")
291
+
292
+ # clean up cache if limit is reached
293
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
294
+ checkpoints_loaded.popitem(last=False)
295
+
296
+ model.sd_model_hash = sd_model_hash
297
+ model.sd_model_checkpoint = checkpoint_info.filename
298
+ model.sd_checkpoint_info = checkpoint_info
299
+ shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
300
+
301
+ model.logvar = model.logvar.to(devices.device) # fix for training
302
+
303
+ sd_vae.delete_base_vae()
304
+ sd_vae.clear_loaded_vae()
305
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
306
+ sd_vae.load_vae(model, vae_file, vae_source)
307
+ timer.record("load VAE")
308
+
309
+
310
+ def enable_midas_autodownload():
311
+ """
312
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
313
+
314
+ When the 512-depth-ema model, and other future models like it, is loaded,
315
+ it calls midas.api.load_model to load the associated midas depth model.
316
+ This function applies a wrapper to download the model to the correct
317
+ location automatically.
318
+ """
319
+
320
+ midas_path = os.path.join(paths.models_path, 'midas')
321
+
322
+ # stable-diffusion-stability-ai hard-codes the midas model path to
323
+ # a location that differs from where other scripts using this model look.
324
+ # HACK: Overriding the path here.
325
+ for k, v in midas.api.ISL_PATHS.items():
326
+ file_name = os.path.basename(v)
327
+ midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
328
+
329
+ midas_urls = {
330
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
331
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
332
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
333
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
334
+ }
335
+
336
+ midas.api.load_model_inner = midas.api.load_model
337
+
338
+ def load_model_wrapper(model_type):
339
+ path = midas.api.ISL_PATHS[model_type]
340
+ if not os.path.exists(path):
341
+ if not os.path.exists(midas_path):
342
+ mkdir(midas_path)
343
+
344
+ print(f"Downloading midas model weights for {model_type} to {path}")
345
+ request.urlretrieve(midas_urls[model_type], path)
346
+ print(f"{model_type} downloaded")
347
+
348
+ return midas.api.load_model_inner(model_type)
349
+
350
+ midas.api.load_model = load_model_wrapper
351
+
352
+
353
+ def repair_config(sd_config):
354
+
355
+ if not hasattr(sd_config.model.params, "use_ema"):
356
+ sd_config.model.params.use_ema = False
357
+
358
+ if shared.cmd_opts.no_half:
359
+ sd_config.model.params.unet_config.params.use_fp16 = False
360
+ elif shared.cmd_opts.upcast_sampling:
361
+ sd_config.model.params.unet_config.params.use_fp16 = True
362
+
363
+
364
+ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
365
+ sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
366
+
367
+ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
368
+ from modules import lowvram, sd_hijack
369
+ checkpoint_info = checkpoint_info or select_checkpoint()
370
+
371
+ if shared.sd_model:
372
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
373
+ shared.sd_model = None
374
+ gc.collect()
375
+ devices.torch_gc()
376
+
377
+ do_inpainting_hijack()
378
+
379
+ timer = Timer()
380
+
381
+ if already_loaded_state_dict is not None:
382
+ state_dict = already_loaded_state_dict
383
+ else:
384
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
385
+
386
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
387
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
388
+
389
+ timer.record("find config")
390
+
391
+ sd_config = OmegaConf.load(checkpoint_config)
392
+ repair_config(sd_config)
393
+
394
+ timer.record("load config")
395
+
396
+ print(f"Creating model from config: {checkpoint_config}")
397
+
398
+ sd_model = None
399
+ try:
400
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
401
+ sd_model = instantiate_from_config(sd_config.model)
402
+ except Exception as e:
403
+ pass
404
+
405
+ if sd_model is None:
406
+ print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
407
+ sd_model = instantiate_from_config(sd_config.model)
408
+
409
+ sd_model.used_config = checkpoint_config
410
+
411
+ timer.record("create model")
412
+
413
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
414
+
415
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
416
+ lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
417
+ else:
418
+ sd_model.to(shared.device)
419
+
420
+ timer.record("move model to device")
421
+
422
+ sd_hijack.model_hijack.hijack(sd_model)
423
+
424
+ timer.record("hijack")
425
+
426
+ sd_model.eval()
427
+ shared.sd_model = sd_model
428
+
429
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
430
+
431
+ timer.record("load textual inversion embeddings")
432
+
433
+ script_callbacks.model_loaded_callback(sd_model)
434
+
435
+ timer.record("scripts callbacks")
436
+
437
+ print(f"Model loaded in {timer.summary()}.")
438
+
439
+ return sd_model
440
+
441
+
442
+ def reload_model_weights(sd_model=None, info=None):
443
+ from modules import lowvram, devices, sd_hijack
444
+ checkpoint_info = info or select_checkpoint()
445
+
446
+ if not sd_model:
447
+ sd_model = shared.sd_model
448
+
449
+ if sd_model is None: # previous model load failed
450
+ current_checkpoint_info = None
451
+ else:
452
+ current_checkpoint_info = sd_model.sd_checkpoint_info
453
+ if sd_model.sd_model_checkpoint == checkpoint_info.filename:
454
+ return
455
+
456
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
457
+ lowvram.send_everything_to_cpu()
458
+ else:
459
+ sd_model.to(devices.cpu)
460
+
461
+ sd_hijack.model_hijack.undo_hijack(sd_model)
462
+
463
+ timer = Timer()
464
+
465
+ state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
466
+
467
+ checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
468
+
469
+ timer.record("find config")
470
+
471
+ if sd_model is None or checkpoint_config != sd_model.used_config:
472
+ del sd_model
473
+ checkpoints_loaded.clear()
474
+ load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
475
+ return shared.sd_model
476
+
477
+ try:
478
+ load_model_weights(sd_model, checkpoint_info, state_dict, timer)
479
+ except Exception as e:
480
+ print("Failed to load checkpoint, restoring previous")
481
+ load_model_weights(sd_model, current_checkpoint_info, None, timer)
482
+ raise
483
+ finally:
484
+ sd_hijack.model_hijack.hijack(sd_model)
485
+ timer.record("hijack")
486
+
487
+ script_callbacks.model_loaded_callback(sd_model)
488
+ timer.record("script callbacks")
489
+
490
+ if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
491
+ sd_model.to(devices.device)
492
+ timer.record("move model to device")
493
+
494
+ print(f"Weights loaded in {timer.summary()}.")
495
+
496
+ return sd_model