yiren98 commited on
Commit
c1bc1cb
·
1 Parent(s): 3889dde
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +12 -0
  2. LICENSE +21 -0
  3. README.md +2 -2
  4. flux_inference_recraft.py +442 -0
  5. flux_minimal_inference.py +576 -0
  6. flux_minimal_inference_asylora.py +583 -0
  7. flux_train_network.py +588 -0
  8. flux_train_network_asylora.py +591 -0
  9. flux_train_recraft.py +713 -0
  10. gradio_app.py +372 -0
  11. gradio_app_asy.py +329 -0
  12. library/__init__.py +0 -0
  13. library/adafactor_fused.py +138 -0
  14. library/attention_processors.py +227 -0
  15. library/config_util.py +716 -0
  16. library/custom_offloading_utils.py +227 -0
  17. library/custom_train_functions.py +559 -0
  18. library/deepspeed_utils.py +139 -0
  19. library/device_utils.py +84 -0
  20. library/flux_models.py +1237 -0
  21. library/flux_train_utils.py +582 -0
  22. library/flux_train_utils_recraft.py +659 -0
  23. library/flux_utils.py +472 -0
  24. library/huggingface_util.py +84 -0
  25. library/hypernetwork.py +223 -0
  26. library/ipex/__init__.py +180 -0
  27. library/ipex/attention.py +177 -0
  28. library/ipex/diffusers.py +312 -0
  29. library/ipex/gradscaler.py +183 -0
  30. library/ipex/hijacks.py +313 -0
  31. library/lpw_stable_diffusion.py +1233 -0
  32. library/model_util.py +1356 -0
  33. library/original_unet.py +1919 -0
  34. library/sai_model_spec.py +334 -0
  35. library/sd3_models.py +1413 -0
  36. library/sd3_train_utils.py +945 -0
  37. library/sd3_utils.py +302 -0
  38. library/sdxl_lpw_stable_diffusion.py +1271 -0
  39. library/sdxl_model_util.py +583 -0
  40. library/sdxl_original_control_net.py +272 -0
  41. library/sdxl_original_unet.py +1292 -0
  42. library/sdxl_train_util.py +382 -0
  43. library/slicing_vae.py +682 -0
  44. library/strategy_base.py +570 -0
  45. library/strategy_flux.py +271 -0
  46. library/strategy_sd.py +171 -0
  47. library/strategy_sd3.py +420 -0
  48. library/strategy_sdxl.py +306 -0
  49. library/train_util.py +0 -0
  50. library/utils.py +582 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ test
3
+ *.egg-info
4
+ .vscode
5
+ .gradio
6
+ wandb
7
+ Merge
8
+ asy_results
9
+ recraft_results
10
+ drop
11
+ SplitAsy
12
+ example*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Show Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -4,8 +4,8 @@ emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Generate high quality images from prmopts
 
4
  colorFrom: purple
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.38.0
8
+ app_file: gradio_app_asy.py
9
  pinned: false
10
  license: mit
11
  short_description: Generate high quality images from prmopts
flux_inference_recraft.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any
6
+ import pdb
7
+ import os
8
+
9
+ import time
10
+ from PIL import Image, ImageOps
11
+
12
+ import torch
13
+ from accelerate import Accelerator
14
+ from library.device_utils import clean_memory_on_device
15
+ from safetensors.torch import load_file
16
+ from networks import lora_flux
17
+
18
+ from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, \
19
+ strategy_base, strategy_flux, train_util
20
+ from torchvision import transforms
21
+ import train_network
22
+ from library.utils import setup_logging
23
+ from diffusers.utils import load_image
24
+ import numpy as np
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def load_target_model(
33
+ fp8_base: bool,
34
+ pretrained_model_name_or_path: str,
35
+ disable_mmap_load_safetensors: bool,
36
+ clip_l_path: str,
37
+ fp8_base_unet: bool,
38
+ t5xxl_path: str,
39
+ ae_path: str,
40
+ weight_dtype: torch.dtype,
41
+ accelerator: Accelerator
42
+ ):
43
+ # Determine the loading data type
44
+ loading_dtype = None if fp8_base else weight_dtype
45
+
46
+ # Load the main model to the accelerator's device
47
+ _, model = flux_utils.load_flow_model(
48
+ pretrained_model_name_or_path,
49
+ # loading_dtype,
50
+ torch.float8_e4m3fn,
51
+ # accelerator.device, # Changed from "cpu" to accelerator.device
52
+ "cpu",
53
+ disable_mmap=disable_mmap_load_safetensors
54
+ )
55
+
56
+ if fp8_base:
57
+ # Check dtype of the model
58
+ if model.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
59
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
60
+ elif model.dtype == torch.float8_e4m3fn:
61
+ logger.info("Loaded fp8 FLUX model")
62
+
63
+ # Load the CLIP model to the accelerator's device
64
+ clip_l = flux_utils.load_clip_l(
65
+ clip_l_path,
66
+ weight_dtype,
67
+ # accelerator.device, # Changed from "cpu" to accelerator.device
68
+ "cpu",
69
+ disable_mmap=disable_mmap_load_safetensors
70
+ )
71
+ clip_l.eval()
72
+
73
+ # Determine the loading data type for T5XXL
74
+ if fp8_base and not fp8_base_unet:
75
+ loading_dtype_t5xxl = None # as is
76
+ else:
77
+ loading_dtype_t5xxl = weight_dtype
78
+
79
+ # Load the T5XXL model to the accelerator's device
80
+ t5xxl = flux_utils.load_t5xxl(
81
+ t5xxl_path,
82
+ loading_dtype_t5xxl,
83
+ # accelerator.device, # Changed from "cpu" to accelerator.device
84
+ "cpu",
85
+ disable_mmap=disable_mmap_load_safetensors
86
+ )
87
+ t5xxl.eval()
88
+
89
+ if fp8_base and not fp8_base_unet:
90
+ # Check dtype of the T5XXL model
91
+ if t5xxl.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
92
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
93
+ elif t5xxl.dtype == torch.float8_e4m3fn:
94
+ logger.info("Loaded fp8 T5XXL model")
95
+
96
+ # Load the AE model to the accelerator's device
97
+ ae = flux_utils.load_ae(
98
+ ae_path,
99
+ weight_dtype,
100
+ # accelerator.device, # Changed from "cpu" to accelerator.device
101
+ "cpu",
102
+ disable_mmap=disable_mmap_load_safetensors
103
+ )
104
+
105
+ # # Wrap models with Accelerator for potential distributed setups
106
+ # model, clip_l, t5xxl, ae = accelerator.prepare(model, clip_l, t5xxl, ae)
107
+
108
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
109
+
110
+
111
+ import torchvision.transforms as transforms
112
+
113
+
114
+ class ResizeWithPadding:
115
+ def __init__(self, size, fill=255):
116
+ self.size = size
117
+ self.fill = fill
118
+
119
+ def __call__(self, img):
120
+ if isinstance(img, np.ndarray):
121
+ img = Image.fromarray(img)
122
+ elif not isinstance(img, Image.Image):
123
+ raise TypeError("Input must be a PIL Image or a NumPy array")
124
+
125
+ width, height = img.size
126
+
127
+ if width == height:
128
+ img = img.resize((self.size, self.size), Image.LANCZOS)
129
+ else:
130
+ max_dim = max(width, height)
131
+
132
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
133
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
134
+
135
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
136
+
137
+ return img
138
+
139
+
140
+ def sample(args, accelerator, vae, text_encoder, flux, output_dir, sample_images, sample_prompts):
141
+ def encode_images_to_latents(vae, images):
142
+ # Get image dimensions
143
+ b, c, h, w = images.shape
144
+ num_split = 2 if args.frame_num == 4 else 3
145
+ # Split the image into three parts
146
+ img_parts = [images[:, :, :, i * w // num_split:(i + 1) * w // num_split] for i in range(num_split)]
147
+ # Encode each part
148
+ latents = [vae.encode(img) for img in img_parts]
149
+ # Concatenate latents in the latent space to reconstruct the full image
150
+ latents = torch.cat(latents, dim=-1)
151
+ return latents
152
+
153
+ def encode_images_to_latents2(vae, images):
154
+ latents = vae.encode(images)
155
+ return latents
156
+
157
+ # Directly use precomputed conditions
158
+ conditions = {}
159
+ with torch.no_grad():
160
+ for image_path, prompt_dict in zip(sample_images, sample_prompts):
161
+ prompt = prompt_dict.get("prompt", "")
162
+ if prompt not in conditions:
163
+ logger.info(f"Cache conditions for image: {image_path} with prompt: {prompt}")
164
+ resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
165
+ img_transforms = transforms.Compose([
166
+ resize_transform,
167
+ transforms.ToTensor(),
168
+ transforms.Normalize([0.5], [0.5]),
169
+ ])
170
+ # Load and preprocess image
171
+ image = img_transforms(np.array(load_image(image_path), dtype=np.uint8)).unsqueeze(0).to(
172
+ # accelerator.device, # Move image to CUDA
173
+ vae.device,
174
+ dtype=vae.dtype
175
+ )
176
+ latents = encode_images_to_latents2(vae, image)
177
+
178
+ # Log the shape of latents
179
+ logger.debug(f"Encoded latents shape for prompt '{prompt}': {latents.shape}")
180
+ # Store conditions on CUDA
181
+ # conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
182
+ conditions[prompt] = latents.to("cpu")
183
+
184
+ sample_conditions = conditions
185
+
186
+ if sample_conditions is not None:
187
+ conditions = {k: v for k, v in sample_conditions.items()} # Already on CUDA
188
+
189
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
190
+ text_encoder[0].to(accelerator.device)
191
+ text_encoder[1].to(accelerator.device)
192
+
193
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
194
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
195
+
196
+ with accelerator.autocast(), torch.no_grad():
197
+ for prompt_dict in sample_prompts:
198
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
199
+ if p not in sample_prompts_te_outputs:
200
+ logger.info(f"Cache Text Encoder outputs for prompt: {p}")
201
+ tokens_and_masks = tokenize_strategy.tokenize(p)
202
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
203
+ tokenize_strategy, text_encoder, tokens_and_masks, True
204
+ )
205
+
206
+ logger.info(f"Generating image")
207
+ save_dir = output_dir
208
+ os.makedirs(save_dir, exist_ok=True)
209
+
210
+ with torch.no_grad(), accelerator.autocast():
211
+ for prompt_dict in sample_prompts:
212
+ sample_image_inference(
213
+ args,
214
+ accelerator,
215
+ flux,
216
+ text_encoder,
217
+ vae,
218
+ save_dir,
219
+ prompt_dict,
220
+ sample_prompts_te_outputs,
221
+ None,
222
+ conditions
223
+ )
224
+
225
+ clean_memory_on_device(accelerator.device)
226
+
227
+
228
+ def sample_image_inference(
229
+ args,
230
+ accelerator: Accelerator,
231
+ flux: flux_models.Flux,
232
+ text_encoder,
233
+ ae: flux_models.AutoEncoder,
234
+ save_dir,
235
+ prompt_dict,
236
+ sample_prompts_te_outputs,
237
+ prompt_replacement,
238
+ sample_images_ae_outputs
239
+ ):
240
+ # Extract parameters from prompt_dict
241
+ sample_steps = prompt_dict.get("sample_steps", 20)
242
+ width = prompt_dict.get("width", 1024) if args.frame_num == 4 else prompt_dict.get("width", 1056)
243
+ height = prompt_dict.get("height", 1024) if args.frame_num == 4 else prompt_dict.get("height", 1056)
244
+ scale = prompt_dict.get("scale", 1.0)
245
+ seed = prompt_dict.get("seed")
246
+ prompt: str = prompt_dict.get("prompt", "")
247
+
248
+ if prompt_replacement is not None:
249
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
250
+
251
+ if seed is not None:
252
+ torch.manual_seed(seed)
253
+ torch.cuda.manual_seed(seed)
254
+ else:
255
+ # True random sample image generation
256
+ torch.seed()
257
+ torch.cuda.seed()
258
+
259
+ # Ensure height and width are divisible by 16
260
+ height = max(64, height - height % 16)
261
+ width = max(64, width - width % 16)
262
+ logger.info(f"prompt: {prompt}")
263
+ logger.info(f"height: {height}")
264
+ logger.info(f"width: {width}")
265
+ logger.info(f"sample_steps: {sample_steps}")
266
+ logger.info(f"scale: {scale}")
267
+ if seed is not None:
268
+ logger.info(f"seed: {seed}")
269
+
270
+ # Encode prompts
271
+ # Assuming that TokenizeStrategy and TextEncodingStrategy are compatible with Accelerator
272
+ text_encoder_conds = []
273
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
274
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
275
+ logger.info(f"Using cached text encoder outputs for prompt: {prompt}")
276
+
277
+ if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
278
+ ae_outputs = sample_images_ae_outputs[prompt]
279
+ else:
280
+ ae_outputs = None
281
+
282
+ # ae_outputs = torch.load('ae_outputs.pth', map_location='cuda:0')
283
+
284
+ # text_encoder_conds = torch.load('text_encoder_conds.pth', map_location='cuda:0')
285
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
286
+
287
+ # 打印调试信息
288
+ logger.debug(
289
+ f"l_pooled shape: {l_pooled.shape}, t5_out shape: {t5_out.shape}, txt_ids shape: {txt_ids.shape}, t5_attn_mask shape: {t5_attn_mask.shape}")
290
+
291
+ # 采样图像
292
+ weight_dtype = ae.dtype # TODO: give dtype as argument
293
+ packed_latent_height = height // 16
294
+ packed_latent_width = width // 16
295
+
296
+ # 打印调试信息
297
+ logger.debug(f"packed_latent_height: {packed_latent_height}, packed_latent_width: {packed_latent_width}")
298
+
299
+ # 准备噪声张量在 CUDA 上
300
+ noise = torch.randn(
301
+ 1,
302
+ packed_latent_height * packed_latent_width,
303
+ 16 * 2 * 2,
304
+ device=accelerator.device,
305
+ dtype=weight_dtype,
306
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
307
+ )
308
+
309
+ timesteps = flux_train_utils.get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
310
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(
311
+ accelerator.device, dtype=weight_dtype
312
+ )
313
+ t5_attn_mask = t5_attn_mask.to(accelerator.device)
314
+
315
+ clip_l, t5xxl = text_encoder
316
+ # ae.to("cpu")
317
+ clip_l.to("cpu")
318
+ t5xxl.to("cpu")
319
+
320
+ clean_memory_on_device(accelerator.device)
321
+ flux.to("cuda")
322
+
323
+ for param in flux.parameters():
324
+ param.requires_grad = False
325
+
326
+ # 执行去噪
327
+ with accelerator.autocast(), torch.no_grad():
328
+ x = flux_train_utils.denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps,
329
+ guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
330
+
331
+ # 打印x的形状
332
+ logger.debug(f"x shape after denoise: {x.shape}")
333
+
334
+ x = x.float()
335
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
336
+
337
+ # 将潜在向量转换为图像
338
+ # clean_memory_on_device(accelerator.device)
339
+ ae.to(accelerator.device)
340
+ with accelerator.autocast(), torch.no_grad():
341
+ x = ae.decode(x)
342
+ ae.to("cpu")
343
+ clean_memory_on_device(accelerator.device)
344
+
345
+ x = x.clamp(-1, 1)
346
+ x = x.permute(0, 2, 3, 1)
347
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
348
+
349
+ # 生成唯一的文件名
350
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
351
+ seed_suffix = "" if seed is None else f"_{seed}"
352
+ i: int = prompt_dict.get("enum", 0) # Ensure 'enum' exists
353
+ img_filename = f"{ts_str}{seed_suffix}_{i}.png" # Added 'i' to filename for uniqueness
354
+ image.save(os.path.join(save_dir, img_filename))
355
+
356
+
357
+ def setup_argparse():
358
+ parser = argparse.ArgumentParser(description="FLUX-Controlnet-Inpainting Inference Script")
359
+
360
+ # Paths
361
+ parser.add_argument('--base_flux_checkpoint', type=str, required=True,
362
+ help='Path to BASE_FLUX_CHECKPOINT')
363
+ parser.add_argument('--lora_weights_path', type=str, required=True,
364
+ help='Path to LORA_WEIGHTS_PATH')
365
+ parser.add_argument('--clip_l_path', type=str, required=True,
366
+ help='Path to CLIP_L_PATH')
367
+ parser.add_argument('--t5xxl_path', type=str, required=True,
368
+ help='Path to T5XXL_PATH')
369
+ parser.add_argument('--ae_path', type=str, required=True,
370
+ help='Path to AE_PATH')
371
+ parser.add_argument('--sample_images_file', type=str, required=True,
372
+ help='Path to SAMPLE_IMAGES_FILE')
373
+ parser.add_argument('--sample_prompts_file', type=str, required=True,
374
+ help='Path to SAMPLE_PROMPTS_FILE')
375
+ parser.add_argument('--output_dir', type=str, required=True,
376
+ help='Directory to save OUTPUT_DIR')
377
+ parser.add_argument('--frame_num', type=int, choices=[4, 9], required=True,
378
+ help="The number of steps in the generated step diagram (choose 4 or 9)")
379
+
380
+ return parser.parse_args()
381
+
382
+
383
+ def main(args):
384
+ accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
385
+
386
+ BASE_FLUX_CHECKPOINT = args.base_flux_checkpoint
387
+ LORA_WEIGHTS_PATH = args.lora_weights_path
388
+ CLIP_L_PATH = args.clip_l_path
389
+ T5XXL_PATH = args.t5xxl_path
390
+ AE_PATH = args.ae_path
391
+
392
+ SAMPLE_IMAGES_FILE = args.sample_images_file
393
+ SAMPLE_PROMPTS_FILE = args.sample_prompts_file
394
+ OUTPUT_DIR = args.output_dir
395
+
396
+ with open(SAMPLE_IMAGES_FILE, "r", encoding="utf-8") as f:
397
+ image_lines = f.readlines()
398
+ sample_images = [line.strip() for line in image_lines if line.strip() and not line.strip().startswith("#")]
399
+
400
+ sample_prompts = train_util.load_prompts(SAMPLE_PROMPTS_FILE)
401
+
402
+ # Load models onto CUDA via Accelerator
403
+ _, [clip_l, t5xxl], ae, model = load_target_model(
404
+ fp8_base=True,
405
+ pretrained_model_name_or_path=BASE_FLUX_CHECKPOINT,
406
+ disable_mmap_load_safetensors=False,
407
+ clip_l_path=CLIP_L_PATH,
408
+ fp8_base_unet=False,
409
+ t5xxl_path=T5XXL_PATH,
410
+ ae_path=AE_PATH,
411
+ weight_dtype=torch.bfloat16,
412
+ accelerator=accelerator
413
+ )
414
+
415
+ model.eval()
416
+ clip_l.eval()
417
+ t5xxl.eval()
418
+ ae.eval()
419
+
420
+ # LoRA
421
+ multiplier = 1.0
422
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
423
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
424
+ True)
425
+
426
+ lora_model.apply_to([clip_l, t5xxl], model)
427
+ info = lora_model.load_state_dict(weights_sd, strict=True)
428
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
429
+ lora_model.eval()
430
+ lora_model.to("cuda")
431
+
432
+ # Set text encoders
433
+ text_encoder = [clip_l, t5xxl]
434
+
435
+ sample(args, accelerator, vae=ae, text_encoder=text_encoder, flux=model, output_dir=OUTPUT_DIR,
436
+ sample_images=sample_images, sample_prompts=sample_prompts)
437
+
438
+
439
+ if __name__ == "__main__":
440
+ args = setup_argparse()
441
+
442
+ main(args)
flux_minimal_inference.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimum Inference Code for FLUX
2
+
3
+ import argparse
4
+ import datetime
5
+ import math
6
+ import os
7
+ import random
8
+ from typing import Callable, List, Optional
9
+ import einops
10
+ import numpy as np
11
+
12
+ import torch
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ import accelerate
16
+ from transformers import CLIPTextModel
17
+ from safetensors.torch import load_file
18
+
19
+ from library import device_utils
20
+ from library.device_utils import init_ipex, get_preferred_device
21
+ from networks import oft_flux
22
+
23
+ init_ipex()
24
+
25
+
26
+ from library.utils import setup_logging, str_to_dtype
27
+
28
+ setup_logging()
29
+ import logging
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ import networks.lora_flux as lora_flux
34
+ from library import flux_models, flux_utils, sd3_utils, strategy_flux
35
+
36
+
37
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
38
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
39
+
40
+
41
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
42
+ m = (y2 - y1) / (x2 - x1)
43
+ b = y1 - m * x1
44
+ return lambda x: m * x + b
45
+
46
+
47
+ def get_schedule(
48
+ num_steps: int,
49
+ image_seq_len: int,
50
+ base_shift: float = 0.5,
51
+ max_shift: float = 1.15,
52
+ shift: bool = True,
53
+ ) -> list[float]:
54
+ # extra step for zero
55
+ timesteps = torch.linspace(1, 0, num_steps + 1)
56
+
57
+ # shifting the schedule to favor high timesteps for higher signal images
58
+ if shift:
59
+ # eastimate mu based on linear estimation between two points
60
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
61
+ timesteps = time_shift(mu, 1.0, timesteps)
62
+
63
+ return timesteps.tolist()
64
+
65
+
66
+ def denoise(
67
+ model: flux_models.Flux,
68
+ img: torch.Tensor,
69
+ img_ids: torch.Tensor,
70
+ txt: torch.Tensor,
71
+ txt_ids: torch.Tensor,
72
+ vec: torch.Tensor,
73
+ timesteps: list[float],
74
+ guidance: float = 4.0,
75
+ t5_attn_mask: Optional[torch.Tensor] = None,
76
+ neg_txt: Optional[torch.Tensor] = None,
77
+ neg_vec: Optional[torch.Tensor] = None,
78
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
79
+ cfg_scale: Optional[float] = None,
80
+ ):
81
+ # this is ignored for schnell
82
+ logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
83
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
84
+
85
+ # prepare classifier free guidance
86
+ if neg_txt is not None and neg_vec is not None:
87
+ b_img_ids = torch.cat([img_ids, img_ids], dim=0)
88
+ b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
89
+ b_txt = torch.cat([neg_txt, txt], dim=0)
90
+ b_vec = torch.cat([neg_vec, vec], dim=0)
91
+ if t5_attn_mask is not None and neg_t5_attn_mask is not None:
92
+ b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
93
+ else:
94
+ b_t5_attn_mask = None
95
+ else:
96
+ b_img_ids = img_ids
97
+ b_txt_ids = txt_ids
98
+ b_txt = txt
99
+ b_vec = vec
100
+ b_t5_attn_mask = t5_attn_mask
101
+
102
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
103
+ t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
104
+
105
+ # classifier free guidance
106
+ if neg_txt is not None and neg_vec is not None:
107
+ b_img = torch.cat([img, img], dim=0)
108
+ else:
109
+ b_img = img
110
+
111
+ pred = model(
112
+ img=b_img,
113
+ img_ids=b_img_ids,
114
+ txt=b_txt,
115
+ txt_ids=b_txt_ids,
116
+ y=b_vec,
117
+ timesteps=t_vec,
118
+ guidance=guidance_vec,
119
+ txt_attention_mask=b_t5_attn_mask,
120
+ )
121
+
122
+ # classifier free guidance
123
+ if neg_txt is not None and neg_vec is not None:
124
+ pred_uncond, pred = torch.chunk(pred, 2, dim=0)
125
+ pred = pred_uncond + cfg_scale * (pred - pred_uncond)
126
+
127
+ img = img + (t_prev - t_curr) * pred
128
+
129
+ return img
130
+
131
+
132
+ def do_sample(
133
+ accelerator: Optional[accelerate.Accelerator],
134
+ model: flux_models.Flux,
135
+ img: torch.Tensor,
136
+ img_ids: torch.Tensor,
137
+ l_pooled: torch.Tensor,
138
+ t5_out: torch.Tensor,
139
+ txt_ids: torch.Tensor,
140
+ num_steps: int,
141
+ guidance: float,
142
+ t5_attn_mask: Optional[torch.Tensor],
143
+ is_schnell: bool,
144
+ device: torch.device,
145
+ flux_dtype: torch.dtype,
146
+ neg_l_pooled: Optional[torch.Tensor] = None,
147
+ neg_t5_out: Optional[torch.Tensor] = None,
148
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
149
+ cfg_scale: Optional[float] = None,
150
+ ):
151
+ logger.info(f"num_steps: {num_steps}")
152
+ timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
153
+
154
+ # denoise initial noise
155
+ if accelerator:
156
+ with accelerator.autocast(), torch.no_grad():
157
+ x = denoise(
158
+ model,
159
+ img,
160
+ img_ids,
161
+ t5_out,
162
+ txt_ids,
163
+ l_pooled,
164
+ timesteps,
165
+ guidance,
166
+ t5_attn_mask,
167
+ neg_t5_out,
168
+ neg_l_pooled,
169
+ neg_t5_attn_mask,
170
+ cfg_scale,
171
+ )
172
+ else:
173
+ with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
174
+ x = denoise(
175
+ model,
176
+ img,
177
+ img_ids,
178
+ t5_out,
179
+ txt_ids,
180
+ l_pooled,
181
+ timesteps,
182
+ guidance,
183
+ t5_attn_mask,
184
+ neg_t5_out,
185
+ neg_l_pooled,
186
+ neg_t5_attn_mask,
187
+ cfg_scale,
188
+ )
189
+
190
+ return x
191
+
192
+
193
+ def generate_image(
194
+ model,
195
+ clip_l: CLIPTextModel,
196
+ t5xxl,
197
+ ae,
198
+ prompt: str,
199
+ seed: Optional[int],
200
+ image_width: int,
201
+ image_height: int,
202
+ steps: Optional[int],
203
+ guidance: float,
204
+ negative_prompt: Optional[str],
205
+ cfg_scale: float,
206
+ ):
207
+ seed = seed if seed is not None else random.randint(0, 2**32 - 1)
208
+ logger.info(f"Seed: {seed}")
209
+
210
+ # make first noise with packed shape
211
+ # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
212
+ packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
213
+ noise_dtype = torch.float32 if is_fp8(dtype) else dtype
214
+ noise = torch.randn(
215
+ 1,
216
+ packed_latent_height * packed_latent_width,
217
+ 16 * 2 * 2,
218
+ device=device,
219
+ dtype=noise_dtype,
220
+ generator=torch.Generator(device=device).manual_seed(seed),
221
+ )
222
+
223
+ # prepare img and img ids
224
+
225
+ # this is needed only for img2img
226
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
227
+ # if img.shape[0] == 1 and bs > 1:
228
+ # img = repeat(img, "1 ... -> bs ...", bs=bs)
229
+
230
+ # txt2img only needs img_ids
231
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
232
+
233
+ # prepare fp8 models
234
+ if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
235
+ logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
236
+ clip_l.to(clip_l_dtype) # fp8
237
+ clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
238
+ clip_l.fp8_prepared = True
239
+
240
+ if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
241
+ logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
242
+
243
+ def prepare_fp8(text_encoder, target_dtype):
244
+ def forward_hook(module):
245
+ def forward(hidden_states):
246
+ hidden_gelu = module.act(module.wi_0(hidden_states))
247
+ hidden_linear = module.wi_1(hidden_states)
248
+ hidden_states = hidden_gelu * hidden_linear
249
+ hidden_states = module.dropout(hidden_states)
250
+
251
+ hidden_states = module.wo(hidden_states)
252
+ return hidden_states
253
+
254
+ return forward
255
+
256
+ for module in text_encoder.modules():
257
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
258
+ # print("set", module.__class__.__name__, "to", target_dtype)
259
+ module.to(target_dtype)
260
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
261
+ # print("set", module.__class__.__name__, "hooks")
262
+ module.forward = forward_hook(module)
263
+
264
+ t5xxl.to(t5xxl_dtype)
265
+ prepare_fp8(t5xxl.encoder, torch.bfloat16)
266
+ t5xxl.fp8_prepared = True
267
+
268
+ # prepare embeddings
269
+ logger.info("Encoding prompts...")
270
+ clip_l = clip_l.to(device)
271
+ t5xxl = t5xxl.to(device)
272
+
273
+ def encode(prpt: str):
274
+ tokens_and_masks = tokenize_strategy.tokenize(prpt)
275
+ with torch.no_grad():
276
+ if is_fp8(clip_l_dtype):
277
+ with accelerator.autocast():
278
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
279
+ else:
280
+ with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
281
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
282
+
283
+ if is_fp8(t5xxl_dtype):
284
+ with accelerator.autocast():
285
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
286
+ tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
287
+ )
288
+ else:
289
+ with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
290
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
291
+ tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
292
+ )
293
+ return l_pooled, t5_out, txt_ids, t5_attn_mask
294
+
295
+ l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
296
+ if negative_prompt:
297
+ neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
298
+ else:
299
+ neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
300
+
301
+ # NaN check
302
+ if torch.isnan(l_pooled).any():
303
+ raise ValueError("NaN in l_pooled")
304
+ if torch.isnan(t5_out).any():
305
+ raise ValueError("NaN in t5_out")
306
+
307
+ if args.offload:
308
+ clip_l = clip_l.cpu()
309
+ t5xxl = t5xxl.cpu()
310
+ # del clip_l, t5xxl
311
+ device_utils.clean_memory()
312
+
313
+ # generate image
314
+ logger.info("Generating image...")
315
+ model = model.to(device)
316
+ if steps is None:
317
+ steps = 4 if is_schnell else 50
318
+
319
+ img_ids = img_ids.to(device)
320
+ t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
321
+
322
+ x = do_sample(
323
+ accelerator,
324
+ model,
325
+ noise,
326
+ img_ids,
327
+ l_pooled,
328
+ t5_out,
329
+ txt_ids,
330
+ steps,
331
+ guidance,
332
+ t5_attn_mask,
333
+ is_schnell,
334
+ device,
335
+ flux_dtype,
336
+ neg_l_pooled,
337
+ neg_t5_out,
338
+ neg_t5_attn_mask,
339
+ cfg_scale,
340
+ )
341
+ if args.offload:
342
+ model = model.cpu()
343
+ # del model
344
+ device_utils.clean_memory()
345
+
346
+ # unpack
347
+ x = x.float()
348
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
349
+
350
+ # decode
351
+ logger.info("Decoding image...")
352
+ ae = ae.to(device)
353
+ with torch.no_grad():
354
+ if is_fp8(ae_dtype):
355
+ with accelerator.autocast():
356
+ x = ae.decode(x)
357
+ else:
358
+ with torch.autocast(device_type=device.type, dtype=ae_dtype):
359
+ x = ae.decode(x)
360
+ if args.offload:
361
+ ae = ae.cpu()
362
+
363
+ x = x.clamp(-1, 1)
364
+ x = x.permute(0, 2, 3, 1)
365
+ img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
366
+
367
+ # save image
368
+ output_dir = args.output_dir
369
+ os.makedirs(output_dir, exist_ok=True)
370
+ output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
371
+ img.save(output_path)
372
+
373
+ logger.info(f"Saved image to {output_path}")
374
+
375
+
376
+ if __name__ == "__main__":
377
+ target_height = 768 # 1024
378
+ target_width = 1360 # 1024
379
+
380
+ # steps = 50 # 28 # 50
381
+ # guidance_scale = 5
382
+ # seed = 1 # None # 1
383
+
384
+ device = get_preferred_device()
385
+
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument("--ckpt_path", type=str, required=True)
388
+ parser.add_argument("--clip_l", type=str, required=False)
389
+ parser.add_argument("--t5xxl", type=str, required=False)
390
+ parser.add_argument("--ae", type=str, required=False)
391
+ parser.add_argument("--apply_t5_attn_mask", action="store_true")
392
+ parser.add_argument("--prompt", type=str, default="A photo of a cat")
393
+ parser.add_argument("--output_dir", type=str, default=".")
394
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
395
+ parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
396
+ parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
397
+ parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
398
+ parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
399
+ parser.add_argument("--seed", type=int, default=None)
400
+ parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
401
+ parser.add_argument("--guidance", type=float, default=3.5)
402
+ parser.add_argument("--negative_prompt", type=str, default=None)
403
+ parser.add_argument("--cfg_scale", type=float, default=1.0)
404
+ parser.add_argument("--offload", action="store_true", help="Offload to CPU")
405
+ parser.add_argument(
406
+ "--lora_weights",
407
+ type=str,
408
+ nargs="*",
409
+ default=[],
410
+ help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
411
+ )
412
+ parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
413
+ parser.add_argument("--width", type=int, default=target_width)
414
+ parser.add_argument("--height", type=int, default=target_height)
415
+ parser.add_argument("--interactive", action="store_true")
416
+ args = parser.parse_args()
417
+
418
+ seed = args.seed
419
+ steps = args.steps
420
+ guidance_scale = args.guidance
421
+
422
+ def is_fp8(dt):
423
+ return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
424
+
425
+ dtype = str_to_dtype(args.dtype)
426
+ clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
427
+ t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
428
+ ae_dtype = str_to_dtype(args.ae_dtype, dtype)
429
+ flux_dtype = str_to_dtype(args.flux_dtype, dtype)
430
+
431
+ logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
432
+
433
+ loading_device = "cpu" if args.offload else device
434
+
435
+ use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
436
+ if any(use_fp8):
437
+ accelerator = accelerate.Accelerator(mixed_precision="bf16")
438
+ else:
439
+ accelerator = None
440
+
441
+ # load clip_l
442
+ logger.info(f"Loading clip_l from {args.clip_l}...")
443
+ clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
444
+ clip_l.eval()
445
+
446
+ logger.info(f"Loading t5xxl from {args.t5xxl}...")
447
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
448
+ t5xxl.eval()
449
+
450
+ # if is_fp8(clip_l_dtype):
451
+ # clip_l = accelerator.prepare(clip_l)
452
+ # if is_fp8(t5xxl_dtype):
453
+ # t5xxl = accelerator.prepare(t5xxl)
454
+
455
+ # DiT
456
+ is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
457
+ model.eval()
458
+ logger.info(f"Casting model to {flux_dtype}")
459
+ model.to(flux_dtype) # make sure model is dtype
460
+ # if is_fp8(flux_dtype):
461
+ # model = accelerator.prepare(model)
462
+ # if args.offload:
463
+ # model = model.to("cpu")
464
+
465
+ t5xxl_max_length = 256 if is_schnell else 512
466
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
467
+ encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
468
+
469
+ # AE
470
+ ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
471
+ ae.eval()
472
+ # if is_fp8(ae_dtype):
473
+ # ae = accelerator.prepare(ae)
474
+
475
+ # LoRA
476
+ lora_models: List[lora_flux.LoRANetwork] = []
477
+ for weights_file in args.lora_weights:
478
+ if ";" in weights_file:
479
+ weights_file, multiplier = weights_file.split(";")
480
+ multiplier = float(multiplier)
481
+ else:
482
+ multiplier = 1.0
483
+
484
+ weights_sd = load_file(weights_file)
485
+ is_lora = is_oft = False
486
+ for key in weights_sd.keys():
487
+ if key.startswith("lora"):
488
+ is_lora = True
489
+ if key.startswith("oft"):
490
+ is_oft = True
491
+ if is_lora or is_oft:
492
+ break
493
+
494
+ module = lora_flux if is_lora else oft_flux
495
+ lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
496
+
497
+ if args.merge_lora_weights:
498
+ lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
499
+ else:
500
+ lora_model.apply_to([clip_l, t5xxl], model)
501
+ info = lora_model.load_state_dict(weights_sd, strict=True)
502
+ logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
503
+ lora_model.eval()
504
+ lora_model.to(device)
505
+
506
+ lora_models.append(lora_model)
507
+
508
+ if not args.interactive:
509
+ generate_image(
510
+ model,
511
+ clip_l,
512
+ t5xxl,
513
+ ae,
514
+ args.prompt,
515
+ args.seed,
516
+ args.width,
517
+ args.height,
518
+ args.steps,
519
+ args.guidance,
520
+ args.negative_prompt,
521
+ args.cfg_scale,
522
+ )
523
+ else:
524
+ # loop for interactive
525
+ width = target_width
526
+ height = target_height
527
+ steps = None
528
+ guidance = args.guidance
529
+ cfg_scale = args.cfg_scale
530
+
531
+ while True:
532
+ print(
533
+ "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
534
+ " --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
535
+ )
536
+ prompt = input()
537
+ if prompt == "":
538
+ break
539
+
540
+ # parse options
541
+ options = prompt.split("--")
542
+ prompt = options[0].strip()
543
+ seed = None
544
+ negative_prompt = None
545
+ for opt in options[1:]:
546
+ try:
547
+ opt = opt.strip()
548
+ if opt.startswith("w"):
549
+ width = int(opt[1:].strip())
550
+ elif opt.startswith("h"):
551
+ height = int(opt[1:].strip())
552
+ elif opt.startswith("s"):
553
+ steps = int(opt[1:].strip())
554
+ elif opt.startswith("d"):
555
+ seed = int(opt[1:].strip())
556
+ elif opt.startswith("g"):
557
+ guidance = float(opt[1:].strip())
558
+ elif opt.startswith("m"):
559
+ mutipliers = opt[1:].strip().split(",")
560
+ if len(mutipliers) != len(lora_models):
561
+ logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
562
+ continue
563
+ for i, lora_model in enumerate(lora_models):
564
+ lora_model.set_multiplier(float(mutipliers[i]))
565
+ elif opt.startswith("n"):
566
+ negative_prompt = opt[1:].strip()
567
+ if negative_prompt == "-":
568
+ negative_prompt = ""
569
+ elif opt.startswith("c"):
570
+ cfg_scale = float(opt[1:].strip())
571
+ except ValueError as e:
572
+ logger.error(f"Invalid option: {opt}, {e}")
573
+
574
+ generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
575
+
576
+ logger.info("Done!")
flux_minimal_inference_asylora.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Minimum Inference Code for FLUX
3
+
4
+ import argparse
5
+ import datetime
6
+ import math
7
+ import os
8
+ import random
9
+ from typing import Callable, List, Optional
10
+ import einops
11
+ import numpy as np
12
+
13
+ import torch
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ import accelerate
17
+ from transformers import CLIPTextModel
18
+ from safetensors.torch import load_file
19
+
20
+ from library import device_utils
21
+ from library.device_utils import init_ipex, get_preferred_device
22
+ from networks import oft_flux
23
+
24
+ init_ipex()
25
+
26
+
27
+ from library.utils import setup_logging, str_to_dtype
28
+
29
+ setup_logging()
30
+ import logging
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ import networks.asylora_flux as lora_flux
35
+ from library import flux_models, flux_utils, sd3_utils, strategy_flux
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
43
+ m = (y2 - y1) / (x2 - x1)
44
+ b = y1 - m * x1
45
+ return lambda x: m * x + b
46
+
47
+
48
+ def get_schedule(
49
+ num_steps: int,
50
+ image_seq_len: int,
51
+ base_shift: float = 0.5,
52
+ max_shift: float = 1.15,
53
+ shift: bool = True,
54
+ ) -> list[float]:
55
+ # extra step for zero
56
+ timesteps = torch.linspace(1, 0, num_steps + 1)
57
+
58
+ # shifting the schedule to favor high timesteps for higher signal images
59
+ if shift:
60
+ # eastimate mu based on linear estimation between two points
61
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
62
+ timesteps = time_shift(mu, 1.0, timesteps)
63
+
64
+ return timesteps.tolist()
65
+
66
+
67
+ def denoise(
68
+ model: flux_models.Flux,
69
+ img: torch.Tensor,
70
+ img_ids: torch.Tensor,
71
+ txt: torch.Tensor,
72
+ txt_ids: torch.Tensor,
73
+ vec: torch.Tensor,
74
+ timesteps: list[float],
75
+ guidance: float = 4.0,
76
+ t5_attn_mask: Optional[torch.Tensor] = None,
77
+ neg_txt: Optional[torch.Tensor] = None,
78
+ neg_vec: Optional[torch.Tensor] = None,
79
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
80
+ cfg_scale: Optional[float] = None,
81
+ ):
82
+ # this is ignored for schnell
83
+ logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
84
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
85
+
86
+ # prepare classifier free guidance
87
+ if neg_txt is not None and neg_vec is not None:
88
+ b_img_ids = torch.cat([img_ids, img_ids], dim=0)
89
+ b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
90
+ b_txt = torch.cat([neg_txt, txt], dim=0)
91
+ b_vec = torch.cat([neg_vec, vec], dim=0)
92
+ if t5_attn_mask is not None and neg_t5_attn_mask is not None:
93
+ b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
94
+ else:
95
+ b_t5_attn_mask = None
96
+ else:
97
+ b_img_ids = img_ids
98
+ b_txt_ids = txt_ids
99
+ b_txt = txt
100
+ b_vec = vec
101
+ b_t5_attn_mask = t5_attn_mask
102
+
103
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
104
+ t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
105
+
106
+ # classifier free guidance
107
+ if neg_txt is not None and neg_vec is not None:
108
+ b_img = torch.cat([img, img], dim=0)
109
+ else:
110
+ b_img = img
111
+
112
+ pred = model(
113
+ img=b_img,
114
+ img_ids=b_img_ids,
115
+ txt=b_txt,
116
+ txt_ids=b_txt_ids,
117
+ y=b_vec,
118
+ timesteps=t_vec,
119
+ guidance=guidance_vec,
120
+ txt_attention_mask=b_t5_attn_mask,
121
+ )
122
+
123
+ # classifier free guidance
124
+ if neg_txt is not None and neg_vec is not None:
125
+ pred_uncond, pred = torch.chunk(pred, 2, dim=0)
126
+ pred = pred_uncond + cfg_scale * (pred - pred_uncond)
127
+
128
+ img = img + (t_prev - t_curr) * pred
129
+
130
+ return img
131
+
132
+
133
+ def do_sample(
134
+ accelerator: Optional[accelerate.Accelerator],
135
+ model: flux_models.Flux,
136
+ img: torch.Tensor,
137
+ img_ids: torch.Tensor,
138
+ l_pooled: torch.Tensor,
139
+ t5_out: torch.Tensor,
140
+ txt_ids: torch.Tensor,
141
+ num_steps: int,
142
+ guidance: float,
143
+ t5_attn_mask: Optional[torch.Tensor],
144
+ is_schnell: bool,
145
+ device: torch.device,
146
+ flux_dtype: torch.dtype,
147
+ neg_l_pooled: Optional[torch.Tensor] = None,
148
+ neg_t5_out: Optional[torch.Tensor] = None,
149
+ neg_t5_attn_mask: Optional[torch.Tensor] = None,
150
+ cfg_scale: Optional[float] = None,
151
+ ):
152
+ logger.info(f"num_steps: {num_steps}")
153
+ timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
154
+
155
+ # denoise initial noise
156
+ if accelerator:
157
+ with accelerator.autocast(), torch.no_grad():
158
+ x = denoise(
159
+ model,
160
+ img,
161
+ img_ids,
162
+ t5_out,
163
+ txt_ids,
164
+ l_pooled,
165
+ timesteps,
166
+ guidance,
167
+ t5_attn_mask,
168
+ neg_t5_out,
169
+ neg_l_pooled,
170
+ neg_t5_attn_mask,
171
+ cfg_scale,
172
+ )
173
+ else:
174
+ with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
175
+ x = denoise(
176
+ model,
177
+ img,
178
+ img_ids,
179
+ t5_out,
180
+ txt_ids,
181
+ l_pooled,
182
+ timesteps,
183
+ guidance,
184
+ t5_attn_mask,
185
+ neg_t5_out,
186
+ neg_l_pooled,
187
+ neg_t5_attn_mask,
188
+ cfg_scale,
189
+ )
190
+
191
+ return x
192
+
193
+
194
+ def generate_image(
195
+ model,
196
+ clip_l: CLIPTextModel,
197
+ t5xxl,
198
+ ae,
199
+ prompt: str,
200
+ seed: Optional[int],
201
+ image_width: int,
202
+ image_height: int,
203
+ steps: Optional[int],
204
+ guidance: float,
205
+ negative_prompt: Optional[str],
206
+ cfg_scale: float,
207
+ ):
208
+ seed = seed if seed is not None else random.randint(0, 2**32 - 1)
209
+ logger.info(f"Seed: {seed}")
210
+
211
+ # make first noise with packed shape
212
+ # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
213
+ packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
214
+ noise_dtype = torch.float32 if is_fp8(dtype) else dtype
215
+ noise = torch.randn(
216
+ 1,
217
+ packed_latent_height * packed_latent_width,
218
+ 16 * 2 * 2,
219
+ device=device,
220
+ dtype=noise_dtype,
221
+ generator=torch.Generator(device=device).manual_seed(seed),
222
+ )
223
+
224
+ # prepare img and img ids
225
+
226
+ # this is needed only for img2img
227
+ # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
228
+ # if img.shape[0] == 1 and bs > 1:
229
+ # img = repeat(img, "1 ... -> bs ...", bs=bs)
230
+
231
+ # txt2img only needs img_ids
232
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
233
+
234
+ # prepare fp8 models
235
+ if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
236
+ logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
237
+ clip_l.to(clip_l_dtype) # fp8
238
+ clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
239
+ clip_l.fp8_prepared = True
240
+
241
+ if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
242
+ logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
243
+
244
+ def prepare_fp8(text_encoder, target_dtype):
245
+ def forward_hook(module):
246
+ def forward(hidden_states):
247
+ hidden_gelu = module.act(module.wi_0(hidden_states))
248
+ hidden_linear = module.wi_1(hidden_states)
249
+ hidden_states = hidden_gelu * hidden_linear
250
+ hidden_states = module.dropout(hidden_states)
251
+
252
+ hidden_states = module.wo(hidden_states)
253
+ return hidden_states
254
+
255
+ return forward
256
+
257
+ for module in text_encoder.modules():
258
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
259
+ # print("set", module.__class__.__name__, "to", target_dtype)
260
+ module.to(target_dtype)
261
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
262
+ # print("set", module.__class__.__name__, "hooks")
263
+ module.forward = forward_hook(module)
264
+
265
+ t5xxl.to(t5xxl_dtype)
266
+ prepare_fp8(t5xxl.encoder, torch.bfloat16)
267
+ t5xxl.fp8_prepared = True
268
+
269
+ # prepare embeddings
270
+ logger.info("Encoding prompts...")
271
+ clip_l = clip_l.to(device)
272
+ t5xxl = t5xxl.to(device)
273
+
274
+ def encode(prpt: str):
275
+ tokens_and_masks = tokenize_strategy.tokenize(prpt)
276
+ with torch.no_grad():
277
+ if is_fp8(clip_l_dtype):
278
+ with accelerator.autocast():
279
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
280
+ else:
281
+ with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
282
+ l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
283
+
284
+ if is_fp8(t5xxl_dtype):
285
+ with accelerator.autocast():
286
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
287
+ tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
288
+ )
289
+ else:
290
+ with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
291
+ _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
292
+ tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
293
+ )
294
+ return l_pooled, t5_out, txt_ids, t5_attn_mask
295
+
296
+ l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
297
+ if negative_prompt:
298
+ neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
299
+ else:
300
+ neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
301
+
302
+ # NaN check
303
+ if torch.isnan(l_pooled).any():
304
+ raise ValueError("NaN in l_pooled")
305
+ if torch.isnan(t5_out).any():
306
+ raise ValueError("NaN in t5_out")
307
+
308
+ if args.offload:
309
+ clip_l = clip_l.cpu()
310
+ t5xxl = t5xxl.cpu()
311
+ # del clip_l, t5xxl
312
+ device_utils.clean_memory()
313
+
314
+ # generate image
315
+ logger.info("Generating image...")
316
+ model = model.to(device)
317
+ if steps is None:
318
+ steps = 4 if is_schnell else 50
319
+
320
+ img_ids = img_ids.to(device)
321
+ t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
322
+
323
+ x = do_sample(
324
+ accelerator,
325
+ model,
326
+ noise,
327
+ img_ids,
328
+ l_pooled,
329
+ t5_out,
330
+ txt_ids,
331
+ steps,
332
+ guidance,
333
+ t5_attn_mask,
334
+ is_schnell,
335
+ device,
336
+ flux_dtype,
337
+ neg_l_pooled,
338
+ neg_t5_out,
339
+ neg_t5_attn_mask,
340
+ cfg_scale,
341
+ )
342
+ if args.offload:
343
+ model = model.cpu()
344
+ # del model
345
+ device_utils.clean_memory()
346
+
347
+ # unpack
348
+ x = x.float()
349
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
350
+
351
+ # decode
352
+ logger.info("Decoding image...")
353
+ ae = ae.to(device)
354
+ with torch.no_grad():
355
+ if is_fp8(ae_dtype):
356
+ with accelerator.autocast():
357
+ x = ae.decode(x)
358
+ else:
359
+ with torch.autocast(device_type=device.type, dtype=ae_dtype):
360
+ x = ae.decode(x)
361
+ if args.offload:
362
+ ae = ae.cpu()
363
+
364
+ x = x.clamp(-1, 1)
365
+ x = x.permute(0, 2, 3, 1)
366
+ img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
367
+
368
+ # save image
369
+ output_dir = args.output_dir
370
+ os.makedirs(output_dir, exist_ok=True)
371
+ output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
372
+ img.save(output_path)
373
+
374
+ logger.info(f"Saved image to {output_path}")
375
+
376
+
377
+ if __name__ == "__main__":
378
+ target_height = 768 # 1024
379
+ target_width = 1360 # 1024
380
+
381
+ # steps = 50 # 28 # 50
382
+ # guidance_scale = 5
383
+ # seed = 1 # None # 1
384
+
385
+ device = get_preferred_device()
386
+
387
+ parser = argparse.ArgumentParser()
388
+ parser.add_argument("--lora_ups_num", type=int, required=True)
389
+ parser.add_argument("--lora_up_cur", type=int, required=True)
390
+ parser.add_argument("--ckpt_path", type=str, required=True)
391
+ parser.add_argument("--clip_l", type=str, required=False)
392
+ parser.add_argument("--t5xxl", type=str, required=False)
393
+ parser.add_argument("--ae", type=str, required=False)
394
+ parser.add_argument("--apply_t5_attn_mask", action="store_true")
395
+ parser.add_argument("--prompt", type=str, default="A photo of a cat")
396
+ parser.add_argument("--output_dir", type=str, default=".")
397
+ parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
398
+ parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
399
+ parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
400
+ parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
401
+ parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
402
+ parser.add_argument("--seed", type=int, default=None)
403
+ parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
404
+ parser.add_argument("--guidance", type=float, default=3.5)
405
+ parser.add_argument("--negative_prompt", type=str, default=None)
406
+ parser.add_argument("--cfg_scale", type=float, default=1.0)
407
+ parser.add_argument("--offload", action="store_true", help="Offload to CPU")
408
+ parser.add_argument(
409
+ "--lora_weights",
410
+ type=str,
411
+ nargs="*",
412
+ default=[],
413
+ help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
414
+ )
415
+ parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
416
+ parser.add_argument("--width", type=int, default=target_width)
417
+ parser.add_argument("--height", type=int, default=target_height)
418
+ parser.add_argument("--interactive", action="store_true")
419
+ args = parser.parse_args()
420
+
421
+ seed = args.seed
422
+ steps = args.steps
423
+ guidance_scale = args.guidance
424
+ lora_ups_num = args.lora_ups_num
425
+ lora_up_cur = args.lora_up_cur
426
+
427
+ def is_fp8(dt):
428
+ return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
429
+
430
+ dtype = str_to_dtype(args.dtype)
431
+ clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
432
+ t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
433
+ ae_dtype = str_to_dtype(args.ae_dtype, dtype)
434
+ flux_dtype = str_to_dtype(args.flux_dtype, dtype)
435
+
436
+ logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
437
+
438
+ loading_device = "cpu" if args.offload else device
439
+
440
+ use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
441
+ if any(use_fp8):
442
+ accelerator = accelerate.Accelerator(mixed_precision="bf16")
443
+ else:
444
+ accelerator = None
445
+
446
+ # load clip_l
447
+ logger.info(f"Loading clip_l from {args.clip_l}...")
448
+ clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
449
+ clip_l.eval()
450
+
451
+ logger.info(f"Loading t5xxl from {args.t5xxl}...")
452
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
453
+ t5xxl.eval()
454
+
455
+ # if is_fp8(clip_l_dtype):
456
+ # clip_l = accelerator.prepare(clip_l)
457
+ # if is_fp8(t5xxl_dtype):
458
+ # t5xxl = accelerator.prepare(t5xxl)
459
+
460
+ # DiT
461
+ is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
462
+ model.eval()
463
+ logger.info(f"Casting model to {flux_dtype}")
464
+ model.to(flux_dtype) # make sure model is dtype
465
+ # if is_fp8(flux_dtype):
466
+ # model = accelerator.prepare(model)
467
+ # if args.offload:
468
+ # model = model.to("cpu")
469
+
470
+ t5xxl_max_length = 256 if is_schnell else 512
471
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
472
+ encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
473
+
474
+ # AE
475
+ ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
476
+ ae.eval()
477
+ # if is_fp8(ae_dtype):
478
+ # ae = accelerator.prepare(ae)
479
+
480
+ # LoRA
481
+ lora_models: List[lora_flux.LoRANetwork] = []
482
+ for weights_file in args.lora_weights:
483
+ if ";" in weights_file:
484
+ weights_file, multiplier = weights_file.split(";")
485
+ multiplier = float(multiplier)
486
+ else:
487
+ multiplier = 1.0
488
+
489
+ weights_sd = load_file(weights_file)
490
+ is_lora = is_oft = False
491
+ for key in weights_sd.keys():
492
+ if key.startswith("lora"):
493
+ is_lora = True
494
+ if key.startswith("oft"):
495
+ is_oft = True
496
+ if is_lora or is_oft:
497
+ break
498
+
499
+ module = lora_flux if is_lora else oft_flux
500
+ lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num)
501
+ for sub_lora in lora_model.unet_loras:
502
+ sub_lora.set_lora_up_cur(lora_up_cur-1)
503
+
504
+ if args.merge_lora_weights:
505
+ lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
506
+ else:
507
+ lora_model.apply_to([clip_l, t5xxl], model)
508
+ info = lora_model.load_state_dict(weights_sd, strict=True)
509
+ logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
510
+ lora_model.eval()
511
+ lora_model.to(device)
512
+
513
+ lora_models.append(lora_model)
514
+
515
+ if not args.interactive:
516
+ generate_image(
517
+ model,
518
+ clip_l,
519
+ t5xxl,
520
+ ae,
521
+ args.prompt,
522
+ args.seed,
523
+ args.width,
524
+ args.height,
525
+ args.steps,
526
+ args.guidance,
527
+ args.negative_prompt,
528
+ args.cfg_scale,
529
+ )
530
+ else:
531
+ # loop for interactive
532
+ width = target_width
533
+ height = target_height
534
+ steps = None
535
+ guidance = args.guidance
536
+ cfg_scale = args.cfg_scale
537
+
538
+ while True:
539
+ print(
540
+ "Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
541
+ " --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
542
+ )
543
+ prompt = input()
544
+ if prompt == "":
545
+ break
546
+
547
+ # parse options
548
+ options = prompt.split("--")
549
+ prompt = options[0].strip()
550
+ seed = None
551
+ negative_prompt = None
552
+ for opt in options[1:]:
553
+ try:
554
+ opt = opt.strip()
555
+ if opt.startswith("w"):
556
+ width = int(opt[1:].strip())
557
+ elif opt.startswith("h"):
558
+ height = int(opt[1:].strip())
559
+ elif opt.startswith("s"):
560
+ steps = int(opt[1:].strip())
561
+ elif opt.startswith("d"):
562
+ seed = int(opt[1:].strip())
563
+ elif opt.startswith("g"):
564
+ guidance = float(opt[1:].strip())
565
+ elif opt.startswith("m"):
566
+ mutipliers = opt[1:].strip().split(",")
567
+ if len(mutipliers) != len(lora_models):
568
+ logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
569
+ continue
570
+ for i, lora_model in enumerate(lora_models):
571
+ lora_model.set_multiplier(float(mutipliers[i]))
572
+ elif opt.startswith("n"):
573
+ negative_prompt = opt[1:].strip()
574
+ if negative_prompt == "-":
575
+ negative_prompt = ""
576
+ elif opt.startswith("c"):
577
+ cfg_scale = float(opt[1:].strip())
578
+ except ValueError as e:
579
+ logger.error(f"Invalid option: {opt}, {e}")
580
+
581
+ generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
582
+
583
+ logger.info("Done!")
flux_train_network.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any, Optional, Union
6
+
7
+ import torch
8
+ from accelerate import Accelerator
9
+
10
+ from library.device_utils import clean_memory_on_device, init_ipex
11
+
12
+ init_ipex()
13
+
14
+ import train_network
15
+ from library import (
16
+ flux_models,
17
+ flux_train_utils,
18
+ flux_utils,
19
+ sd3_train_utils,
20
+ strategy_base,
21
+ strategy_flux,
22
+ train_util,
23
+ )
24
+ from library.utils import setup_logging
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class FluxNetworkTrainer(train_network.NetworkTrainer):
33
+ def __init__(self):
34
+ super().__init__()
35
+ self.sample_prompts_te_outputs = None
36
+ self.is_schnell: Optional[bool] = None
37
+ self.is_swapping_blocks: bool = False
38
+
39
+ def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
40
+ super().assert_extra_args(args, train_dataset_group, val_dataset_group)
41
+ # sdxl_train_util.verify_sdxl_training_args(args)
42
+
43
+ if args.fp8_base_unet:
44
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
45
+
46
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
47
+ logger.warning(
48
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
49
+ )
50
+ args.cache_text_encoder_outputs = True
51
+
52
+ if args.cache_text_encoder_outputs:
53
+ assert (
54
+ train_dataset_group.is_text_encoder_output_cacheable()
55
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
56
+
57
+ # prepare CLIP-L/T5XXL training flags
58
+ self.train_clip_l = not args.network_train_unet_only
59
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
60
+
61
+ if args.max_token_length is not None:
62
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
63
+
64
+ assert (
65
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
66
+ ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
67
+
68
+ # deprecated split_mode option
69
+ if args.split_mode:
70
+ if args.blocks_to_swap is not None:
71
+ logger.warning(
72
+ "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
73
+ " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
74
+ )
75
+ else:
76
+ logger.warning(
77
+ "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
78
+ " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
79
+ )
80
+ args.blocks_to_swap = 18 # 18 is safe for most cases
81
+
82
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
83
+ if val_dataset_group is not None:
84
+ val_dataset_group.verify_bucket_reso_steps(32) # TODO check this
85
+
86
+ def load_target_model(self, args, weight_dtype, accelerator):
87
+ # currently offload to cpu for some models
88
+
89
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
90
+ loading_dtype = None if args.fp8_base else weight_dtype
91
+
92
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
93
+ self.is_schnell, model = flux_utils.load_flow_model(
94
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
95
+ )
96
+ if args.fp8_base:
97
+ # check dtype of model
98
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
99
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
100
+ elif model.dtype == torch.float8_e4m3fn:
101
+ logger.info("Loaded fp8 FLUX model")
102
+ else:
103
+ logger.info(
104
+ "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
105
+ " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
106
+ )
107
+ model.to(torch.float8_e4m3fn)
108
+
109
+ # if args.split_mode:
110
+ # model = self.prepare_split_model(model, weight_dtype, accelerator)
111
+
112
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
113
+ if self.is_swapping_blocks:
114
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
115
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
116
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
117
+
118
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
119
+ clip_l.eval()
120
+
121
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
122
+ if args.fp8_base and not args.fp8_base_unet:
123
+ loading_dtype = None # as is
124
+ else:
125
+ loading_dtype = weight_dtype
126
+
127
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
128
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
129
+ t5xxl.eval()
130
+ if args.fp8_base and not args.fp8_base_unet:
131
+ # check dtype of model
132
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
133
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
134
+ elif t5xxl.dtype == torch.float8_e4m3fn:
135
+ logger.info("Loaded fp8 T5XXL model")
136
+
137
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
138
+
139
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
140
+
141
+ def get_tokenize_strategy(self, args):
142
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
143
+
144
+ if args.t5xxl_max_token_length is None:
145
+ if is_schnell:
146
+ t5xxl_max_token_length = 256
147
+ else:
148
+ t5xxl_max_token_length = 512
149
+ else:
150
+ t5xxl_max_token_length = args.t5xxl_max_token_length
151
+
152
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
153
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
154
+
155
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
156
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
157
+
158
+ def get_latents_caching_strategy(self, args):
159
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
160
+ return latents_caching_strategy
161
+
162
+ def get_text_encoding_strategy(self, args):
163
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
164
+
165
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
166
+ # check t5xxl is trained or not
167
+ self.train_t5xxl = network.train_t5xxl
168
+
169
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
170
+ raise ValueError(
171
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
172
+ )
173
+
174
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
175
+ if args.cache_text_encoder_outputs:
176
+ if self.train_clip_l and not self.train_t5xxl:
177
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
178
+ else:
179
+ return None # no text encoders are needed for encoding because both are cached
180
+ else:
181
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
182
+
183
+ def get_text_encoders_train_flags(self, args, text_encoders):
184
+ return [self.train_clip_l, self.train_t5xxl]
185
+
186
+ def get_text_encoder_outputs_caching_strategy(self, args):
187
+ if args.cache_text_encoder_outputs:
188
+ # if the text encoders is trained, we need tokenization, so is_partial is True
189
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
190
+ args.cache_text_encoder_outputs_to_disk,
191
+ args.text_encoder_batch_size,
192
+ args.skip_cache_check,
193
+ is_partial=self.train_clip_l or self.train_t5xxl,
194
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
195
+ )
196
+ else:
197
+ return None
198
+
199
+ def cache_text_encoder_outputs_if_needed(
200
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
201
+ ):
202
+ if args.cache_text_encoder_outputs:
203
+ if not args.lowram:
204
+ # メモリ消費を減らす
205
+ logger.info("move vae and unet to cpu to save memory")
206
+ org_vae_device = vae.device
207
+ org_unet_device = unet.device
208
+ vae.to("cpu")
209
+ unet.to("cpu")
210
+ clean_memory_on_device(accelerator.device)
211
+
212
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
213
+ logger.info("move text encoders to gpu")
214
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
215
+ text_encoders[1].to(accelerator.device)
216
+
217
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
218
+ # if we load fp8 weights, the model is already fp8, so we use it as is
219
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
220
+ else:
221
+ # otherwise, we need to convert it to target dtype
222
+ text_encoders[1].to(weight_dtype)
223
+
224
+ with accelerator.autocast():
225
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
226
+
227
+ # cache sample prompts
228
+ if args.sample_prompts is not None:
229
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
230
+
231
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
232
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
233
+
234
+ prompts = train_util.load_prompts(args.sample_prompts)
235
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
236
+ with accelerator.autocast(), torch.no_grad():
237
+ for prompt_dict in prompts:
238
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
239
+ if p not in sample_prompts_te_outputs:
240
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
241
+ tokens_and_masks = tokenize_strategy.tokenize(p)
242
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
243
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
244
+ )
245
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
246
+
247
+ accelerator.wait_for_everyone()
248
+
249
+ # move back to cpu
250
+ if not self.is_train_text_encoder(args):
251
+ logger.info("move CLIP-L back to cpu")
252
+ text_encoders[0].to("cpu")
253
+ logger.info("move t5XXL back to cpu")
254
+ text_encoders[1].to("cpu")
255
+ clean_memory_on_device(accelerator.device)
256
+
257
+ if not args.lowram:
258
+ logger.info("move vae and unet back to original device")
259
+ vae.to(org_vae_device)
260
+ unet.to(org_unet_device)
261
+ else:
262
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
263
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
264
+ text_encoders[1].to(accelerator.device)
265
+
266
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
267
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
268
+
269
+ # # get size embeddings
270
+ # orig_size = batch["original_sizes_hw"]
271
+ # crop_size = batch["crop_top_lefts"]
272
+ # target_size = batch["target_sizes_hw"]
273
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
274
+
275
+ # # concat embeddings
276
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
277
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
278
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
279
+
280
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
281
+ # return noise_pred
282
+
283
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
284
+ text_encoders = text_encoder # for compatibility
285
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
286
+
287
+ flux_train_utils.sample_images(
288
+ accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
289
+ )
290
+ # return
291
+
292
+ """
293
+ class FluxUpperLowerWrapper(torch.nn.Module):
294
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
295
+ super().__init__()
296
+ self.flux_upper = flux_upper
297
+ self.flux_lower = flux_lower
298
+ self.target_device = device
299
+
300
+ def prepare_block_swap_before_forward(self):
301
+ pass
302
+
303
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
304
+ self.flux_lower.to("cpu")
305
+ clean_memory_on_device(self.target_device)
306
+ self.flux_upper.to(self.target_device)
307
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
308
+ self.flux_upper.to("cpu")
309
+ clean_memory_on_device(self.target_device)
310
+ self.flux_lower.to(self.target_device)
311
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
312
+
313
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
314
+ clean_memory_on_device(accelerator.device)
315
+ flux_train_utils.sample_images(
316
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
317
+ )
318
+ clean_memory_on_device(accelerator.device)
319
+ """
320
+
321
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
322
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
323
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
324
+ return noise_scheduler
325
+
326
+ def encode_images_to_latents(self, args, accelerator, vae, images):
327
+ return vae.encode(images)
328
+
329
+ def shift_scale_latents(self, args, latents):
330
+ return latents
331
+
332
+ def get_noise_pred_and_target(
333
+ self,
334
+ args,
335
+ accelerator,
336
+ noise_scheduler,
337
+ latents,
338
+ batch,
339
+ text_encoder_conds,
340
+ unet: flux_models.Flux,
341
+ network,
342
+ weight_dtype,
343
+ train_unet,
344
+ is_train=True
345
+ ):
346
+ # Sample noise that we'll add to the latents
347
+ noise = torch.randn_like(latents)
348
+ bsz = latents.shape[0]
349
+
350
+ # get noisy model input and timesteps
351
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
352
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
353
+ )
354
+
355
+ # pack latents and get img_ids
356
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
357
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
358
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
359
+
360
+ # get guidance
361
+ # ensure guidance_scale in args is float
362
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
363
+
364
+ # ensure the hidden state will require grad
365
+ if args.gradient_checkpointing:
366
+ noisy_model_input.requires_grad_(True)
367
+ for t in text_encoder_conds:
368
+ if t is not None and t.dtype.is_floating_point:
369
+ t.requires_grad_(True)
370
+ img_ids.requires_grad_(True)
371
+ guidance_vec.requires_grad_(True)
372
+
373
+ # Predict the noise residual
374
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
375
+ if not args.apply_t5_attn_mask:
376
+ t5_attn_mask = None
377
+
378
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
379
+ # if not args.split_mode:
380
+ # normal forward
381
+ with torch.set_grad_enabled(is_train), accelerator.autocast():
382
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
383
+ model_pred = unet(
384
+ img=img,
385
+ img_ids=img_ids,
386
+ txt=t5_out,
387
+ txt_ids=txt_ids,
388
+ y=l_pooled,
389
+ timesteps=timesteps / 1000,
390
+ guidance=guidance_vec,
391
+ txt_attention_mask=t5_attn_mask,
392
+ )
393
+ """
394
+ else:
395
+ # split forward to reduce memory usage
396
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
397
+ with accelerator.autocast():
398
+ # move flux lower to cpu, and then move flux upper to gpu
399
+ unet.to("cpu")
400
+ clean_memory_on_device(accelerator.device)
401
+ self.flux_upper.to(accelerator.device)
402
+
403
+ # upper model does not require grad
404
+ with torch.no_grad():
405
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
406
+ img=packed_noisy_model_input,
407
+ img_ids=img_ids,
408
+ txt=t5_out,
409
+ txt_ids=txt_ids,
410
+ y=l_pooled,
411
+ timesteps=timesteps / 1000,
412
+ guidance=guidance_vec,
413
+ txt_attention_mask=t5_attn_mask,
414
+ )
415
+
416
+ # move flux upper back to cpu, and then move flux lower to gpu
417
+ self.flux_upper.to("cpu")
418
+ clean_memory_on_device(accelerator.device)
419
+ unet.to(accelerator.device)
420
+
421
+ # lower model requires grad
422
+ intermediate_img.requires_grad_(True)
423
+ intermediate_txt.requires_grad_(True)
424
+ vec.requires_grad_(True)
425
+ pe.requires_grad_(True)
426
+
427
+ with torch.set_grad_enabled(is_train and train_unet):
428
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
429
+ """
430
+
431
+ return model_pred
432
+
433
+ model_pred = call_dit(
434
+ img=packed_noisy_model_input,
435
+ img_ids=img_ids,
436
+ t5_out=t5_out,
437
+ txt_ids=txt_ids,
438
+ l_pooled=l_pooled,
439
+ timesteps=timesteps,
440
+ guidance_vec=guidance_vec,
441
+ t5_attn_mask=t5_attn_mask,
442
+ )
443
+
444
+ # unpack latents
445
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
446
+
447
+ # apply model prediction type
448
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
449
+
450
+ # flow matching loss: this is different from SD3
451
+ target = noise - latents
452
+
453
+ # differential output preservation
454
+ if "custom_attributes" in batch:
455
+ diff_output_pr_indices = []
456
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
457
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
458
+ diff_output_pr_indices.append(i)
459
+
460
+ if len(diff_output_pr_indices) > 0:
461
+ network.set_multiplier(0.0)
462
+ unet.prepare_block_swap_before_forward()
463
+ with torch.no_grad():
464
+ model_pred_prior = call_dit(
465
+ img=packed_noisy_model_input[diff_output_pr_indices],
466
+ img_ids=img_ids[diff_output_pr_indices],
467
+ t5_out=t5_out[diff_output_pr_indices],
468
+ txt_ids=txt_ids[diff_output_pr_indices],
469
+ l_pooled=l_pooled[diff_output_pr_indices],
470
+ timesteps=timesteps[diff_output_pr_indices],
471
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
472
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
473
+ )
474
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
475
+
476
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
477
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
478
+ args,
479
+ model_pred_prior,
480
+ noisy_model_input[diff_output_pr_indices],
481
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
482
+ )
483
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
484
+
485
+ return model_pred, target, timesteps, weighting
486
+
487
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
488
+ return loss
489
+
490
+ def get_sai_model_spec(self, args):
491
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
492
+
493
+ def update_metadata(self, metadata, args):
494
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
495
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
496
+ metadata["ss_logit_mean"] = args.logit_mean
497
+ metadata["ss_logit_std"] = args.logit_std
498
+ metadata["ss_mode_scale"] = args.mode_scale
499
+ metadata["ss_guidance_scale"] = args.guidance_scale
500
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
501
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
502
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
503
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
504
+
505
+ def is_text_encoder_not_needed_for_training(self, args):
506
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
507
+
508
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
509
+ if index == 0: # CLIP-L
510
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
511
+ else: # T5XXL
512
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
513
+
514
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
515
+ if index == 0: # CLIP-L
516
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
517
+ text_encoder.to(te_weight_dtype) # fp8
518
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
519
+ else: # T5XXL
520
+
521
+ def prepare_fp8(text_encoder, target_dtype):
522
+ def forward_hook(module):
523
+ def forward(hidden_states):
524
+ hidden_gelu = module.act(module.wi_0(hidden_states))
525
+ hidden_linear = module.wi_1(hidden_states)
526
+ hidden_states = hidden_gelu * hidden_linear
527
+ hidden_states = module.dropout(hidden_states)
528
+
529
+ hidden_states = module.wo(hidden_states)
530
+ return hidden_states
531
+
532
+ return forward
533
+
534
+ for module in text_encoder.modules():
535
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
536
+ # print("set", module.__class__.__name__, "to", target_dtype)
537
+ module.to(target_dtype)
538
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
539
+ # print("set", module.__class__.__name__, "hooks")
540
+ module.forward = forward_hook(module)
541
+
542
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
543
+ logger.info(f"T5XXL already prepared for fp8")
544
+ else:
545
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
546
+ text_encoder.to(te_weight_dtype) # fp8
547
+ prepare_fp8(text_encoder, weight_dtype)
548
+
549
+ def prepare_unet_with_accelerator(
550
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
551
+ ) -> torch.nn.Module:
552
+ if not self.is_swapping_blocks:
553
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
554
+
555
+ # if we doesn't swap blocks, we can move the model to device
556
+ flux: flux_models.Flux = unet
557
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
558
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
559
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
560
+
561
+ return flux
562
+
563
+
564
+ def setup_parser() -> argparse.ArgumentParser:
565
+ parser = train_network.setup_parser()
566
+ train_util.add_dit_training_arguments(parser)
567
+ flux_train_utils.add_flux_train_arguments(parser)
568
+
569
+ parser.add_argument(
570
+ "--split_mode",
571
+ action="store_true",
572
+ # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
573
+ # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
574
+ help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
575
+ " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
576
+ )
577
+ return parser
578
+
579
+
580
+ if __name__ == "__main__":
581
+ parser = setup_parser()
582
+
583
+ args = parser.parse_args()
584
+ train_util.verify_command_line_training_args(args)
585
+ args = train_util.read_config_from_file(args, parser)
586
+
587
+ trainer = FluxNetworkTrainer()
588
+ trainer.train(args)
flux_train_network_asylora.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+
4
+ import argparse
5
+ import copy
6
+ import math
7
+ import random
8
+ from typing import Any, Optional
9
+
10
+ import torch
11
+ from accelerate import Accelerator
12
+ from library.device_utils import init_ipex, clean_memory_on_device
13
+
14
+ init_ipex()
15
+
16
+ from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
17
+ import train_network_asylora
18
+ from library.utils import setup_logging
19
+
20
+ setup_logging()
21
+ import logging
22
+ import re
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class FluxNetworkTrainer(train_network_asylora.NetworkTrainer):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.sample_prompts_te_outputs = None
31
+ self.is_schnell: Optional[bool] = None
32
+ self.is_swapping_blocks: bool = False
33
+
34
+ def assert_extra_args(self, args, train_dataset_group):
35
+ super().assert_extra_args(args, train_dataset_group)
36
+ # sdxl_train_util.verify_sdxl_training_args(args)
37
+
38
+ if args.fp8_base_unet:
39
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
40
+
41
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
42
+ logger.warning(
43
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
44
+ )
45
+ args.cache_text_encoder_outputs = True
46
+
47
+ if args.cache_text_encoder_outputs:
48
+ assert (
49
+ train_dataset_group.is_text_encoder_output_cacheable()
50
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
51
+
52
+ # prepare CLIP-L/T5XXL training flags
53
+ self.train_clip_l = not args.network_train_unet_only
54
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
55
+
56
+ if args.max_token_length is not None:
57
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
58
+
59
+ assert (
60
+ args.blocks_to_swap is None or args.blocks_to_swap == 0
61
+ ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
62
+
63
+ # deprecated split_mode option
64
+ if args.split_mode:
65
+ if args.blocks_to_swap is not None:
66
+ logger.warning(
67
+ "split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
68
+ " / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
69
+ )
70
+ else:
71
+ logger.warning(
72
+ "split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
73
+ " / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
74
+ )
75
+ args.blocks_to_swap = 18 # 18 is safe for most cases
76
+
77
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
78
+
79
+ def load_target_model(self, args, weight_dtype, accelerator):
80
+ # currently offload to cpu for some models
81
+
82
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
83
+ loading_dtype = None if args.fp8_base else weight_dtype
84
+
85
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
86
+ self.is_schnell, model = flux_utils.load_flow_model(
87
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
88
+ )
89
+ if args.fp8_base:
90
+ # check dtype of model
91
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
92
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
93
+ elif model.dtype == torch.float8_e4m3fn:
94
+ logger.info("Loaded fp8 FLUX model")
95
+ else:
96
+ logger.info(
97
+ "Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
98
+ " / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
99
+ )
100
+ model.to(torch.float8_e4m3fn)
101
+
102
+ # if args.split_mode:
103
+ # model = self.prepare_split_model(model, weight_dtype, accelerator)
104
+
105
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
106
+ if self.is_swapping_blocks:
107
+ # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
108
+ logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
109
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
110
+
111
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
112
+ clip_l.eval()
113
+
114
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
115
+ if args.fp8_base and not args.fp8_base_unet:
116
+ loading_dtype = None # as is
117
+ else:
118
+ loading_dtype = weight_dtype
119
+
120
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
121
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
122
+ t5xxl.eval()
123
+ if args.fp8_base and not args.fp8_base_unet:
124
+ # check dtype of model
125
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
126
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
127
+ elif t5xxl.dtype == torch.float8_e4m3fn:
128
+ logger.info("Loaded fp8 T5XXL model")
129
+
130
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
131
+
132
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
133
+
134
+ def get_tokenize_strategy(self, args):
135
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
136
+
137
+ if args.t5xxl_max_token_length is None:
138
+ if is_schnell:
139
+ t5xxl_max_token_length = 256
140
+ else:
141
+ t5xxl_max_token_length = 512
142
+ else:
143
+ t5xxl_max_token_length = args.t5xxl_max_token_length
144
+
145
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
146
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
147
+
148
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
149
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
150
+
151
+ def get_latents_caching_strategy(self, args):
152
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
153
+ return latents_caching_strategy
154
+
155
+ def get_text_encoding_strategy(self, args):
156
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
157
+
158
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
159
+ # check t5xxl is trained or not
160
+ self.train_t5xxl = network.train_t5xxl
161
+
162
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
163
+ raise ValueError(
164
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
165
+ )
166
+
167
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
168
+ if args.cache_text_encoder_outputs:
169
+ if self.train_clip_l and not self.train_t5xxl:
170
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
171
+ else:
172
+ return None # no text encoders are needed for encoding because both are cached
173
+ else:
174
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
175
+
176
+ def get_text_encoders_train_flags(self, args, text_encoders):
177
+ return [self.train_clip_l, self.train_t5xxl]
178
+
179
+ def get_text_encoder_outputs_caching_strategy(self, args):
180
+ if args.cache_text_encoder_outputs:
181
+ # if the text encoders is trained, we need tokenization, so is_partial is True
182
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
183
+ args.cache_text_encoder_outputs_to_disk,
184
+ args.text_encoder_batch_size,
185
+ args.skip_cache_check,
186
+ is_partial=self.train_clip_l or self.train_t5xxl,
187
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
188
+ )
189
+ else:
190
+ return None
191
+
192
+ def cache_text_encoder_outputs_if_needed(
193
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
194
+ ):
195
+ if args.cache_text_encoder_outputs:
196
+ if not args.lowram:
197
+ # メモリ消費を減らす
198
+ logger.info("move vae and unet to cpu to save memory")
199
+ org_vae_device = vae.device
200
+ org_unet_device = unet.device
201
+ vae.to("cpu")
202
+ unet.to("cpu")
203
+ clean_memory_on_device(accelerator.device)
204
+
205
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
206
+ logger.info("move text encoders to gpu")
207
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
208
+ text_encoders[1].to(accelerator.device)
209
+
210
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
211
+ # if we load fp8 weights, the model is already fp8, so we use it as is
212
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
213
+ else:
214
+ # otherwise, we need to convert it to target dtype
215
+ text_encoders[1].to(weight_dtype)
216
+
217
+ with accelerator.autocast():
218
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
219
+
220
+ # cache sample prompts
221
+ if args.sample_prompts is not None:
222
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
223
+
224
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
225
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
226
+
227
+ prompts = train_util.load_prompts(args.sample_prompts)
228
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
229
+ with accelerator.autocast(), torch.no_grad():
230
+ for prompt_dict in prompts:
231
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
232
+ if p not in sample_prompts_te_outputs:
233
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
234
+ tokens_and_masks = tokenize_strategy.tokenize(p)
235
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
236
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
237
+ )
238
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
239
+
240
+ accelerator.wait_for_everyone()
241
+
242
+ # move back to cpu
243
+ if not self.is_train_text_encoder(args):
244
+ logger.info("move CLIP-L back to cpu")
245
+ text_encoders[0].to("cpu")
246
+ logger.info("move t5XXL back to cpu")
247
+ text_encoders[1].to("cpu")
248
+ clean_memory_on_device(accelerator.device)
249
+
250
+ if not args.lowram:
251
+ logger.info("move vae and unet back to original device")
252
+ vae.to(org_vae_device)
253
+ unet.to(org_unet_device)
254
+ else:
255
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
256
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
257
+ text_encoders[1].to(accelerator.device)
258
+
259
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
260
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
261
+
262
+ # # get size embeddings
263
+ # orig_size = batch["original_sizes_hw"]
264
+ # crop_size = batch["crop_top_lefts"]
265
+ # target_size = batch["target_sizes_hw"]
266
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
267
+
268
+ # # concat embeddings
269
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
270
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
271
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
272
+
273
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
274
+ # return noise_pred
275
+
276
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
277
+ text_encoders = text_encoder # for compatibility
278
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
279
+
280
+ flux_train_utils.sample_images(
281
+ accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
282
+ )
283
+ # return
284
+
285
+ """
286
+ class FluxUpperLowerWrapper(torch.nn.Module):
287
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
288
+ super().__init__()
289
+ self.flux_upper = flux_upper
290
+ self.flux_lower = flux_lower
291
+ self.target_device = device
292
+
293
+ def prepare_block_swap_before_forward(self):
294
+ pass
295
+
296
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
297
+ self.flux_lower.to("cpu")
298
+ clean_memory_on_device(self.target_device)
299
+ self.flux_upper.to(self.target_device)
300
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
301
+ self.flux_upper.to("cpu")
302
+ clean_memory_on_device(self.target_device)
303
+ self.flux_lower.to(self.target_device)
304
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
305
+
306
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
307
+ clean_memory_on_device(accelerator.device)
308
+ flux_train_utils.sample_images(
309
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
310
+ )
311
+ clean_memory_on_device(accelerator.device)
312
+ """
313
+
314
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
315
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
316
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
317
+ return noise_scheduler
318
+
319
+ def encode_images_to_latents(self, args, accelerator, vae, images):
320
+ return vae.encode(images)
321
+
322
+ def shift_scale_latents(self, args, latents):
323
+ return latents
324
+
325
+ def get_noise_pred_and_target(
326
+ self,
327
+ args,
328
+ accelerator,
329
+ noise_scheduler,
330
+ latents,
331
+ batch,
332
+ text_encoder_conds,
333
+ unet: flux_models.Flux,
334
+ network,
335
+ weight_dtype,
336
+ train_unet,
337
+ ):
338
+ # Sample noise that we'll add to the latents
339
+ noise = torch.randn_like(latents)
340
+ bsz = latents.shape[0]
341
+
342
+ # get noisy model input and timesteps
343
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
344
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
345
+ )
346
+
347
+ # pack latents and get img_ids
348
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
349
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
350
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
351
+
352
+ # get guidance
353
+ # ensure guidance_scale in args is float
354
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
355
+
356
+ # ensure the hidden state will require grad
357
+ if args.gradient_checkpointing:
358
+ noisy_model_input.requires_grad_(True)
359
+ for t in text_encoder_conds:
360
+ if t is not None and t.dtype.is_floating_point:
361
+ t.requires_grad_(True)
362
+ img_ids.requires_grad_(True)
363
+ guidance_vec.requires_grad_(True)
364
+
365
+ # Predict the noise residual
366
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
367
+ if not args.apply_t5_attn_mask:
368
+ t5_attn_mask = None
369
+
370
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
371
+ # if not args.split_mode:
372
+ # normal forward
373
+ with accelerator.autocast():
374
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
375
+ model_pred = unet(
376
+ img=img,
377
+ img_ids=img_ids,
378
+ txt=t5_out,
379
+ txt_ids=txt_ids,
380
+ y=l_pooled,
381
+ timesteps=timesteps / 1000,
382
+ guidance=guidance_vec,
383
+ txt_attention_mask=t5_attn_mask
384
+ )
385
+ """
386
+ else:
387
+ # split forward to reduce memory usage
388
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
389
+ with accelerator.autocast():
390
+ # move flux lower to cpu, and then move flux upper to gpu
391
+ unet.to("cpu")
392
+ clean_memory_on_device(accelerator.device)
393
+ self.flux_upper.to(accelerator.device)
394
+
395
+ # upper model does not require grad
396
+ with torch.no_grad():
397
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
398
+ img=packed_noisy_model_input,
399
+ img_ids=img_ids,
400
+ txt=t5_out,
401
+ txt_ids=txt_ids,
402
+ y=l_pooled,
403
+ timesteps=timesteps / 1000,
404
+ guidance=guidance_vec,
405
+ txt_attention_mask=t5_attn_mask,
406
+ )
407
+
408
+ # move flux upper back to cpu, and then move flux lower to gpu
409
+ self.flux_upper.to("cpu")
410
+ clean_memory_on_device(accelerator.device)
411
+ unet.to(accelerator.device)
412
+
413
+ # lower model requires grad
414
+ intermediate_img.requires_grad_(True)
415
+ intermediate_txt.requires_grad_(True)
416
+ vec.requires_grad_(True)
417
+ pe.requires_grad_(True)
418
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
419
+ """
420
+
421
+ return model_pred
422
+
423
+ # 获取数据集分类编号 文本
424
+ # lora_category = batch["captions"][0].split(",")[0][3:]
425
+ # assert lora_category.isdigit(), f"lora_category 不是整数,值为: {lora_category}, {batch['captions'][0]}"
426
+ # lora_category = int(lora_category)
427
+
428
+ prompt_cur = batch["captions"][0]
429
+ match = re.search(r'--lora_up_cur (\d+)', prompt_cur)
430
+ assert match, "Pattern '--lora_up_cur' not found"
431
+ lora_category = int(match.group(1))
432
+
433
+ for lora in network.unet_loras:
434
+ lora.set_lora_up_cur(lora_category-1)
435
+
436
+ model_pred = call_dit(
437
+ img=packed_noisy_model_input,
438
+ img_ids=img_ids,
439
+ t5_out=t5_out,
440
+ txt_ids=txt_ids,
441
+ l_pooled=l_pooled,
442
+ timesteps=timesteps,
443
+ guidance_vec=guidance_vec,
444
+ t5_attn_mask=t5_attn_mask
445
+ )
446
+
447
+ # unpack latents
448
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
449
+
450
+ # apply model prediction type
451
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
452
+
453
+ # flow matching loss: this is different from SD3
454
+ target = noise - latents
455
+
456
+ # differential output preservation
457
+ if "custom_attributes" in batch:
458
+ diff_output_pr_indices = []
459
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
460
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
461
+ diff_output_pr_indices.append(i)
462
+
463
+ if len(diff_output_pr_indices) > 0:
464
+ network.set_multiplier(0.0)
465
+ unet.prepare_block_swap_before_forward()
466
+ with torch.no_grad():
467
+ model_pred_prior = call_dit(
468
+ img=packed_noisy_model_input[diff_output_pr_indices],
469
+ img_ids=img_ids[diff_output_pr_indices],
470
+ t5_out=t5_out[diff_output_pr_indices],
471
+ txt_ids=txt_ids[diff_output_pr_indices],
472
+ l_pooled=l_pooled[diff_output_pr_indices],
473
+ timesteps=timesteps[diff_output_pr_indices],
474
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
475
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
476
+ )
477
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
478
+
479
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
480
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
481
+ args,
482
+ model_pred_prior,
483
+ noisy_model_input[diff_output_pr_indices],
484
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
485
+ )
486
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
487
+
488
+ return model_pred, target, timesteps, None, weighting
489
+
490
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
491
+ return loss
492
+
493
+ def get_sai_model_spec(self, args):
494
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
495
+
496
+ def update_metadata(self, metadata, args):
497
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
498
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
499
+ metadata["ss_logit_mean"] = args.logit_mean
500
+ metadata["ss_logit_std"] = args.logit_std
501
+ metadata["ss_mode_scale"] = args.mode_scale
502
+ metadata["ss_guidance_scale"] = args.guidance_scale
503
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
504
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
505
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
506
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
507
+
508
+ def is_text_encoder_not_needed_for_training(self, args):
509
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
510
+
511
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
512
+ if index == 0: # CLIP-L
513
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
514
+ else: # T5XXL
515
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
516
+
517
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
518
+ if index == 0: # CLIP-L
519
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
520
+ text_encoder.to(te_weight_dtype) # fp8
521
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
522
+ else: # T5XXL
523
+
524
+ def prepare_fp8(text_encoder, target_dtype):
525
+ def forward_hook(module):
526
+ def forward(hidden_states):
527
+ hidden_gelu = module.act(module.wi_0(hidden_states))
528
+ hidden_linear = module.wi_1(hidden_states)
529
+ hidden_states = hidden_gelu * hidden_linear
530
+ hidden_states = module.dropout(hidden_states)
531
+
532
+ hidden_states = module.wo(hidden_states)
533
+ return hidden_states
534
+
535
+ return forward
536
+
537
+ for module in text_encoder.modules():
538
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
539
+ # print("set", module.__class__.__name__, "to", target_dtype)
540
+ module.to(target_dtype)
541
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
542
+ # print("set", module.__class__.__name__, "hooks")
543
+ module.forward = forward_hook(module)
544
+
545
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
546
+ logger.info(f"T5XXL already prepared for fp8")
547
+ else:
548
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
549
+ text_encoder.to(te_weight_dtype) # fp8
550
+ prepare_fp8(text_encoder, weight_dtype)
551
+
552
+ def prepare_unet_with_accelerator(
553
+ self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
554
+ ) -> torch.nn.Module:
555
+ if not self.is_swapping_blocks:
556
+ return super().prepare_unet_with_accelerator(args, accelerator, unet)
557
+
558
+ # if we doesn't swap blocks, we can move the model to device
559
+ flux: flux_models.Flux = unet
560
+ flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
561
+ accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
562
+ accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
563
+
564
+ return flux
565
+
566
+
567
+ def setup_parser() -> argparse.ArgumentParser:
568
+ parser = train_network_asylora.setup_parser()
569
+ train_util.add_dit_training_arguments(parser)
570
+ flux_train_utils.add_flux_train_arguments(parser)
571
+
572
+ parser.add_argument(
573
+ "--split_mode",
574
+ action="store_true",
575
+ # help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
576
+ # + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
577
+ help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
578
+ " / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
579
+ )
580
+ return parser
581
+
582
+
583
+ if __name__ == "__main__":
584
+ parser = setup_parser()
585
+
586
+ args = parser.parse_args()
587
+ train_util.verify_command_line_training_args(args)
588
+ args = train_util.read_config_from_file(args, parser)
589
+
590
+ trainer = FluxNetworkTrainer()
591
+ trainer.train(args)
flux_train_recraft.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import math
4
+ import random
5
+ from typing import Any
6
+ import pdb
7
+
8
+ import torch
9
+ from accelerate import Accelerator
10
+ from library.device_utils import init_ipex, clean_memory_on_device
11
+
12
+ init_ipex()
13
+
14
+ from library import flux_models, flux_train_utils_recraft as flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
15
+ from torchvision import transforms
16
+ import train_network
17
+ from library.utils import setup_logging
18
+ from diffusers.utils import load_image
19
+ import numpy as np
20
+ from PIL import Image, ImageOps
21
+
22
+ setup_logging()
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # NUM_SPLIT = 2
28
+
29
+ class ResizeWithPadding:
30
+ def __init__(self, size, fill=255):
31
+ self.size = size
32
+ self.fill = fill
33
+
34
+ def __call__(self, img):
35
+ if isinstance(img, np.ndarray):
36
+ img = Image.fromarray(img)
37
+ elif not isinstance(img, Image.Image):
38
+ raise TypeError("Input must be a PIL Image or a NumPy array")
39
+
40
+ width, height = img.size
41
+
42
+ if width == height:
43
+ img = img.resize((self.size, self.size), Image.LANCZOS)
44
+ else:
45
+ max_dim = max(width, height)
46
+
47
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
48
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
49
+
50
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
51
+
52
+ return img
53
+
54
+ class FluxNetworkTrainer(train_network.NetworkTrainer):
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.sample_prompts_te_outputs = None
58
+ self.sample_conditions = None
59
+ self.is_schnell: Optional[bool] = None
60
+
61
+ def assert_extra_args(self, args, train_dataset_group):
62
+ super().assert_extra_args(args, train_dataset_group)
63
+ # sdxl_train_util.verify_sdxl_training_args(args)
64
+
65
+ if args.fp8_base_unet:
66
+ args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
67
+
68
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
69
+ logger.warning(
70
+ "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
71
+ )
72
+ args.cache_text_encoder_outputs = True
73
+
74
+ if args.cache_text_encoder_outputs:
75
+ assert (
76
+ train_dataset_group.is_text_encoder_output_cacheable()
77
+ ), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
78
+
79
+ # prepare CLIP-L/T5XXL training flags
80
+ self.train_clip_l = not args.network_train_unet_only
81
+ self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
82
+
83
+ if args.max_token_length is not None:
84
+ logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
85
+
86
+ assert not args.split_mode or not args.cpu_offload_checkpointing, (
87
+ "split_mode and cpu_offload_checkpointing cannot be used together"
88
+ " / split_modeとcpu_offload_checkpointingは同時に使用できません"
89
+ )
90
+
91
+ train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
92
+
93
+ def load_target_model(self, args, weight_dtype, accelerator):
94
+ # currently offload to cpu for some models
95
+
96
+ # if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
97
+ loading_dtype = None if args.fp8_base else weight_dtype
98
+
99
+ # if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
100
+ self.is_schnell, model = flux_utils.load_flow_model(
101
+ args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
102
+ )
103
+ if args.fp8_base:
104
+ # check dtype of model
105
+ if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
106
+ raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
107
+ elif model.dtype == torch.float8_e4m3fn:
108
+ logger.info("Loaded fp8 FLUX model")
109
+
110
+ if args.split_mode:
111
+ model = self.prepare_split_model(model, weight_dtype, accelerator)
112
+
113
+ clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
114
+ clip_l.eval()
115
+
116
+ # if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
117
+ if args.fp8_base and not args.fp8_base_unet:
118
+ loading_dtype = None # as is
119
+ else:
120
+ loading_dtype = weight_dtype
121
+
122
+ # loading t5xxl to cpu takes a long time, so we should load to gpu in future
123
+ t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
124
+ t5xxl.eval()
125
+ if args.fp8_base and not args.fp8_base_unet:
126
+ # check dtype of model
127
+ if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
128
+ raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
129
+ elif t5xxl.dtype == torch.float8_e4m3fn:
130
+ logger.info("Loaded fp8 T5XXL model")
131
+
132
+ ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
133
+
134
+ return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
135
+
136
+ def prepare_split_model(self, model, weight_dtype, accelerator):
137
+ from accelerate import init_empty_weights
138
+
139
+ logger.info("prepare split model")
140
+ with init_empty_weights():
141
+ flux_upper = flux_models.FluxUpper(model.params)
142
+ flux_lower = flux_models.FluxLower(model.params)
143
+ sd = model.state_dict()
144
+
145
+ # lower (trainable)
146
+ logger.info("load state dict for lower")
147
+ flux_lower.load_state_dict(sd, strict=False, assign=True)
148
+ flux_lower.to(dtype=weight_dtype)
149
+
150
+ # upper (frozen)
151
+ logger.info("load state dict for upper")
152
+ flux_upper.load_state_dict(sd, strict=False, assign=True)
153
+
154
+ logger.info("prepare upper model")
155
+ target_dtype = torch.float8_e4m3fn if args.fp8_base else weight_dtype
156
+ flux_upper.to(accelerator.device, dtype=target_dtype)
157
+ flux_upper.eval()
158
+
159
+ if args.fp8_base:
160
+ # this is required to run on fp8
161
+ flux_upper = accelerator.prepare(flux_upper)
162
+
163
+ flux_upper.to("cpu")
164
+
165
+ self.flux_upper = flux_upper
166
+ del model # we don't need model anymore
167
+ clean_memory_on_device(accelerator.device)
168
+
169
+ logger.info("split model prepared")
170
+
171
+ return flux_lower
172
+
173
+ def get_tokenize_strategy(self, args):
174
+ _, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
175
+
176
+ if args.t5xxl_max_token_length is None:
177
+ if is_schnell:
178
+ t5xxl_max_token_length = 256
179
+ else:
180
+ t5xxl_max_token_length = 512
181
+ else:
182
+ t5xxl_max_token_length = args.t5xxl_max_token_length
183
+
184
+ logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
185
+ return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
186
+
187
+ def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
188
+ return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
189
+
190
+ def get_latents_caching_strategy(self, args):
191
+ latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
192
+ return latents_caching_strategy
193
+
194
+ def get_text_encoding_strategy(self, args):
195
+ return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
196
+
197
+ def post_process_network(self, args, accelerator, network, text_encoders, unet):
198
+ # check t5xxl is trained or not
199
+ self.train_t5xxl = network.train_t5xxl
200
+
201
+ if self.train_t5xxl and args.cache_text_encoder_outputs:
202
+ raise ValueError(
203
+ "T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
204
+ )
205
+
206
+ def get_models_for_text_encoding(self, args, accelerator, text_encoders):
207
+ if args.cache_text_encoder_outputs:
208
+ if self.train_clip_l and not self.train_t5xxl:
209
+ return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
210
+ else:
211
+ return None # no text encoders are needed for encoding because both are cached
212
+ else:
213
+ return text_encoders # both CLIP-L and T5XXL are needed for encoding
214
+
215
+ def get_text_encoders_train_flags(self, args, text_encoders):
216
+ return [self.train_clip_l, self.train_t5xxl]
217
+
218
+ def get_text_encoder_outputs_caching_strategy(self, args):
219
+ if args.cache_text_encoder_outputs:
220
+ # if the text encoders is trained, we need tokenization, so is_partial is True
221
+ return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
222
+ args.cache_text_encoder_outputs_to_disk,
223
+ args.text_encoder_batch_size,
224
+ args.skip_cache_check,
225
+ is_partial=self.train_clip_l or self.train_t5xxl,
226
+ apply_t5_attn_mask=args.apply_t5_attn_mask,
227
+ )
228
+ else:
229
+ return None
230
+
231
+ def cache_text_encoder_outputs_if_needed(
232
+ self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
233
+ ):
234
+ if args.cache_text_encoder_outputs:
235
+ if not args.lowram:
236
+ # メモリ消費を減らす
237
+ logger.info("move vae and unet to cpu to save memory")
238
+ org_vae_device = vae.device
239
+ org_unet_device = unet.device
240
+ vae.to("cpu")
241
+ unet.to("cpu")
242
+ clean_memory_on_device(accelerator.device)
243
+
244
+ # When TE is not be trained, it will not be prepared so we need to use explicit autocast
245
+ logger.info("move text encoders to gpu")
246
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
247
+ text_encoders[1].to(accelerator.device)
248
+
249
+ if text_encoders[1].dtype == torch.float8_e4m3fn:
250
+ # if we load fp8 weights, the model is already fp8, so we use it as is
251
+ self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
252
+ else:
253
+ # otherwise, we need to convert it to target dtype
254
+ text_encoders[1].to(weight_dtype)
255
+
256
+ with accelerator.autocast():
257
+ dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
258
+
259
+ # cache sample prompts
260
+ if args.sample_prompts is not None:
261
+ logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
262
+
263
+ tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
264
+ text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
265
+
266
+ prompts = train_util.load_prompts(args.sample_prompts)
267
+ sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
268
+ with accelerator.autocast(), torch.no_grad():
269
+ for prompt_dict in prompts:
270
+ for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
271
+ if p not in sample_prompts_te_outputs:
272
+ logger.info(f"cache Text Encoder outputs for prompt: {p}")
273
+ tokens_and_masks = tokenize_strategy.tokenize(p)
274
+ sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
275
+ tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
276
+ )
277
+ self.sample_prompts_te_outputs = sample_prompts_te_outputs
278
+
279
+ # 添加conditions缓存逻辑
280
+ if args.sample_images is not None:
281
+ logger.info(f"cache conditions for sample images: {args.sample_images}")
282
+
283
+ # lc03lc
284
+ resize_transform = ResizeWithPadding(size=512, fill=255) if args.frame_num == 4 else ResizeWithPadding(size=352, fill=255)
285
+ img_transforms = transforms.Compose([
286
+ resize_transform,
287
+ transforms.ToTensor(),
288
+ transforms.Normalize([0.5], [0.5]),
289
+ ])
290
+
291
+ if args.sample_images.endswith(".txt"):
292
+ with open(args.sample_images, "r", encoding="utf-8") as f:
293
+ lines = f.readlines()
294
+ sample_images = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
295
+ else:
296
+ raise NotImplementedError(f"sample_images file format not supported: {args.sample_images}")
297
+
298
+ prompts = train_util.load_prompts(args.sample_prompts)
299
+ conditions = {} # key: prompt, value: latents
300
+
301
+ with torch.no_grad():
302
+ for image, prompt_dict in zip(sample_images, prompts):
303
+ prompt = prompt_dict.get("prompt", "")
304
+ if prompt not in conditions:
305
+ logger.info(f"cache conditions for image: {image} with prompt: {prompt}")
306
+ image = img_transforms(np.array(load_image(image), dtype=np.uint8)).unsqueeze(0).to(vae.device, dtype=vae.dtype)
307
+ latents = self.encode_images_to_latents2(args, accelerator, vae, image)
308
+ # lc03lc
309
+ conditions[prompt] = latents
310
+ # if args.frame_num == 4:
311
+ # conditions[prompt] = latents[:,:,2*latents.shape[2]//3:latents.shape[2], 2*latents.shape[3]//3:latents.shape[3]].to("cpu")
312
+ # else:
313
+ # conditions[prompt] = latents[:,:,latents.shape[2]//2:latents.shape[2], :latents.shape[3]//2].to("cpu")
314
+
315
+ self.sample_conditions = conditions
316
+
317
+ accelerator.wait_for_everyone()
318
+
319
+ # move back to cpu
320
+ if not self.is_train_text_encoder(args):
321
+ logger.info("move CLIP-L back to cpu")
322
+ text_encoders[0].to("cpu")
323
+ logger.info("move t5XXL back to cpu")
324
+ text_encoders[1].to("cpu")
325
+ clean_memory_on_device(accelerator.device)
326
+
327
+ if not args.lowram:
328
+ logger.info("move vae and unet back to original device")
329
+ vae.to(org_vae_device)
330
+ unet.to(org_unet_device)
331
+ else:
332
+ # Text Encoderから毎回出力を取得するので、GPUに乗せておく
333
+ text_encoders[0].to(accelerator.device, dtype=weight_dtype)
334
+ text_encoders[1].to(accelerator.device)
335
+
336
+ # def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
337
+ # noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
338
+
339
+ # # get size embeddings
340
+ # orig_size = batch["original_sizes_hw"]
341
+ # crop_size = batch["crop_top_lefts"]
342
+ # target_size = batch["target_sizes_hw"]
343
+ # embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
344
+
345
+ # # concat embeddings
346
+ # encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
347
+ # vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
348
+ # text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
349
+
350
+ # noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
351
+ # return noise_pred
352
+
353
+ def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
354
+ text_encoders = text_encoder # for compatibility
355
+ text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
356
+ # 直接使用预先计算的conditions
357
+ conditions = None
358
+ if self.sample_conditions is not None:
359
+ conditions = {k: v.to(accelerator.device) for k, v in self.sample_conditions.items()}
360
+
361
+ if not args.split_mode:
362
+ flux_train_utils.sample_images(
363
+ accelerator, args, epoch, global_step, flux, ae, text_encoder, self.sample_prompts_te_outputs, None, conditions
364
+ )
365
+ return
366
+
367
+ class FluxUpperLowerWrapper(torch.nn.Module):
368
+ def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
369
+ super().__init__()
370
+ self.flux_upper = flux_upper
371
+ self.flux_lower = flux_lower
372
+ self.target_device = device
373
+
374
+ def prepare_block_swap_before_forward(self):
375
+ pass
376
+
377
+ def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
378
+ self.flux_lower.to("cpu")
379
+ clean_memory_on_device(self.target_device)
380
+ self.flux_upper.to(self.target_device)
381
+ img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
382
+ self.flux_upper.to("cpu")
383
+ clean_memory_on_device(self.target_device)
384
+ self.flux_lower.to(self.target_device)
385
+ return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
386
+
387
+ wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
388
+ clean_memory_on_device(accelerator.device)
389
+ flux_train_utils.sample_images(
390
+ accelerator, args, epoch, global_step, wrapper, ae, text_encoder, self.sample_prompts_te_outputs, conditions
391
+ )
392
+ clean_memory_on_device(accelerator.device)
393
+
394
+ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
395
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
396
+ self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
397
+ return noise_scheduler
398
+
399
+ def encode_images_to_latents(self, args, accelerator, vae, images):
400
+ # 获取图像尺寸
401
+ b, c, h, w = images.shape
402
+
403
+ # num_split = NUM_SPLIT
404
+ num_split = 2 if args.frame_num == 4 else 3
405
+ # 将图像分成三个部分
406
+ img_parts = [images[:,:,:,i*w//num_split:(i+1)*w//num_split] for i in range(num_split)]
407
+ # 分别编码
408
+ latents = [vae.encode(img) for img in img_parts]
409
+ # 在latent空间拼接回完整图像
410
+ latents = torch.cat(latents, dim=-1)
411
+
412
+ return latents
413
+
414
+ def encode_images_to_latents2(self, args, accelerator, vae, images):
415
+ # 获取图像尺寸
416
+ b, c, h, w = images.shape
417
+ # num_split = NUM_SPLIT
418
+ num_split = 2 if args.frame_num == 4 else 3
419
+ latents = vae.encode(images)
420
+ return latents
421
+
422
+ def encode_images_to_latents3(self, args, accelerator, vae, images):
423
+ b, c, h, w = images.shape
424
+ # Number of splits along each dimension
425
+ num_split = 3
426
+ # Check if the image can be evenly divided into 3x3 grid
427
+ assert h % num_split == 0 and w % num_split == 0, "Image dimensions must be divisible by 3."
428
+
429
+ # Height and width of each split
430
+ split_h, split_w = h // num_split, w // num_split
431
+
432
+ # Store latents for each split
433
+ latents = []
434
+
435
+ for i in range(num_split):
436
+ for j in range(num_split):
437
+ # Extract the (i, j) sub-image
438
+ img_part = images[:, :, i * split_h:(i + 1) * split_h, j * split_w:(j + 1) * split_w]
439
+ # Encode the sub-image using VAE
440
+ latent = vae.encode(img_part)
441
+ # Append the latent
442
+ latents.append(latent)
443
+
444
+ # Combine latents into a 3x3 grid in the latent space
445
+ # Latents list -> Tensor [num_split^2, b, latent_dim, h', w']
446
+ latents = torch.stack(latents, dim=0)
447
+
448
+ # Reshape into a 3x3 grid
449
+ # Shape: [num_split, num_split, b, latent_dim, h', w']
450
+ latents = latents.view(num_split, num_split, b, *latents.shape[2:])
451
+
452
+ # Combine the 3x3 grid along height and width in latent space
453
+ # Concatenate along width for each row, then concatenate rows along height
454
+ latents = torch.cat([torch.cat(latents[i], dim=-1) for i in range(num_split)], dim=-2)
455
+
456
+ # Final shape: [b, latent_dim, h', w']
457
+ return latents
458
+
459
+ def shift_scale_latents(self, args, latents):
460
+ return latents
461
+
462
+ def get_noise_pred_and_target(
463
+ self,
464
+ args,
465
+ accelerator,
466
+ noise_scheduler,
467
+ latents,
468
+ batch,
469
+ text_encoder_conds,
470
+ unet: flux_models.Flux,
471
+ network,
472
+ weight_dtype,
473
+ train_unet,
474
+ ):
475
+ # Sample noise that we'll add to the latents
476
+ noise = torch.randn_like(latents)
477
+ bsz = latents.shape[0]
478
+
479
+ # get noisy model input and timesteps
480
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
481
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
482
+ )
483
+
484
+ # pack latents and get img_ids
485
+ # yiren ? need modify?
486
+ packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
487
+ packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
488
+ img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
489
+
490
+ # get guidance
491
+ # ensure guidance_scale in args is float
492
+ guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
493
+
494
+ # ensure the hidden state will require grad
495
+ if args.gradient_checkpointing:
496
+ noisy_model_input.requires_grad_(True)
497
+ for t in text_encoder_conds:
498
+ if t is not None and t.dtype.is_floating_point:
499
+ t.requires_grad_(True)
500
+ img_ids.requires_grad_(True)
501
+ guidance_vec.requires_grad_(True)
502
+
503
+ # Predict the noise residual
504
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
505
+ if not args.apply_t5_attn_mask:
506
+ t5_attn_mask = None
507
+
508
+ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
509
+ if not args.split_mode:
510
+ # normal forward
511
+ with accelerator.autocast():
512
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
513
+ model_pred = unet(
514
+ img=img,
515
+ img_ids=img_ids,
516
+ txt=t5_out,
517
+ txt_ids=txt_ids,
518
+ y=l_pooled,
519
+ timesteps=timesteps / 1000,
520
+ guidance=guidance_vec,
521
+ txt_attention_mask=t5_attn_mask,
522
+ )
523
+ else:
524
+ # split forward to reduce memory usage
525
+ assert network.train_blocks == "single", "train_blocks must be single for split mode"
526
+ with accelerator.autocast():
527
+ # move flux lower to cpu, and then move flux upper to gpu
528
+ unet.to("cpu")
529
+ clean_memory_on_device(accelerator.device)
530
+ self.flux_upper.to(accelerator.device)
531
+
532
+ # upper model does not require grad
533
+ with torch.no_grad():
534
+ intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
535
+ img=packed_noisy_model_input,
536
+ img_ids=img_ids,
537
+ txt=t5_out,
538
+ txt_ids=txt_ids,
539
+ y=l_pooled,
540
+ timesteps=timesteps / 1000,
541
+ guidance=guidance_vec,
542
+ txt_attention_mask=t5_attn_mask,
543
+ )
544
+
545
+ # move flux upper back to cpu, and then move flux lower to gpu
546
+ self.flux_upper.to("cpu")
547
+ clean_memory_on_device(accelerator.device)
548
+ unet.to(accelerator.device)
549
+
550
+ # lower model requires grad
551
+ intermediate_img.requires_grad_(True)
552
+ intermediate_txt.requires_grad_(True)
553
+ vec.requires_grad_(True)
554
+ pe.requires_grad_(True)
555
+ model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
556
+
557
+ return model_pred
558
+
559
+ model_pred = call_dit(
560
+ img=packed_noisy_model_input,
561
+ img_ids=img_ids,
562
+ t5_out=t5_out,
563
+ txt_ids=txt_ids,
564
+ l_pooled=l_pooled,
565
+ timesteps=timesteps,
566
+ guidance_vec=guidance_vec,
567
+ t5_attn_mask=t5_attn_mask,
568
+ )
569
+
570
+ # unpack latents
571
+ model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
572
+
573
+ # apply model prediction type
574
+ model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
575
+
576
+ # flow matching loss: this is different from SD3
577
+ target = noise - latents
578
+
579
+ # differential output preservation
580
+ if "custom_attributes" in batch:
581
+ diff_output_pr_indices = []
582
+ for i, custom_attributes in enumerate(batch["custom_attributes"]):
583
+ if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
584
+ diff_output_pr_indices.append(i)
585
+
586
+ if len(diff_output_pr_indices) > 0:
587
+ network.set_multiplier(0.0)
588
+ with torch.no_grad():
589
+ model_pred_prior = call_dit(
590
+ img=packed_noisy_model_input[diff_output_pr_indices],
591
+ img_ids=img_ids[diff_output_pr_indices],
592
+ t5_out=t5_out[diff_output_pr_indices],
593
+ txt_ids=txt_ids[diff_output_pr_indices],
594
+ l_pooled=l_pooled[diff_output_pr_indices],
595
+ timesteps=timesteps[diff_output_pr_indices],
596
+ guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
597
+ t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
598
+ )
599
+ network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
600
+
601
+ model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
602
+ model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
603
+ args,
604
+ model_pred_prior,
605
+ noisy_model_input[diff_output_pr_indices],
606
+ sigmas[diff_output_pr_indices] if sigmas is not None else None,
607
+ )
608
+ target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
609
+
610
+ # elimilate the loss in the left top quarter of the image
611
+ h, w = target.shape[2], target.shape[3]
612
+ # num_split = NUM_SPLIT
613
+ num_split = 2 if args.frame_num == 4 else 3
614
+ # target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
615
+ # target[:, :, :, :w//num_split] = model_pred[:, :, :, :w//num_split]
616
+ target[:, :, 2*h//num_split:h, 2*w//num_split:w] = model_pred[:, :, 2*h//num_split:h, 2*w//num_split:w]
617
+
618
+
619
+ return model_pred, target, timesteps, None, weighting
620
+
621
+ def post_process_loss(self, loss, args, timesteps, noise_scheduler):
622
+ return loss
623
+
624
+ def get_sai_model_spec(self, args):
625
+ return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
626
+
627
+ def update_metadata(self, metadata, args):
628
+ metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
629
+ metadata["ss_weighting_scheme"] = args.weighting_scheme
630
+ metadata["ss_logit_mean"] = args.logit_mean
631
+ metadata["ss_logit_std"] = args.logit_std
632
+ metadata["ss_mode_scale"] = args.mode_scale
633
+ metadata["ss_guidance_scale"] = args.guidance_scale
634
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
635
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
636
+ metadata["ss_model_prediction_type"] = args.model_prediction_type
637
+ metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
638
+
639
+ def is_text_encoder_not_needed_for_training(self, args):
640
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
641
+
642
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
643
+ if index == 0: # CLIP-L
644
+ return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
645
+ else: # T5XXL
646
+ text_encoder.encoder.embed_tokens.requires_grad_(True)
647
+
648
+ def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
649
+ if index == 0: # CLIP-L
650
+ logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
651
+ text_encoder.to(te_weight_dtype) # fp8
652
+ text_encoder.text_model.embeddings.to(dtype=weight_dtype)
653
+ else: # T5XXL
654
+
655
+ def prepare_fp8(text_encoder, target_dtype):
656
+ def forward_hook(module):
657
+ def forward(hidden_states):
658
+ hidden_gelu = module.act(module.wi_0(hidden_states))
659
+ hidden_linear = module.wi_1(hidden_states)
660
+ hidden_states = hidden_gelu * hidden_linear
661
+ hidden_states = module.dropout(hidden_states)
662
+
663
+ hidden_states = module.wo(hidden_states)
664
+ return hidden_states
665
+
666
+ return forward
667
+
668
+ for module in text_encoder.modules():
669
+ if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
670
+ # print("set", module.__class__.__name__, "to", target_dtype)
671
+ module.to(target_dtype)
672
+ if module.__class__.__name__ in ["T5DenseGatedActDense"]:
673
+ # print("set", module.__class__.__name__, "hooks")
674
+ module.forward = forward_hook(module)
675
+
676
+ if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
677
+ logger.info(f"T5XXL already prepared for fp8")
678
+ else:
679
+ logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
680
+ text_encoder.to(te_weight_dtype) # fp8
681
+ prepare_fp8(text_encoder, weight_dtype)
682
+
683
+
684
+ def setup_parser() -> argparse.ArgumentParser:
685
+ parser = train_network.setup_parser()
686
+ flux_train_utils.add_flux_train_arguments(parser)
687
+
688
+ parser.add_argument(
689
+ "--split_mode",
690
+ action="store_true",
691
+ help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
692
+ + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
693
+ )
694
+
695
+ parser.add_argument(
696
+ '--frame_num',
697
+ type=int,
698
+ choices=[4, 9],
699
+ required=True,
700
+ help="The number of steps in the generated step diagram (choose 4 or 9)"
701
+ )
702
+ return parser
703
+
704
+
705
+ if __name__ == "__main__":
706
+ parser = setup_parser()
707
+
708
+ args = parser.parse_args()
709
+ train_util.verify_command_line_training_args(args)
710
+ args = train_util.read_config_from_file(args, parser)
711
+
712
+ trainer = FluxNetworkTrainer()
713
+ trainer.train(args)
gradio_app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from accelerate import Accelerator
7
+ import os
8
+ import time
9
+ from torchvision import transforms
10
+ from safetensors.torch import load_file
11
+ from networks import lora_flux
12
+ from library import flux_utils, flux_train_utils_recraft as flux_train_utils, strategy_flux
13
+ import logging
14
+ from huggingface_hub import login
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Set up logger
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.DEBUG)
22
+
23
+ accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
24
+
25
+ # hf_token = os.getenv("HF_TOKEN")
26
+ # login(token=hf_token)
27
+
28
+ # # Model paths dynamically retrieved using selected model
29
+ # model_paths = {
30
+ # 'Wood Sculpture': {
31
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
32
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
33
+ # 'LORA_REPO': "showlab/makeanything",
34
+ # 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
35
+ # "Frame": 4
36
+ # },
37
+ # 'LEGO': {
38
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
39
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
40
+ # 'LORA_REPO': "showlab/makeanything",
41
+ # 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
42
+ # "Frame": 9
43
+ # },
44
+ # 'Sketch': {
45
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
46
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
47
+ # 'LORA_REPO': "showlab/makeanything",
48
+ # 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
49
+ # "Frame": 9
50
+ # },
51
+ # 'Portrait': {
52
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
53
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
54
+ # 'LORA_REPO': "showlab/makeanything",
55
+ # 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
56
+ # "Frame": 9
57
+ # }
58
+ # }
59
+
60
+ # # Common paths
61
+ # clip_repo_id = "comfyanonymous/flux_text_encoders"
62
+ # t5xxl_file = "t5xxl_fp16.safetensors"
63
+ # clip_l_file = "clip_l.safetensors"
64
+ # ae_repo_id = "black-forest-labs/FLUX.1-dev"
65
+ # ae_file = "ae.safetensors"
66
+
67
+ model_paths = {
68
+ 'Wood Sculpture': {
69
+ 'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_4f_wood_sculpture-fp8_e4m3fn.safetensors",
70
+ 'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_4f_wood_sculpture.safetensors",
71
+ 'Frame': 4
72
+ },
73
+ 'LEGO': {
74
+ 'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_lego-fp8_e4m3fn.safetensors",
75
+ 'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_lego.safetensors",
76
+ 'Frame': 9
77
+ },
78
+ 'Sketch': {
79
+ 'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_sketch-fp8_e4m3fn.safetensors",
80
+ 'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_sketch.safetensors",
81
+ 'Frame': 9
82
+ },
83
+ 'Portrait': {
84
+ 'BASE_FLUX_CHECKPOINT': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/flux_merge_lora/flux_merge_9f_portrait-fp8_e4m3fn.safetensors",
85
+ 'LORA_WEIGHTS_PATH': "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/recraft/recraft_9f_portrait.safetensors",
86
+ 'Frame': 9
87
+ }
88
+ }
89
+ CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
90
+ T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp16.safetensors"
91
+ AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
92
+
93
+
94
+ # Model placeholders
95
+ model = None
96
+ clip_l = None
97
+ t5xxl = None
98
+ ae = None
99
+ lora_model = None
100
+
101
+ # Function to load a file from Hugging Face Hub
102
+ def download_file(repo_id, file_name):
103
+ return hf_hub_download(repo_id=repo_id, filename=file_name)
104
+
105
+ # Load model function with dynamic paths based on the selected model
106
+ def load_target_model(selected_model):
107
+ global model, clip_l, t5xxl, ae, lora_model
108
+ model_path = model_paths[selected_model]
109
+ BASE_FLUX_CHECKPOINT = model_path['BASE_FLUX_CHECKPOINT']
110
+ LORA_WEIGHTS_PATH = model_path['LORA_WEIGHTS_PATH']
111
+
112
+ logger.info("Loading models...")
113
+ try:
114
+ if model is None is None or clip_l is None or t5xxl is None or ae is None:
115
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
116
+ clip_l.eval()
117
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
118
+ t5xxl.eval()
119
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
120
+ logger.info("Models loaded successfully.")
121
+ # Load models
122
+ _, model = flux_utils.load_flow_model(
123
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
124
+ )
125
+ # Load LoRA weights
126
+ multiplier = 1.0
127
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
128
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
129
+ lora_model.apply_to([clip_l, t5xxl], model)
130
+ info = lora_model.load_state_dict(weights_sd, strict=True)
131
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
132
+ lora_model.eval()
133
+
134
+ logger.info("Models loaded successfully.")
135
+ return "Models loaded successfully. Using Recraft: {}".format(selected_model)
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error loading models: {e}")
139
+ return f"Error loading models: {e}"
140
+
141
+ # Image pre-processing (resize and padding)
142
+ class ResizeWithPadding:
143
+ def __init__(self, size, fill=255):
144
+ self.size = size
145
+ self.fill = fill
146
+
147
+ def __call__(self, img):
148
+ if isinstance(img, np.ndarray):
149
+ img = Image.fromarray(img)
150
+ elif not isinstance(img, Image.Image):
151
+ raise TypeError("Input must be a PIL Image or a NumPy array")
152
+
153
+ width, height = img.size
154
+ max_dim = max(width, height)
155
+ new_img = Image.new("RGB", (max_dim, max_dim), (self.fill, self.fill, self.fill))
156
+ new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
157
+ img = new_img.resize((self.size, self.size), Image.LANCZOS)
158
+ return img
159
+
160
+ # The function to generate image from a prompt and conditional image
161
+ # @spaces.GPU(duration=180)
162
+ def infer(prompt, sample_image, recraft_model, seed=0):
163
+ global model, clip_l, t5xxl, ae, lora_model
164
+ if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
165
+ logger.error("Models not loaded. Please load the models first.")
166
+ return None
167
+
168
+ model_path = model_paths[recraft_model]
169
+ frame_num = model_path['Frame']
170
+
171
+ logger.info(f"Started generating image with prompt: {prompt}")
172
+
173
+ lora_model.to("cuda")
174
+
175
+ model.eval()
176
+ clip_l.eval()
177
+ t5xxl.eval()
178
+ ae.eval()
179
+
180
+ # # Load models
181
+ # model, [clip_l, t5xxl], ae = load_target_model()
182
+
183
+ # # LoRA
184
+ # multiplier = 1.0
185
+ # weights_sd = load_file(LORA_WEIGHTS_PATH)
186
+ # lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd,
187
+ # True)
188
+
189
+ # lora_model.apply_to([clip_l, t5xxl], model)
190
+ # info = lora_model.load_state_dict(weights_sd, strict=True)
191
+ # logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
192
+ # lora_model.eval()
193
+ # lora_model.to(device)
194
+
195
+ logger.info(f"Using seed: {seed}")
196
+
197
+ # Preprocess the conditional image
198
+ resize_transform = ResizeWithPadding(size=512) if frame_num == 4 else ResizeWithPadding(size=352)
199
+ img_transforms = transforms.Compose([
200
+ resize_transform,
201
+ transforms.ToTensor(),
202
+ transforms.Normalize([0.5], [0.5]),
203
+ ])
204
+ image = img_transforms(np.array(sample_image, dtype=np.uint8)).unsqueeze(0).to(
205
+ device=device,
206
+ dtype=torch.bfloat16
207
+ )
208
+ logger.debug("Conditional image preprocessed.")
209
+
210
+ # Encode the image to latents
211
+ ae.to(device)
212
+ latents = ae.encode(image)
213
+ logger.debug("Image encoded to latents.")
214
+
215
+ conditions = {}
216
+ # conditions[prompt] = latents.to("cpu")
217
+ conditions[prompt] = latents
218
+
219
+
220
+ # ae.to("cpu")
221
+ clip_l.to(device)
222
+ t5xxl.to(device)
223
+
224
+ # Encode the prompt
225
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
226
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
227
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
228
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
229
+
230
+ logger.debug("Prompt encoded.")
231
+
232
+ # Prepare the noise and other parameters
233
+ width = 1024 if frame_num == 4 else 1056
234
+ height = 1024 if frame_num == 4 else 1056
235
+
236
+ height = max(64, height - height % 16)
237
+ width = max(64, width - width % 16)
238
+
239
+ packed_latent_height = height // 16
240
+ packed_latent_width = width // 16
241
+
242
+ torch.manual_seed(seed)
243
+ noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
244
+ logger.debug("Noise prepared.")
245
+
246
+ # Generate the image
247
+ timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
248
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
249
+
250
+ t5_attn_mask = t5_attn_mask.to(device)
251
+ ae_outputs = conditions[prompt]
252
+
253
+ logger.debug("Image generation parameters set.")
254
+
255
+ args = lambda: None
256
+ args.frame_num = frame_num
257
+
258
+ # clip_l.to("cpu")
259
+ # t5xxl.to("cpu")
260
+
261
+ model.to(device)
262
+
263
+ print(f"Model device: {model.device}")
264
+ print(f"Noise device: {noise.device}")
265
+ print(f"Image IDs device: {img_ids.device}")
266
+ print(f"T5 output device: {t5_out.device}")
267
+ print(f"Text IDs device: {txt_ids.device}")
268
+ print(f"L pooled device: {l_pooled.device}")
269
+
270
+ # Run the denoising process
271
+ with accelerator.autocast(), torch.no_grad():
272
+ x = flux_train_utils.denoise(
273
+ args, model, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=1.0, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs
274
+ )
275
+ logger.debug("Denoising process completed.")
276
+
277
+ # Decode the final image
278
+ x = x.float()
279
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
280
+ # model.to("cpu")
281
+ ae.to(device)
282
+ with accelerator.autocast(), torch.no_grad():
283
+ x = ae.decode(x)
284
+ logger.debug("Latents decoded into image.")
285
+ # ae.to("cpu")
286
+
287
+ # Convert the tensor to an image
288
+ x = x.clamp(-1, 1)
289
+ x = x.permute(0, 2, 3, 1)
290
+ generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
291
+
292
+ logger.info("Image generation completed.")
293
+ return generated_image
294
+
295
+ # Gradio interface
296
+ with gr.Blocks() as demo:
297
+ gr.Markdown("## Recraft Generation")
298
+
299
+ with gr.Row():
300
+ with gr.Column(scale=1):
301
+ # Dropdown for selecting the recraft model
302
+ recraft_model = gr.Dropdown(
303
+ label="Select Recraft Model",
304
+ choices=["Wood Sculpture", "LEGO", "Sketch", "Portrait"],
305
+ value="Wood Sculpture"
306
+ )
307
+
308
+ # Load Model Button
309
+ load_button = gr.Button("Load Model")
310
+
311
+ with gr.Column(scale=1):
312
+ # Status message box
313
+ status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=0.5):
317
+ # Input for the prompt
318
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8)
319
+ seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42)
320
+
321
+ with gr.Column(scale=0.5):
322
+ # File upload for image
323
+ sample_image = gr.Image(label="Upload a Conditional Image", type="pil")
324
+ run_button = gr.Button("Generate Image")
325
+
326
+ with gr.Column(scale=1):
327
+ # Output result
328
+ result_image = gr.Image(label="Generated Image", interactive=False)
329
+
330
+ # Load model button action
331
+ load_button.click(fn=load_target_model, inputs=[recraft_model], outputs=[status_box])
332
+
333
+ # Run Button
334
+ run_button.click(fn=infer, inputs=[prompt, sample_image, recraft_model, seed], outputs=[result_image])
335
+
336
+ gr.Markdown("### Examples")
337
+ examples = [
338
+ [
339
+ "sks14, 2*2 puzzle of 4 sub-images, step-by-step wood sculpture carving process", # prompt
340
+ "./gradio_examples/wood_sculpture.png",
341
+ "Wood Sculpture", # recraft_model
342
+ 12345 # seed
343
+ ],
344
+ [
345
+ "sks1, 3*3 puzzle of 9 sub-images, step-by-step lego model construction process", # prompt
346
+ "./gradio_examples/lego.png",
347
+ "LEGO", # recraft_model
348
+ 42 # seed
349
+ ],
350
+ [
351
+ "sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process", # prompt
352
+ "./gradio_examples/portrait.png",
353
+ "Portrait", # recraft_model
354
+ 999 # seed
355
+ ],
356
+ [
357
+ "sks10, 3*3 puzzle of 9 sub-images, step-by-step sketch painting process,", # prompt
358
+ "./gradio_examples/sketch.png",
359
+ "Sketch",
360
+ 2023
361
+ ]
362
+ ]
363
+
364
+ gr.Examples(
365
+ examples=examples,
366
+ inputs=[prompt, sample_image, recraft_model, seed],
367
+ outputs=[result_image],
368
+ cache_examples=False
369
+ )
370
+
371
+ # Launch the Gradio app
372
+ demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
gradio_app_asy.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from accelerate import Accelerator
7
+ import os
8
+ import time
9
+ import math
10
+ import json
11
+ from torchvision import transforms
12
+ from safetensors.torch import load_file
13
+ from networks import asylora_flux as lora_flux
14
+ from library import flux_utils, strategy_flux
15
+ import flux_minimal_inference_asylora as flux_train_utils
16
+ import logging
17
+ from huggingface_hub import login
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Set up logger
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.DEBUG)
25
+
26
+ accelerator = Accelerator(mixed_precision='bf16', device_placement=True)
27
+
28
+ # hf_token = os.getenv("HF_TOKEN")
29
+ # login(token=hf_token)
30
+
31
+ # # Model paths dynamically retrieved using selected model
32
+ # model_paths = {
33
+ # 'Wood Sculpture': {
34
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
35
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_4f_wood-fp16.safetensors",
36
+ # 'LORA_REPO': "showlab/makeanything",
37
+ # 'LORA_FILE': "recraft/recraft_4f_wood_sculpture.safetensors",
38
+ # "Frame": 4
39
+ # },
40
+ # 'LEGO': {
41
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
42
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_lego-fp16.safetensors",
43
+ # 'LORA_REPO': "showlab/makeanything",
44
+ # 'LORA_FILE': "recraft/recraft_9f_lego.safetensors",
45
+ # "Frame": 9
46
+ # },
47
+ # 'Sketch': {
48
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
49
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_portrait-fp16.safetensors",
50
+ # 'LORA_REPO': "showlab/makeanything",
51
+ # 'LORA_FILE': "recraft/recraft_9f_sketch.safetensors",
52
+ # "Frame": 9
53
+ # },
54
+ # 'Portrait': {
55
+ # 'BASE_FLUX_CHECKPOINT': "showlab/makeanything",
56
+ # 'BASE_FILE': "flux_merge_lora/flux_merge_9f_sketch-fp16.safetensors",
57
+ # 'LORA_REPO': "showlab/makeanything",
58
+ # 'LORA_FILE': "recraft/recraft_9f_portrait.safetensors",
59
+ # "Frame": 9
60
+ # }
61
+ # }
62
+
63
+ # # Common paths
64
+ # clip_repo_id = "comfyanonymous/flux_text_encoders"
65
+ # t5xxl_file = "t5xxl_fp16.safetensors"
66
+ # clip_l_file = "clip_l.safetensors"
67
+ # ae_repo_id = "black-forest-labs/FLUX.1-dev"
68
+ # ae_file = "ae.safetensors"
69
+
70
+ domain_index = {
71
+ 'LEGO': 1, 'Cook': 2, 'Painting': 3, 'Icon': 4, 'Landscape illustration': 5,
72
+ 'Portrait': 6, 'Transformer': 7, 'Sand art': 8, 'Illustration': 9, 'Sketch': 10,
73
+ 'Clay toys': 11, 'Clay sculpture': 12, 'Zbrush Modeling': 13, 'Wood sculpture': 14,
74
+ 'Ink painting': 15, 'Pencil sketch': 16, 'Fabric toys': 17, 'Oil painting': 18,
75
+ 'Jade Carving': 19, 'Line draw': 20, 'Emoji': 21
76
+ }
77
+
78
+ lora_paths = {
79
+ "9 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors",
80
+ "4 frame": "/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_4f_general.safetensors"
81
+ }
82
+ BASE_FLUX_CHECKPOINT = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/unet/flux1-dev-fp8.safetensors"
83
+ # LORA_WEIGHTS_PATH="/tiamat-NAS/songyiren/FYP/liucheng/makeanything_models/makeanything/asymmetric_lora/asymmetric_lora_9f_general.safetensors"
84
+ CLIP_L_PATH = "/tiamat-NAS/hailong/storage_backup/models/stabilityai/stable-diffusion-3-medium/text_encoders/clip_l.safetensors"
85
+ T5XXL_PATH = "/tiamat-NAS/songyiren/FYP/liucheng/ComfyUI/models/clip/t5xxl_fp8_e4m3fn.safetensors"
86
+ AE_PATH = "/tiamat-vePFS/share_data/storage/huggingface/models/black-forest-labs/FLUX.1-dev/ae.safetensors"
87
+
88
+
89
+ # Model placeholders
90
+ model = None
91
+ clip_l = None
92
+ t5xxl = None
93
+ ae = None
94
+ lora_model = None
95
+
96
+ # Function to load a file from Hugging Face Hub
97
+ def download_file(repo_id, file_name):
98
+ return hf_hub_download(repo_id=repo_id, filename=file_name)
99
+
100
+ # Load model function with dynamic paths based on the selected model
101
+ def load_target_model(frame, domain):
102
+ global model, clip_l, t5xxl, ae, lora_model
103
+
104
+ logger.info("Loading models...")
105
+ # try:
106
+ if model is None is None or clip_l is None or t5xxl is None or ae is None:
107
+ clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
108
+ clip_l.eval()
109
+ t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
110
+ t5xxl.eval()
111
+ ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
112
+ logger.info("Models loaded successfully.")
113
+ # Load models
114
+ _, model = flux_utils.load_flow_model(
115
+ BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
116
+ )
117
+ # Load LoRA weights
118
+ LORA_WEIGHTS_PATH = lora_paths[frame]
119
+ multiplier = 1.0
120
+ weights_sd = load_file(LORA_WEIGHTS_PATH)
121
+ lora_ups_num = 10 if frame=="9 frame" else 21
122
+ lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True, lora_ups_num=lora_ups_num)
123
+ for sub_lora in lora_model.unet_loras:
124
+ sub_lora.set_lora_up_cur(domain_index[domain]-1)
125
+
126
+ lora_model.apply_to([clip_l, t5xxl], model)
127
+ info = lora_model.load_state_dict(weights_sd, strict=True)
128
+ logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
129
+ lora_model.eval()
130
+
131
+ logger.info("Models loaded successfully.")
132
+ return "Models loaded successfully. Using Frame: {}, Damain: {}".format(frame, domain)
133
+
134
+ # except Exception as e:
135
+ # logger.error(f"Error loading models: {e}")
136
+ # return f"Error loading models: {e}"
137
+
138
+ # The function to generate image from a prompt and conditional image
139
+ # @spaces.GPU(duration=180)
140
+ def infer(prompt, frame, seed=0):
141
+ global model, clip_l, t5xxl, ae, lora_model
142
+ if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
143
+ logger.error("Models not loaded. Please load the models first.")
144
+ return None
145
+
146
+ frame_num = int(frame[0:1])
147
+
148
+ logger.info(f"Started generating image with prompt: {prompt}")
149
+
150
+ lora_model.to("cuda")
151
+
152
+ model.eval()
153
+ clip_l.eval()
154
+ t5xxl.eval()
155
+ ae.eval()
156
+
157
+ logger.info(f"Using seed: {seed}")
158
+
159
+ # ae.to("cpu")
160
+ clip_l.to(device)
161
+ t5xxl.to(device)
162
+
163
+ # Encode the prompt
164
+ tokenize_strategy = strategy_flux.FluxTokenizeStrategy(512)
165
+ text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(True)
166
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
167
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, True)
168
+
169
+ logger.debug("Prompt encoded.")
170
+
171
+ # Prepare the noise and other parameters
172
+ width = 1024 if frame_num == 4 else 1056
173
+ height = 1024 if frame_num == 4 else 1056
174
+
175
+ packed_latent_height, packed_latent_width = math.ceil(height / 16), math.ceil(width / 16)
176
+
177
+ torch.manual_seed(seed)
178
+ noise = torch.randn(1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, dtype=torch.float16)
179
+ logger.debug("Noise prepared.")
180
+
181
+
182
+ # Generate the image
183
+ timesteps = flux_train_utils.get_schedule(20, noise.shape[1], shift=True) # Sample steps = 20
184
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(device)
185
+
186
+ t5_attn_mask = t5_attn_mask.to(device)
187
+
188
+ logger.debug("Image generation parameters set.")
189
+
190
+ args = lambda: None
191
+ args.frame_num = frame_num
192
+
193
+ # clip_l.to("cpu")
194
+ # t5xxl.to("cpu")
195
+
196
+ model.to(device)
197
+
198
+ print(f"Model device: {model.device}")
199
+ print(f"Noise device: {noise.device}")
200
+ print(f"Image IDs device: {img_ids.device}")
201
+ print(f"T5 output device: {t5_out.device}")
202
+ print(f"Text IDs device: {txt_ids.device}")
203
+ print(f"L pooled device: {l_pooled.device}")
204
+
205
+ # Run the denoising process
206
+ with accelerator.autocast(), torch.no_grad():
207
+ x = flux_train_utils.denoise(
208
+ model,
209
+ noise,
210
+ img_ids,
211
+ t5_out,
212
+ txt_ids,
213
+ l_pooled,
214
+ timesteps,
215
+ guidance=4.0,
216
+ t5_attn_mask=t5_attn_mask,
217
+ cfg_scale=1.0,
218
+ )
219
+ logger.debug("Denoising process completed.")
220
+
221
+ # Decode the final image
222
+ x = x.float()
223
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
224
+ # model.to("cpu")
225
+ ae.to(device)
226
+ with accelerator.autocast(), torch.no_grad():
227
+ x = ae.decode(x)
228
+ logger.debug("Latents decoded into image.")
229
+ # ae.to("cpu")
230
+
231
+ # Convert the tensor to an image
232
+ x = x.clamp(-1, 1)
233
+ x = x.permute(0, 2, 3, 1)
234
+ generated_image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
235
+
236
+ logger.info("Image generation completed.")
237
+ return generated_image
238
+
239
+ def update_domains(floor):
240
+ domains_dict = {
241
+ "4 frame": [
242
+ "LEGO", "Cook", "Painting", "Icon", "Landscape illustration",
243
+ "Portrait", "Transformer", "Sand art", "Illustration", "Sketch",
244
+ "Clay toys", "Clay sculpture", "Zbrush Modeling", "Wood sculpture", "Ink painting",
245
+ "Pencil sketch", "Fabric toys", "Oil painting", "Jade Carving", "Line draw", "Emoji"
246
+ ],
247
+ "9 frame": [
248
+ "LEGO", "Cook", "Painting", "Icon", "Landscape illustration",
249
+ "Portrait", "Transformer", "Sand art", "Illustration", "Sketch"
250
+ ]
251
+ }
252
+ return gr.Dropdown.update(choices=domains_dict[floor], label="Select Domains")
253
+
254
+ # Gradio interface
255
+ with gr.Blocks() as demo:
256
+ gr.Markdown("## Asymmertric LoRA Generation")
257
+
258
+ with gr.Row():
259
+ with gr.Column(scale=1):
260
+ with gr.Row():
261
+ with gr.Column(scale=1):
262
+ frame_selector = gr.Radio(choices=["4 frame", "9 frame"], label="Select Floor")
263
+ with gr.Column(scale=2):
264
+ domain_selector = gr.Dropdown(choices=[], label="Select Domains")
265
+
266
+ # Load Model Button
267
+ load_button = gr.Button("Load Model")
268
+
269
+ with gr.Column(scale=1):
270
+ # Status message box
271
+ status_box = gr.Textbox(label="Status", placeholder="Model loading status", interactive=False, value="Model not loaded", lines=3)
272
+
273
+ with gr.Row():
274
+ with gr.Column(scale=1):
275
+ # Input for the prompt
276
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here", lines=8)
277
+ with gr.Row():
278
+ seed = gr.Slider(0, np.iinfo(np.int32).max, step=1, label="Seed", value=42)
279
+ run_button = gr.Button("Generate Image")
280
+
281
+ with gr.Column(scale=1):
282
+ # Output result
283
+ result_image = gr.Image(label="Generated Image", interactive=False)
284
+
285
+ frame_selector.change(update_domains, inputs=frame_selector, outputs=domain_selector)
286
+
287
+ # Load model button action
288
+ load_button.click(fn=load_target_model, inputs=[frame_selector, domain_selector], outputs=[status_box])
289
+
290
+ # Run Button
291
+ run_button.click(fn=infer, inputs=[prompt, frame_selector, seed], outputs=[result_image])
292
+
293
+ # gr.Markdown("### Examples")
294
+ # examples = [
295
+ # [
296
+ # "sks14, 2*2 puzzle of 4 sub-images, step-by-step wood sculpture carving process", # prompt
297
+ # "./gradio_examples/wood_sculpture.png",
298
+ # "Wood Sculpture", # recraft_model
299
+ # 12345 # seed
300
+ # ],
301
+ # [
302
+ # "sks1, 3*3 puzzle of 9 sub-images, step-by-step lego model construction process", # prompt
303
+ # "./gradio_examples/lego.png",
304
+ # "LEGO", # recraft_model
305
+ # 42 # seed
306
+ # ],
307
+ # [
308
+ # "sks6, 3*3 puzzle of 9 sub-images, step-by-step portrait painting process", # prompt
309
+ # "./gradio_examples/portrait.png",
310
+ # "Portrait", # recraft_model
311
+ # 999 # seed
312
+ # ],
313
+ # [
314
+ # "sks10, 3*3 puzzle of 9 sub-images, step-by-step sketch painting process,", # prompt
315
+ # "./gradio_examples/sketch.png",
316
+ # "Sketch",
317
+ # 2023
318
+ # ]
319
+ # ]
320
+
321
+ # gr.Examples(
322
+ # examples=examples,
323
+ # inputs=[prompt, sample_image, recraft_model, seed],
324
+ # outputs=[result_image],
325
+ # cache_examples=False
326
+ # )
327
+
328
+ # Launch the Gradio app
329
+ demo.launch(server_port=8289, server_name="0.0.0.0", share=True)
library/__init__.py ADDED
File without changes
library/adafactor_fused.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers import Adafactor
4
+
5
+ # stochastic rounding for bfloat16
6
+ # The implementation was provided by 2kpr. Thank you very much!
7
+
8
+ def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
9
+ """
10
+ copies source into target using stochastic rounding
11
+
12
+ Args:
13
+ target: the target tensor with dtype=bfloat16
14
+ source: the target tensor with dtype=float32
15
+ """
16
+ # create a random 16 bit integer
17
+ result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
18
+
19
+ # add the random number to the lower 16 bit of the mantissa
20
+ result.add_(source.view(dtype=torch.int32))
21
+
22
+ # mask off the lower 16 bit of the mantissa
23
+ result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
24
+
25
+ # copy the higher 16 bit into the target tensor
26
+ target.copy_(result.view(dtype=torch.float32))
27
+
28
+ del result
29
+
30
+
31
+ @torch.no_grad()
32
+ def adafactor_step_param(self, p, group):
33
+ if p.grad is None:
34
+ return
35
+ grad = p.grad
36
+ if grad.dtype in {torch.float16, torch.bfloat16}:
37
+ grad = grad.float()
38
+ if grad.is_sparse:
39
+ raise RuntimeError("Adafactor does not support sparse gradients.")
40
+
41
+ state = self.state[p]
42
+ grad_shape = grad.shape
43
+
44
+ factored, use_first_moment = Adafactor._get_options(group, grad_shape)
45
+ # State Initialization
46
+ if len(state) == 0:
47
+ state["step"] = 0
48
+
49
+ if use_first_moment:
50
+ # Exponential moving average of gradient values
51
+ state["exp_avg"] = torch.zeros_like(grad)
52
+ if factored:
53
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
54
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
55
+ else:
56
+ state["exp_avg_sq"] = torch.zeros_like(grad)
57
+
58
+ state["RMS"] = 0
59
+ else:
60
+ if use_first_moment:
61
+ state["exp_avg"] = state["exp_avg"].to(grad)
62
+ if factored:
63
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
64
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
65
+ else:
66
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
67
+
68
+ p_data_fp32 = p
69
+ if p.dtype in {torch.float16, torch.bfloat16}:
70
+ p_data_fp32 = p_data_fp32.float()
71
+
72
+ state["step"] += 1
73
+ state["RMS"] = Adafactor._rms(p_data_fp32)
74
+ lr = Adafactor._get_lr(group, state)
75
+
76
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
77
+ update = (grad**2) + group["eps"][0]
78
+ if factored:
79
+ exp_avg_sq_row = state["exp_avg_sq_row"]
80
+ exp_avg_sq_col = state["exp_avg_sq_col"]
81
+
82
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
83
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
84
+
85
+ # Approximation of exponential moving average of square of gradient
86
+ update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
87
+ update.mul_(grad)
88
+ else:
89
+ exp_avg_sq = state["exp_avg_sq"]
90
+
91
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
92
+ update = exp_avg_sq.rsqrt().mul_(grad)
93
+
94
+ update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
95
+ update.mul_(lr)
96
+
97
+ if use_first_moment:
98
+ exp_avg = state["exp_avg"]
99
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
100
+ update = exp_avg
101
+
102
+ if group["weight_decay"] != 0:
103
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
104
+
105
+ p_data_fp32.add_(-update)
106
+
107
+ # if p.dtype in {torch.float16, torch.bfloat16}:
108
+ # p.copy_(p_data_fp32)
109
+
110
+ if p.dtype == torch.bfloat16:
111
+ copy_stochastic_(p, p_data_fp32)
112
+ elif p.dtype == torch.float16:
113
+ p.copy_(p_data_fp32)
114
+
115
+
116
+ @torch.no_grad()
117
+ def adafactor_step(self, closure=None):
118
+ """
119
+ Performs a single optimization step
120
+
121
+ Arguments:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ loss = closure()
128
+
129
+ for group in self.param_groups:
130
+ for p in group["params"]:
131
+ adafactor_step_param(self, p, group)
132
+
133
+ return loss
134
+
135
+
136
+ def patch_adafactor_fused(optimizer: Adafactor):
137
+ optimizer.step_param = adafactor_step_param.__get__(optimizer)
138
+ optimizer.step = adafactor_step.__get__(optimizer)
library/attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
library/config_util.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import (
18
+ Any,
19
+ ExactSequence,
20
+ MultipleInvalid,
21
+ Object,
22
+ Required,
23
+ Schema,
24
+ )
25
+ from transformers import CLIPTokenizer
26
+
27
+ from . import train_util
28
+ from .train_util import (
29
+ DreamBoothSubset,
30
+ FineTuningSubset,
31
+ ControlNetSubset,
32
+ DreamBoothDataset,
33
+ FineTuningDataset,
34
+ ControlNetDataset,
35
+ DatasetGroup,
36
+ )
37
+ from .utils import setup_logging
38
+
39
+ setup_logging()
40
+ import logging
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ def add_config_arguments(parser: argparse.ArgumentParser):
46
+ parser.add_argument(
47
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
48
+ )
49
+
50
+
51
+ # TODO: inherit Params class in Subset, Dataset
52
+
53
+
54
+ @dataclass
55
+ class BaseSubsetParams:
56
+ image_dir: Optional[str] = None
57
+ num_repeats: int = 1
58
+ shuffle_caption: bool = False
59
+ caption_separator: str = (",",)
60
+ keep_tokens: int = 0
61
+ keep_tokens_separator: str = (None,)
62
+ secondary_separator: Optional[str] = None
63
+ enable_wildcard: bool = False
64
+ color_aug: bool = False
65
+ flip_aug: bool = False
66
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
67
+ random_crop: bool = False
68
+ caption_prefix: Optional[str] = None
69
+ caption_suffix: Optional[str] = None
70
+ caption_dropout_rate: float = 0.0
71
+ caption_dropout_every_n_epochs: int = 0
72
+ caption_tag_dropout_rate: float = 0.0
73
+ token_warmup_min: int = 1
74
+ token_warmup_step: float = 0
75
+ custom_attributes: Optional[Dict[str, Any]] = None
76
+
77
+
78
+ @dataclass
79
+ class DreamBoothSubsetParams(BaseSubsetParams):
80
+ is_reg: bool = False
81
+ class_tokens: Optional[str] = None
82
+ caption_extension: str = ".caption"
83
+ cache_info: bool = False
84
+ alpha_mask: bool = False
85
+
86
+
87
+ @dataclass
88
+ class FineTuningSubsetParams(BaseSubsetParams):
89
+ metadata_file: Optional[str] = None
90
+ alpha_mask: bool = False
91
+
92
+
93
+ @dataclass
94
+ class ControlNetSubsetParams(BaseSubsetParams):
95
+ conditioning_data_dir: str = None
96
+ caption_extension: str = ".caption"
97
+ cache_info: bool = False
98
+
99
+
100
+ @dataclass
101
+ class BaseDatasetParams:
102
+ resolution: Optional[Tuple[int, int]] = None
103
+ network_multiplier: float = 1.0
104
+ debug_dataset: bool = False
105
+
106
+
107
+ @dataclass
108
+ class DreamBoothDatasetParams(BaseDatasetParams):
109
+ batch_size: int = 1
110
+ enable_bucket: bool = False
111
+ min_bucket_reso: int = 256
112
+ max_bucket_reso: int = 1024
113
+ bucket_reso_steps: int = 64
114
+ bucket_no_upscale: bool = False
115
+ prior_loss_weight: float = 1.0
116
+
117
+
118
+ @dataclass
119
+ class FineTuningDatasetParams(BaseDatasetParams):
120
+ batch_size: int = 1
121
+ enable_bucket: bool = False
122
+ min_bucket_reso: int = 256
123
+ max_bucket_reso: int = 1024
124
+ bucket_reso_steps: int = 64
125
+ bucket_no_upscale: bool = False
126
+
127
+
128
+ @dataclass
129
+ class ControlNetDatasetParams(BaseDatasetParams):
130
+ batch_size: int = 1
131
+ enable_bucket: bool = False
132
+ min_bucket_reso: int = 256
133
+ max_bucket_reso: int = 1024
134
+ bucket_reso_steps: int = 64
135
+ bucket_no_upscale: bool = False
136
+
137
+
138
+ @dataclass
139
+ class SubsetBlueprint:
140
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
141
+
142
+
143
+ @dataclass
144
+ class DatasetBlueprint:
145
+ is_dreambooth: bool
146
+ is_controlnet: bool
147
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
148
+ subsets: Sequence[SubsetBlueprint]
149
+
150
+
151
+ @dataclass
152
+ class DatasetGroupBlueprint:
153
+ datasets: Sequence[DatasetBlueprint]
154
+
155
+
156
+ @dataclass
157
+ class Blueprint:
158
+ dataset_group: DatasetGroupBlueprint
159
+
160
+
161
+ class ConfigSanitizer:
162
+ # @curry
163
+ @staticmethod
164
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
165
+ Schema(ExactSequence([klass, klass]))(value)
166
+ return tuple(value)
167
+
168
+ # @curry
169
+ @staticmethod
170
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
171
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
172
+ try:
173
+ Schema(klass)(value)
174
+ return (value, value)
175
+ except:
176
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
177
+
178
+ # subset schema
179
+ SUBSET_ASCENDABLE_SCHEMA = {
180
+ "color_aug": bool,
181
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
182
+ "flip_aug": bool,
183
+ "num_repeats": int,
184
+ "random_crop": bool,
185
+ "shuffle_caption": bool,
186
+ "keep_tokens": int,
187
+ "keep_tokens_separator": str,
188
+ "secondary_separator": str,
189
+ "caption_separator": str,
190
+ "enable_wildcard": bool,
191
+ "token_warmup_min": int,
192
+ "token_warmup_step": Any(float, int),
193
+ "caption_prefix": str,
194
+ "caption_suffix": str,
195
+ "custom_attributes": dict,
196
+ }
197
+ # DO means DropOut
198
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
199
+ "caption_dropout_every_n_epochs": int,
200
+ "caption_dropout_rate": Any(float, int),
201
+ "caption_tag_dropout_rate": Any(float, int),
202
+ }
203
+ # DB means DreamBooth
204
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_extension": str,
206
+ "class_tokens": str,
207
+ "cache_info": bool,
208
+ }
209
+ DB_SUBSET_DISTINCT_SCHEMA = {
210
+ Required("image_dir"): str,
211
+ "is_reg": bool,
212
+ "alpha_mask": bool,
213
+ }
214
+ # FT means FineTuning
215
+ FT_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("metadata_file"): str,
217
+ "image_dir": str,
218
+ "alpha_mask": bool,
219
+ }
220
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
221
+ "caption_extension": str,
222
+ "cache_info": bool,
223
+ }
224
+ CN_SUBSET_DISTINCT_SCHEMA = {
225
+ Required("image_dir"): str,
226
+ Required("conditioning_data_dir"): str,
227
+ }
228
+
229
+ # datasets schema
230
+ DATASET_ASCENDABLE_SCHEMA = {
231
+ "batch_size": int,
232
+ "bucket_no_upscale": bool,
233
+ "bucket_reso_steps": int,
234
+ "enable_bucket": bool,
235
+ "max_bucket_reso": int,
236
+ "min_bucket_reso": int,
237
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
238
+ "network_multiplier": float,
239
+ }
240
+
241
+ # options handled by argparse but not handled by user config
242
+ ARGPARSE_SPECIFIC_SCHEMA = {
243
+ "debug_dataset": bool,
244
+ "max_token_length": Any(None, int),
245
+ "prior_loss_weight": Any(float, int),
246
+ }
247
+ # for handling default None value of argparse
248
+ ARGPARSE_NULLABLE_OPTNAMES = [
249
+ "face_crop_aug_range",
250
+ "resolution",
251
+ ]
252
+ # prepare map because option name may differ among argparse and user config
253
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
254
+ "train_batch_size": "batch_size",
255
+ "dataset_repeats": "num_repeats",
256
+ }
257
+
258
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
259
+ assert support_dreambooth or support_finetuning or support_controlnet, (
260
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
261
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
262
+ )
263
+
264
+ self.db_subset_schema = self.__merge_dict(
265
+ self.SUBSET_ASCENDABLE_SCHEMA,
266
+ self.DB_SUBSET_DISTINCT_SCHEMA,
267
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
268
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
269
+ )
270
+
271
+ self.ft_subset_schema = self.__merge_dict(
272
+ self.SUBSET_ASCENDABLE_SCHEMA,
273
+ self.FT_SUBSET_DISTINCT_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.cn_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.CN_SUBSET_DISTINCT_SCHEMA,
280
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
281
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
282
+ )
283
+
284
+ self.db_dataset_schema = self.__merge_dict(
285
+ self.DATASET_ASCENDABLE_SCHEMA,
286
+ self.SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
288
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
289
+ {"subsets": [self.db_subset_schema]},
290
+ )
291
+
292
+ self.ft_dataset_schema = self.__merge_dict(
293
+ self.DATASET_ASCENDABLE_SCHEMA,
294
+ self.SUBSET_ASCENDABLE_SCHEMA,
295
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
296
+ {"subsets": [self.ft_subset_schema]},
297
+ )
298
+
299
+ self.cn_dataset_schema = self.__merge_dict(
300
+ self.DATASET_ASCENDABLE_SCHEMA,
301
+ self.SUBSET_ASCENDABLE_SCHEMA,
302
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
303
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
304
+ {"subsets": [self.cn_subset_schema]},
305
+ )
306
+
307
+ if support_dreambooth and support_finetuning:
308
+
309
+ def validate_flex_dataset(dataset_config: dict):
310
+ subsets_config = dataset_config.get("subsets", [])
311
+
312
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
313
+ return Schema(self.cn_dataset_schema)(dataset_config)
314
+ # check dataset meets FT style
315
+ # NOTE: all FT subsets should have "metadata_file"
316
+ elif all(["metadata_file" in subset for subset in subsets_config]):
317
+ return Schema(self.ft_dataset_schema)(dataset_config)
318
+ # check dataset meets DB style
319
+ # NOTE: all DB subsets should have no "metadata_file"
320
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
321
+ return Schema(self.db_dataset_schema)(dataset_config)
322
+ else:
323
+ raise voluptuous.Invalid(
324
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
325
+ )
326
+
327
+ self.dataset_schema = validate_flex_dataset
328
+ elif support_dreambooth:
329
+ if support_controlnet:
330
+ self.dataset_schema = self.cn_dataset_schema
331
+ else:
332
+ self.dataset_schema = self.db_dataset_schema
333
+ elif support_finetuning:
334
+ self.dataset_schema = self.ft_dataset_schema
335
+ elif support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+
338
+ self.general_schema = self.__merge_dict(
339
+ self.DATASET_ASCENDABLE_SCHEMA,
340
+ self.SUBSET_ASCENDABLE_SCHEMA,
341
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
342
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
343
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
344
+ )
345
+
346
+ self.user_config_validator = Schema(
347
+ {
348
+ "general": self.general_schema,
349
+ "datasets": [self.dataset_schema],
350
+ }
351
+ )
352
+
353
+ self.argparse_schema = self.__merge_dict(
354
+ self.general_schema,
355
+ self.ARGPARSE_SPECIFIC_SCHEMA,
356
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
357
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
358
+ )
359
+
360
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
361
+
362
+ def sanitize_user_config(self, user_config: dict) -> dict:
363
+ try:
364
+ return self.user_config_validator(user_config)
365
+ except MultipleInvalid:
366
+ # TODO: エラー発生時のメッセージをわかりやすくする
367
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
368
+ raise
369
+
370
+ # NOTE: In nature, argument parser result is not needed to be sanitize
371
+ # However this will help us to detect program bug
372
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
373
+ try:
374
+ return self.argparse_config_validator(argparse_namespace)
375
+ except MultipleInvalid:
376
+ # XXX: this should be a bug
377
+ logger.error(
378
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
379
+ )
380
+ raise
381
+
382
+ # NOTE: value would be overwritten by latter dict if there is already the same key
383
+ @staticmethod
384
+ def __merge_dict(*dict_list: dict) -> dict:
385
+ merged = {}
386
+ for schema in dict_list:
387
+ # merged |= schema
388
+ for k, v in schema.items():
389
+ merged[k] = v
390
+ return merged
391
+
392
+
393
+ class BlueprintGenerator:
394
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
395
+
396
+ def __init__(self, sanitizer: ConfigSanitizer):
397
+ self.sanitizer = sanitizer
398
+
399
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
400
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
401
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
402
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
403
+
404
+ # convert argparse namespace to dict like config
405
+ # NOTE: it is ok to have extra entries in dict
406
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
407
+ argparse_config = {
408
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
409
+ }
410
+
411
+ general_config = sanitized_user_config.get("general", {})
412
+
413
+ dataset_blueprints = []
414
+ for dataset_config in sanitized_user_config.get("datasets", []):
415
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
416
+ subsets = dataset_config.get("subsets", [])
417
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
418
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
419
+ if is_controlnet:
420
+ subset_params_klass = ControlNetSubsetParams
421
+ dataset_params_klass = ControlNetDatasetParams
422
+ elif is_dreambooth:
423
+ subset_params_klass = DreamBoothSubsetParams
424
+ dataset_params_klass = DreamBoothDatasetParams
425
+ else:
426
+ subset_params_klass = FineTuningSubsetParams
427
+ dataset_params_klass = FineTuningDatasetParams
428
+
429
+ subset_blueprints = []
430
+ for subset_config in subsets:
431
+ params = self.generate_params_by_fallbacks(
432
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
433
+ )
434
+ subset_blueprints.append(SubsetBlueprint(params))
435
+
436
+ params = self.generate_params_by_fallbacks(
437
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
438
+ )
439
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
440
+
441
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
442
+
443
+ return Blueprint(dataset_group_blueprint)
444
+
445
+ @staticmethod
446
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
447
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
448
+ search_value = BlueprintGenerator.search_value
449
+ default_params = asdict(param_klass())
450
+ param_names = default_params.keys()
451
+
452
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
453
+
454
+ return param_klass(**params)
455
+
456
+ @staticmethod
457
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
458
+ for cand in fallbacks:
459
+ value = cand.get(key)
460
+ if value is not None:
461
+ return value
462
+
463
+ return default_value
464
+
465
+
466
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
467
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
468
+
469
+ for dataset_blueprint in dataset_group_blueprint.datasets:
470
+ if dataset_blueprint.is_controlnet:
471
+ subset_klass = ControlNetSubset
472
+ dataset_klass = ControlNetDataset
473
+ elif dataset_blueprint.is_dreambooth:
474
+ subset_klass = DreamBoothSubset
475
+ dataset_klass = DreamBoothDataset
476
+ else:
477
+ subset_klass = FineTuningSubset
478
+ dataset_klass = FineTuningDataset
479
+
480
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
481
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
482
+ datasets.append(dataset)
483
+
484
+ # print info
485
+ info = ""
486
+ for i, dataset in enumerate(datasets):
487
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
488
+ is_controlnet = isinstance(dataset, ControlNetDataset)
489
+ info += dedent(
490
+ f"""\
491
+ [Dataset {i}]
492
+ batch_size: {dataset.batch_size}
493
+ resolution: {(dataset.width, dataset.height)}
494
+ enable_bucket: {dataset.enable_bucket}
495
+ network_multiplier: {dataset.network_multiplier}
496
+ """
497
+ )
498
+
499
+ if dataset.enable_bucket:
500
+ info += indent(
501
+ dedent(
502
+ f"""\
503
+ min_bucket_reso: {dataset.min_bucket_reso}
504
+ max_bucket_reso: {dataset.max_bucket_reso}
505
+ bucket_reso_steps: {dataset.bucket_reso_steps}
506
+ bucket_no_upscale: {dataset.bucket_no_upscale}
507
+ \n"""
508
+ ),
509
+ " ",
510
+ )
511
+ else:
512
+ info += "\n"
513
+
514
+ for j, subset in enumerate(dataset.subsets):
515
+ info += indent(
516
+ dedent(
517
+ f"""\
518
+ [Subset {j} of Dataset {i}]
519
+ image_dir: "{subset.image_dir}"
520
+ image_count: {subset.img_count}
521
+ num_repeats: {subset.num_repeats}
522
+ shuffle_caption: {subset.shuffle_caption}
523
+ keep_tokens: {subset.keep_tokens}
524
+ keep_tokens_separator: {subset.keep_tokens_separator}
525
+ caption_separator: {subset.caption_separator}
526
+ secondary_separator: {subset.secondary_separator}
527
+ enable_wildcard: {subset.enable_wildcard}
528
+ caption_dropout_rate: {subset.caption_dropout_rate}
529
+ caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
530
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
531
+ caption_prefix: {subset.caption_prefix}
532
+ caption_suffix: {subset.caption_suffix}
533
+ color_aug: {subset.color_aug}
534
+ flip_aug: {subset.flip_aug}
535
+ face_crop_aug_range: {subset.face_crop_aug_range}
536
+ random_crop: {subset.random_crop}
537
+ token_warmup_min: {subset.token_warmup_min}
538
+ token_warmup_step: {subset.token_warmup_step}
539
+ alpha_mask: {subset.alpha_mask}
540
+ custom_attributes: {subset.custom_attributes}
541
+ """
542
+ ),
543
+ " ",
544
+ )
545
+
546
+ if is_dreambooth:
547
+ info += indent(
548
+ dedent(
549
+ f"""\
550
+ is_reg: {subset.is_reg}
551
+ class_tokens: {subset.class_tokens}
552
+ caption_extension: {subset.caption_extension}
553
+ \n"""
554
+ ),
555
+ " ",
556
+ )
557
+ elif not is_controlnet:
558
+ info += indent(
559
+ dedent(
560
+ f"""\
561
+ metadata_file: {subset.metadata_file}
562
+ \n"""
563
+ ),
564
+ " ",
565
+ )
566
+
567
+ logger.info(f"{info}")
568
+
569
+ # make buckets first because it determines the length of dataset
570
+ # and set the same seed for all datasets
571
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
572
+ for i, dataset in enumerate(datasets):
573
+ logger.info(f"[Dataset {i}]")
574
+ dataset.make_buckets()
575
+ dataset.set_seed(seed)
576
+
577
+ return DatasetGroup(datasets)
578
+
579
+
580
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
581
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
582
+ tokens = name.split("_")
583
+ try:
584
+ n_repeats = int(tokens[0])
585
+ except ValueError as e:
586
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
587
+ return 0, ""
588
+ caption_by_folder = "_".join(tokens[1:])
589
+ return n_repeats, caption_by_folder
590
+
591
+ def generate(base_dir: Optional[str], is_reg: bool):
592
+ if base_dir is None:
593
+ return []
594
+
595
+ base_dir: Path = Path(base_dir)
596
+ if not base_dir.is_dir():
597
+ return []
598
+
599
+ subsets_config = []
600
+ for subdir in base_dir.iterdir():
601
+ if not subdir.is_dir():
602
+ continue
603
+
604
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
605
+ if num_repeats < 1:
606
+ continue
607
+
608
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
609
+ subsets_config.append(subset_config)
610
+
611
+ return subsets_config
612
+
613
+ subsets_config = []
614
+ subsets_config += generate(train_data_dir, False)
615
+ subsets_config += generate(reg_data_dir, True)
616
+
617
+ return subsets_config
618
+
619
+
620
+ def generate_controlnet_subsets_config_by_subdirs(
621
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
622
+ ):
623
+ def generate(base_dir: Optional[str]):
624
+ if base_dir is None:
625
+ return []
626
+
627
+ base_dir: Path = Path(base_dir)
628
+ if not base_dir.is_dir():
629
+ return []
630
+
631
+ subsets_config = []
632
+ subset_config = {
633
+ "image_dir": train_data_dir,
634
+ "conditioning_data_dir": conditioning_data_dir,
635
+ "caption_extension": caption_extension,
636
+ "num_repeats": 1,
637
+ }
638
+ subsets_config.append(subset_config)
639
+
640
+ return subsets_config
641
+
642
+ subsets_config = []
643
+ subsets_config += generate(train_data_dir)
644
+
645
+ return subsets_config
646
+
647
+
648
+ def load_user_config(file: str) -> dict:
649
+ file: Path = Path(file)
650
+ if not file.is_file():
651
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
652
+
653
+ if file.name.lower().endswith(".json"):
654
+ try:
655
+ with open(file, "r") as f:
656
+ config = json.load(f)
657
+ except Exception:
658
+ logger.error(
659
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
660
+ )
661
+ raise
662
+ elif file.name.lower().endswith(".toml"):
663
+ try:
664
+ config = toml.load(file)
665
+ except Exception:
666
+ logger.error(
667
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
668
+ )
669
+ raise
670
+ else:
671
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
672
+
673
+ return config
674
+
675
+
676
+ # for config test
677
+ if __name__ == "__main__":
678
+ parser = argparse.ArgumentParser()
679
+ parser.add_argument("--support_dreambooth", action="store_true")
680
+ parser.add_argument("--support_finetuning", action="store_true")
681
+ parser.add_argument("--support_controlnet", action="store_true")
682
+ parser.add_argument("--support_dropout", action="store_true")
683
+ parser.add_argument("dataset_config")
684
+ config_args, remain = parser.parse_known_args()
685
+
686
+ parser = argparse.ArgumentParser()
687
+ train_util.add_dataset_arguments(
688
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
689
+ )
690
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
691
+ argparse_namespace = parser.parse_args(remain)
692
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
693
+
694
+ logger.info("[argparse_namespace]")
695
+ logger.info(f"{vars(argparse_namespace)}")
696
+
697
+ user_config = load_user_config(config_args.dataset_config)
698
+
699
+ logger.info("")
700
+ logger.info("[user_config]")
701
+ logger.info(f"{user_config}")
702
+
703
+ sanitizer = ConfigSanitizer(
704
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
705
+ )
706
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
707
+
708
+ logger.info("")
709
+ logger.info("[sanitized_user_config]")
710
+ logger.info(f"{sanitized_user_config}")
711
+
712
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
713
+
714
+ logger.info("")
715
+ logger.info("[blueprint]")
716
+ logger.info(f"{blueprint}")
library/custom_offloading_utils.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import time
3
+ from typing import Optional
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from library.device_utils import clean_memory_on_device
8
+
9
+
10
+ def synchronize_device(device: torch.device):
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize()
13
+ elif device.type == "xpu":
14
+ torch.xpu.synchronize()
15
+ elif device.type == "mps":
16
+ torch.mps.synchronize()
17
+
18
+
19
+ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
20
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
21
+
22
+ weight_swap_jobs = []
23
+
24
+ # This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
25
+ # for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
26
+ # print(module_to_cpu.__class__, module_to_cuda.__class__)
27
+ # if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
28
+ # weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
29
+
30
+ modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
31
+ for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
32
+ if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
33
+ module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
34
+ if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
35
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
36
+ else:
37
+ if module_to_cuda.weight.data.device.type != device.type:
38
+ # print(
39
+ # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
40
+ # )
41
+ module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
42
+
43
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
44
+
45
+ stream = torch.cuda.Stream()
46
+ with torch.cuda.stream(stream):
47
+ # cuda to cpu
48
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
49
+ cuda_data_view.record_stream(stream)
50
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
51
+
52
+ stream.synchronize()
53
+
54
+ # cpu to cuda
55
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
56
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
57
+ module_to_cuda.weight.data = cuda_data_view
58
+
59
+ stream.synchronize()
60
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
61
+
62
+
63
+ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
64
+ """
65
+ not tested
66
+ """
67
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
68
+
69
+ weight_swap_jobs = []
70
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
71
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
72
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
73
+
74
+ # device to cpu
75
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
76
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
77
+
78
+ synchronize_device()
79
+
80
+ # cpu to device
81
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
82
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
83
+ module_to_cuda.weight.data = cuda_data_view
84
+
85
+ synchronize_device()
86
+
87
+
88
+ def weighs_to_device(layer: nn.Module, device: torch.device):
89
+ for module in layer.modules():
90
+ if hasattr(module, "weight") and module.weight is not None:
91
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
92
+
93
+
94
+ class Offloader:
95
+ """
96
+ common offloading class
97
+ """
98
+
99
+ def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
100
+ self.num_blocks = num_blocks
101
+ self.blocks_to_swap = blocks_to_swap
102
+ self.device = device
103
+ self.debug = debug
104
+
105
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
106
+ self.futures = {}
107
+ self.cuda_available = device.type == "cuda"
108
+
109
+ def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
110
+ if self.cuda_available:
111
+ swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
112
+ else:
113
+ swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
114
+
115
+ def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
116
+ def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
117
+ if self.debug:
118
+ start_time = time.perf_counter()
119
+ print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
120
+
121
+ self.swap_weight_devices(block_to_cpu, block_to_cuda)
122
+
123
+ if self.debug:
124
+ print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
125
+ return bidx_to_cpu, bidx_to_cuda # , event
126
+
127
+ block_to_cpu = blocks[block_idx_to_cpu]
128
+ block_to_cuda = blocks[block_idx_to_cuda]
129
+
130
+ self.futures[block_idx_to_cuda] = self.thread_pool.submit(
131
+ move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
132
+ )
133
+
134
+ def _wait_blocks_move(self, block_idx):
135
+ if block_idx not in self.futures:
136
+ return
137
+
138
+ if self.debug:
139
+ print(f"Wait for block {block_idx}")
140
+ start_time = time.perf_counter()
141
+
142
+ future = self.futures.pop(block_idx)
143
+ _, bidx_to_cuda = future.result()
144
+
145
+ assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
146
+
147
+ if self.debug:
148
+ print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
149
+
150
+
151
+ class ModelOffloader(Offloader):
152
+ """
153
+ supports forward offloading
154
+ """
155
+
156
+ def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
157
+ super().__init__(num_blocks, blocks_to_swap, device, debug)
158
+
159
+ # register backward hooks
160
+ self.remove_handles = []
161
+ for i, block in enumerate(blocks):
162
+ hook = self.create_backward_hook(blocks, i)
163
+ if hook is not None:
164
+ handle = block.register_full_backward_hook(hook)
165
+ self.remove_handles.append(handle)
166
+
167
+ def __del__(self):
168
+ for handle in self.remove_handles:
169
+ handle.remove()
170
+
171
+ def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
172
+ # -1 for 0-based index
173
+ num_blocks_propagated = self.num_blocks - block_index - 1
174
+ swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
175
+ waiting = block_index > 0 and block_index <= self.blocks_to_swap
176
+
177
+ if not swapping and not waiting:
178
+ return None
179
+
180
+ # create hook
181
+ block_idx_to_cpu = self.num_blocks - num_blocks_propagated
182
+ block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
183
+ block_idx_to_wait = block_index - 1
184
+
185
+ def backward_hook(module, grad_input, grad_output):
186
+ if self.debug:
187
+ print(f"Backward hook for block {block_index}")
188
+
189
+ if swapping:
190
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
191
+ if waiting:
192
+ self._wait_blocks_move(block_idx_to_wait)
193
+ return None
194
+
195
+ return backward_hook
196
+
197
+ def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
198
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
199
+ return
200
+
201
+ if self.debug:
202
+ print("Prepare block devices before forward")
203
+
204
+ for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
205
+ b.to(self.device)
206
+ weighs_to_device(b, self.device) # make sure weights are on device
207
+
208
+ for b in blocks[self.num_blocks - self.blocks_to_swap :]:
209
+ b.to(self.device) # move block to device first
210
+ weighs_to_device(b, "cpu") # make sure weights are on cpu
211
+
212
+ synchronize_device(self.device)
213
+ clean_memory_on_device(self.device)
214
+
215
+ def wait_for_block(self, block_idx: int):
216
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
217
+ return
218
+ self._wait_blocks_move(block_idx)
219
+
220
+ def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
221
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
222
+ return
223
+ if block_idx >= self.blocks_to_swap:
224
+ return
225
+ block_idx_to_cpu = block_idx
226
+ block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
227
+ self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
library/custom_train_functions.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
15
+ if hasattr(noise_scheduler, "all_snr"):
16
+ return
17
+
18
+ alphas_cumprod = noise_scheduler.alphas_cumprod
19
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
20
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
21
+ alpha = sqrt_alphas_cumprod
22
+ sigma = sqrt_one_minus_alphas_cumprod
23
+ all_snr = (alpha / sigma) ** 2
24
+
25
+ noise_scheduler.all_snr = all_snr.to(device)
26
+
27
+
28
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
29
+ # fix beta: zero terminal SNR
30
+ logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
31
+
32
+ def enforce_zero_terminal_snr(betas):
33
+ # Convert betas to alphas_bar_sqrt
34
+ alphas = 1 - betas
35
+ alphas_bar = alphas.cumprod(0)
36
+ alphas_bar_sqrt = alphas_bar.sqrt()
37
+
38
+ # Store old values.
39
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
40
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
41
+ # Shift so last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+ # Scale so first timestep is back to old value.
44
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
45
+
46
+ # Convert alphas_bar_sqrt to betas
47
+ alphas_bar = alphas_bar_sqrt**2
48
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
49
+ alphas = torch.cat([alphas_bar[0:1], alphas])
50
+ betas = 1 - alphas
51
+ return betas
52
+
53
+ betas = noise_scheduler.betas
54
+ betas = enforce_zero_terminal_snr(betas)
55
+ alphas = 1.0 - betas
56
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
57
+
58
+ # logger.info(f"original: {noise_scheduler.betas}")
59
+ # logger.info(f"fixed: {betas}")
60
+
61
+ noise_scheduler.betas = betas
62
+ noise_scheduler.alphas = alphas
63
+ noise_scheduler.alphas_cumprod = alphas_cumprod
64
+
65
+
66
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
67
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
68
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
69
+ if v_prediction:
70
+ snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
71
+ else:
72
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
73
+ loss = loss * snr_weight
74
+ return loss
75
+
76
+
77
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
78
+ scale = get_snr_scale(timesteps, noise_scheduler)
79
+ loss = loss * scale
80
+ return loss
81
+
82
+
83
+ def get_snr_scale(timesteps, noise_scheduler):
84
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
85
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
86
+ scale = snr_t / (snr_t + 1)
87
+ # # show debug info
88
+ # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
89
+ return scale
90
+
91
+
92
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
93
+ scale = get_snr_scale(timesteps, noise_scheduler)
94
+ # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
95
+ loss = loss + loss / scale * v_pred_like_loss
96
+ return loss
97
+
98
+
99
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
100
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102
+ if v_prediction:
103
+ weight = 1 / (snr_t + 1)
104
+ else:
105
+ weight = 1 / torch.sqrt(snr_t)
106
+ loss = weight * loss
107
+ return loss
108
+
109
+
110
+ # TODO train_utilと分散しているのでどちらかに寄せる
111
+
112
+
113
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
114
+ parser.add_argument(
115
+ "--min_snr_gamma",
116
+ type=float,
117
+ default=None,
118
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
119
+ )
120
+ parser.add_argument(
121
+ "--scale_v_pred_loss_like_noise_pred",
122
+ action="store_true",
123
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
124
+ )
125
+ parser.add_argument(
126
+ "--v_pred_like_loss",
127
+ type=float,
128
+ default=None,
129
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
130
+ )
131
+ parser.add_argument(
132
+ "--debiased_estimation_loss",
133
+ action="store_true",
134
+ help="debiased estimation loss / debiased estimation loss",
135
+ )
136
+ if support_weighted_captions:
137
+ parser.add_argument(
138
+ "--weighted_captions",
139
+ action="store_true",
140
+ default=False,
141
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
142
+ )
143
+
144
+
145
+ re_attention = re.compile(
146
+ r"""
147
+ \\\(|
148
+ \\\)|
149
+ \\\[|
150
+ \\]|
151
+ \\\\|
152
+ \\|
153
+ \(|
154
+ \[|
155
+ :([+-]?[.\d]+)\)|
156
+ \)|
157
+ ]|
158
+ [^\\()\[\]:]+|
159
+ :
160
+ """,
161
+ re.X,
162
+ )
163
+
164
+
165
+ def parse_prompt_attention(text):
166
+ """
167
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
168
+ Accepted tokens are:
169
+ (abc) - increases attention to abc by a multiplier of 1.1
170
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
171
+ [abc] - decreases attention to abc by a multiplier of 1.1
172
+ \( - literal character '('
173
+ \[ - literal character '['
174
+ \) - literal character ')'
175
+ \] - literal character ']'
176
+ \\ - literal character '\'
177
+ anything else - just text
178
+ >>> parse_prompt_attention('normal text')
179
+ [['normal text', 1.0]]
180
+ >>> parse_prompt_attention('an (important) word')
181
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
182
+ >>> parse_prompt_attention('(unbalanced')
183
+ [['unbalanced', 1.1]]
184
+ >>> parse_prompt_attention('\(literal\]')
185
+ [['(literal]', 1.0]]
186
+ >>> parse_prompt_attention('(unnecessary)(parens)')
187
+ [['unnecessaryparens', 1.1]]
188
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
189
+ [['a ', 1.0],
190
+ ['house', 1.5730000000000004],
191
+ [' ', 1.1],
192
+ ['on', 1.0],
193
+ [' a ', 1.1],
194
+ ['hill', 0.55],
195
+ [', sun, ', 1.1],
196
+ ['sky', 1.4641000000000006],
197
+ ['.', 1.1]]
198
+ """
199
+
200
+ res = []
201
+ round_brackets = []
202
+ square_brackets = []
203
+
204
+ round_bracket_multiplier = 1.1
205
+ square_bracket_multiplier = 1 / 1.1
206
+
207
+ def multiply_range(start_position, multiplier):
208
+ for p in range(start_position, len(res)):
209
+ res[p][1] *= multiplier
210
+
211
+ for m in re_attention.finditer(text):
212
+ text = m.group(0)
213
+ weight = m.group(1)
214
+
215
+ if text.startswith("\\"):
216
+ res.append([text[1:], 1.0])
217
+ elif text == "(":
218
+ round_brackets.append(len(res))
219
+ elif text == "[":
220
+ square_brackets.append(len(res))
221
+ elif weight is not None and len(round_brackets) > 0:
222
+ multiply_range(round_brackets.pop(), float(weight))
223
+ elif text == ")" and len(round_brackets) > 0:
224
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
225
+ elif text == "]" and len(square_brackets) > 0:
226
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
227
+ else:
228
+ res.append([text, 1.0])
229
+
230
+ for pos in round_brackets:
231
+ multiply_range(pos, round_bracket_multiplier)
232
+
233
+ for pos in square_brackets:
234
+ multiply_range(pos, square_bracket_multiplier)
235
+
236
+ if len(res) == 0:
237
+ res = [["", 1.0]]
238
+
239
+ # merge runs of identical weights
240
+ i = 0
241
+ while i + 1 < len(res):
242
+ if res[i][1] == res[i + 1][1]:
243
+ res[i][0] += res[i + 1][0]
244
+ res.pop(i + 1)
245
+ else:
246
+ i += 1
247
+
248
+ return res
249
+
250
+
251
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
252
+ r"""
253
+ Tokenize a list of prompts and return its tokens with weights of each token.
254
+
255
+ No padding, starting or ending token is included.
256
+ """
257
+ tokens = []
258
+ weights = []
259
+ truncated = False
260
+ for text in prompt:
261
+ texts_and_weights = parse_prompt_attention(text)
262
+ text_token = []
263
+ text_weight = []
264
+ for word, weight in texts_and_weights:
265
+ # tokenize and discard the starting and the ending token
266
+ token = tokenizer(word).input_ids[1:-1]
267
+ text_token += token
268
+ # copy the weight by length of token
269
+ text_weight += [weight] * len(token)
270
+ # stop if the text is too long (longer than truncation limit)
271
+ if len(text_token) > max_length:
272
+ truncated = True
273
+ break
274
+ # truncate
275
+ if len(text_token) > max_length:
276
+ truncated = True
277
+ text_token = text_token[:max_length]
278
+ text_weight = text_weight[:max_length]
279
+ tokens.append(text_token)
280
+ weights.append(text_weight)
281
+ if truncated:
282
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
283
+ return tokens, weights
284
+
285
+
286
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
287
+ r"""
288
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
289
+ """
290
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
291
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
292
+ for i in range(len(tokens)):
293
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
294
+ if no_boseos_middle:
295
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
296
+ else:
297
+ w = []
298
+ if len(weights[i]) == 0:
299
+ w = [1.0] * weights_length
300
+ else:
301
+ for j in range(max_embeddings_multiples):
302
+ w.append(1.0) # weight for starting token in this chunk
303
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
304
+ w.append(1.0) # weight for ending token in this chunk
305
+ w += [1.0] * (weights_length - len(w))
306
+ weights[i] = w[:]
307
+
308
+ return tokens, weights
309
+
310
+
311
+ def get_unweighted_text_embeddings(
312
+ tokenizer,
313
+ text_encoder,
314
+ text_input: torch.Tensor,
315
+ chunk_length: int,
316
+ clip_skip: int,
317
+ eos: int,
318
+ pad: int,
319
+ no_boseos_middle: Optional[bool] = True,
320
+ ):
321
+ """
322
+ When the length of tokens is a multiple of the capacity of the text encoder,
323
+ it should be split into chunks and sent to the text encoder individually.
324
+ """
325
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
326
+ if max_embeddings_multiples > 1:
327
+ text_embeddings = []
328
+ for i in range(max_embeddings_multiples):
329
+ # extract the i-th chunk
330
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
331
+
332
+ # cover the head and the tail by the starting and the ending tokens
333
+ text_input_chunk[:, 0] = text_input[0, 0]
334
+ if pad == eos: # v1
335
+ text_input_chunk[:, -1] = text_input[0, -1]
336
+ else: # v2
337
+ for j in range(len(text_input_chunk)):
338
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
339
+ text_input_chunk[j, -1] = eos
340
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
341
+ text_input_chunk[j, 1] = eos
342
+
343
+ if clip_skip is None or clip_skip == 1:
344
+ text_embedding = text_encoder(text_input_chunk)[0]
345
+ else:
346
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
347
+ text_embedding = enc_out["hidden_states"][-clip_skip]
348
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
349
+
350
+ if no_boseos_middle:
351
+ if i == 0:
352
+ # discard the ending token
353
+ text_embedding = text_embedding[:, :-1]
354
+ elif i == max_embeddings_multiples - 1:
355
+ # discard the starting token
356
+ text_embedding = text_embedding[:, 1:]
357
+ else:
358
+ # discard both starting and ending tokens
359
+ text_embedding = text_embedding[:, 1:-1]
360
+
361
+ text_embeddings.append(text_embedding)
362
+ text_embeddings = torch.concat(text_embeddings, axis=1)
363
+ else:
364
+ if clip_skip is None or clip_skip == 1:
365
+ text_embeddings = text_encoder(text_input)[0]
366
+ else:
367
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
368
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
369
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
370
+ return text_embeddings
371
+
372
+
373
+ def get_weighted_text_embeddings(
374
+ tokenizer,
375
+ text_encoder,
376
+ prompt: Union[str, List[str]],
377
+ device,
378
+ max_embeddings_multiples: Optional[int] = 3,
379
+ no_boseos_middle: Optional[bool] = False,
380
+ clip_skip=None,
381
+ ):
382
+ r"""
383
+ Prompts can be assigned with local weights using brackets. For example,
384
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
385
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
386
+
387
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
388
+
389
+ Args:
390
+ prompt (`str` or `List[str]`):
391
+ The prompt or prompts to guide the image generation.
392
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
393
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
394
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
395
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
396
+ ending token in each of the chunk in the middle.
397
+ skip_parsing (`bool`, *optional*, defaults to `False`):
398
+ Skip the parsing of brackets.
399
+ skip_weighting (`bool`, *optional*, defaults to `False`):
400
+ Skip the weighting. When the parsing is skipped, it is forced True.
401
+ """
402
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
403
+ if isinstance(prompt, str):
404
+ prompt = [prompt]
405
+
406
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
407
+
408
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
409
+ max_length = max([len(token) for token in prompt_tokens])
410
+
411
+ max_embeddings_multiples = min(
412
+ max_embeddings_multiples,
413
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
414
+ )
415
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
416
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
417
+
418
+ # pad the length of tokens and weights
419
+ bos = tokenizer.bos_token_id
420
+ eos = tokenizer.eos_token_id
421
+ pad = tokenizer.pad_token_id
422
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
423
+ prompt_tokens,
424
+ prompt_weights,
425
+ max_length,
426
+ bos,
427
+ eos,
428
+ no_boseos_middle=no_boseos_middle,
429
+ chunk_length=tokenizer.model_max_length,
430
+ )
431
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
432
+
433
+ # get the embeddings
434
+ text_embeddings = get_unweighted_text_embeddings(
435
+ tokenizer,
436
+ text_encoder,
437
+ prompt_tokens,
438
+ tokenizer.model_max_length,
439
+ clip_skip,
440
+ eos,
441
+ pad,
442
+ no_boseos_middle=no_boseos_middle,
443
+ )
444
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
445
+
446
+ # assign weights to the prompts and normalize in the sense of mean
447
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
448
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
449
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
450
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
451
+
452
+ return text_embeddings
453
+
454
+
455
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
456
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
457
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
458
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
459
+ for i in range(iterations):
460
+ r = random.random() * 2 + 2 # Rather than always going 2x,
461
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
462
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
463
+ if wn == 1 or hn == 1:
464
+ break # Lowest resolution is 1x1
465
+ return noise / noise.std() # Scaled back to roughly unit variance
466
+
467
+
468
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
469
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
470
+ if noise_offset is None:
471
+ return noise
472
+ if adaptive_noise_scale is not None:
473
+ # latent shape: (batch_size, channels, height, width)
474
+ # abs mean value for each channel
475
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
476
+
477
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
478
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
479
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
480
+
481
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
482
+ return noise
483
+
484
+
485
+ def apply_masked_loss(loss, batch):
486
+ if "conditioning_images" in batch:
487
+ # conditioning image is -1 to 1. we need to convert it to 0 to 1
488
+ mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
489
+ mask_image = mask_image / 2 + 0.5
490
+ # print(f"conditioning_image: {mask_image.shape}")
491
+ elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
492
+ # alpha mask is 0 to 1
493
+ mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
494
+ # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
495
+ else:
496
+ return loss
497
+
498
+ # resize to the same size as the loss
499
+ mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
500
+ loss = loss * mask_image
501
+ return loss
502
+
503
+
504
+ """
505
+ ##########################################
506
+ # Perlin Noise
507
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
508
+ delta = (res[0] / shape[0], res[1] / shape[1])
509
+ d = (shape[0] // res[0], shape[1] // res[1])
510
+
511
+ grid = (
512
+ torch.stack(
513
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
514
+ dim=-1,
515
+ )
516
+ % 1
517
+ )
518
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
519
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
520
+
521
+ tile_grads = (
522
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
523
+ .repeat_interleave(d[0], 0)
524
+ .repeat_interleave(d[1], 1)
525
+ )
526
+ dot = lambda grad, shift: (
527
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
528
+ * grad[: shape[0], : shape[1]]
529
+ ).sum(dim=-1)
530
+
531
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
532
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
533
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
534
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
535
+ t = fade(grid[: shape[0], : shape[1]])
536
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
537
+
538
+
539
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
540
+ noise = torch.zeros(shape, device=device)
541
+ frequency = 1
542
+ amplitude = 1
543
+ for _ in range(octaves):
544
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
545
+ frequency *= 2
546
+ amplitude *= persistence
547
+ return noise
548
+
549
+
550
+ def perlin_noise(noise, device, octaves):
551
+ _, c, w, h = noise.shape
552
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
553
+ noise_perlin = []
554
+ for _ in range(c):
555
+ noise_perlin.append(perlin())
556
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
557
+ noise += noise_perlin # broadcast for each batch
558
+ return noise / noise.std() # Scaled back to roughly unit variance
559
+ """
library/deepspeed_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from accelerate import DeepSpeedPlugin, Accelerator
5
+
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15
+ # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16
+ parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17
+ parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18
+ parser.add_argument(
19
+ "--offload_optimizer_device",
20
+ type=str,
21
+ default=None,
22
+ choices=[None, "cpu", "nvme"],
23
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24
+ )
25
+ parser.add_argument(
26
+ "--offload_optimizer_nvme_path",
27
+ type=str,
28
+ default=None,
29
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30
+ )
31
+ parser.add_argument(
32
+ "--offload_param_device",
33
+ type=str,
34
+ default=None,
35
+ choices=[None, "cpu", "nvme"],
36
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37
+ )
38
+ parser.add_argument(
39
+ "--offload_param_nvme_path",
40
+ type=str,
41
+ default=None,
42
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43
+ )
44
+ parser.add_argument(
45
+ "--zero3_init_flag",
46
+ action="store_true",
47
+ help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48
+ "Only applicable with ZeRO Stage-3.",
49
+ )
50
+ parser.add_argument(
51
+ "--zero3_save_16bit_model",
52
+ action="store_true",
53
+ help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54
+ )
55
+ parser.add_argument(
56
+ "--fp16_master_weights_and_gradients",
57
+ action="store_true",
58
+ help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59
+ )
60
+
61
+
62
+ def prepare_deepspeed_args(args: argparse.Namespace):
63
+ if not args.deepspeed:
64
+ return
65
+
66
+ # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67
+ args.max_data_loader_n_workers = 1
68
+
69
+
70
+ def prepare_deepspeed_plugin(args: argparse.Namespace):
71
+ if not args.deepspeed:
72
+ return None
73
+
74
+ try:
75
+ import deepspeed
76
+ except ImportError as e:
77
+ logger.error(
78
+ "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79
+ )
80
+ exit(1)
81
+
82
+ deepspeed_plugin = DeepSpeedPlugin(
83
+ zero_stage=args.zero_stage,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_clipping=args.max_grad_norm,
86
+ offload_optimizer_device=args.offload_optimizer_device,
87
+ offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88
+ offload_param_device=args.offload_param_device,
89
+ offload_param_nvme_path=args.offload_param_nvme_path,
90
+ zero3_init_flag=args.zero3_init_flag,
91
+ zero3_save_16bit_model=args.zero3_save_16bit_model,
92
+ )
93
+ deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94
+ deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95
+ args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96
+ )
97
+ deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98
+ if args.mixed_precision.lower() == "fp16":
99
+ deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100
+ if args.full_fp16 or args.fp16_master_weights_and_gradients:
101
+ if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102
+ deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103
+ logger.info("[DeepSpeed] full fp16 enable.")
104
+ else:
105
+ logger.info(
106
+ "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107
+ )
108
+
109
+ if args.offload_optimizer_device is not None:
110
+ logger.info("[DeepSpeed] start to manually build cpu_adam.")
111
+ deepspeed.ops.op_builder.CPUAdamBuilder().load()
112
+ logger.info("[DeepSpeed] building cpu_adam done.")
113
+
114
+ return deepspeed_plugin
115
+
116
+
117
+ # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118
+ def prepare_deepspeed_model(args: argparse.Namespace, **models):
119
+ # remove None from models
120
+ models = {k: v for k, v in models.items() if v is not None}
121
+
122
+ class DeepSpeedWrapper(torch.nn.Module):
123
+ def __init__(self, **kw_models) -> None:
124
+ super().__init__()
125
+ self.models = torch.nn.ModuleDict()
126
+
127
+ for key, model in kw_models.items():
128
+ if isinstance(model, list):
129
+ model = torch.nn.ModuleList(model)
130
+ assert isinstance(
131
+ model, torch.nn.Module
132
+ ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133
+ self.models.update(torch.nn.ModuleDict({key: model}))
134
+
135
+ def get_models(self):
136
+ return self.models
137
+
138
+ ds_model = DeepSpeedWrapper(**models)
139
+ return ds_model
library/device_utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+
4
+ import torch
5
+
6
+ try:
7
+ HAS_CUDA = torch.cuda.is_available()
8
+ except Exception:
9
+ HAS_CUDA = False
10
+
11
+ try:
12
+ HAS_MPS = torch.backends.mps.is_available()
13
+ except Exception:
14
+ HAS_MPS = False
15
+
16
+ try:
17
+ import intel_extension_for_pytorch as ipex # noqa
18
+
19
+ HAS_XPU = torch.xpu.is_available()
20
+ except Exception:
21
+ HAS_XPU = False
22
+
23
+
24
+ def clean_memory():
25
+ gc.collect()
26
+ if HAS_CUDA:
27
+ torch.cuda.empty_cache()
28
+ if HAS_XPU:
29
+ torch.xpu.empty_cache()
30
+ if HAS_MPS:
31
+ torch.mps.empty_cache()
32
+
33
+
34
+ def clean_memory_on_device(device: torch.device):
35
+ r"""
36
+ Clean memory on the specified device, will be called from training scripts.
37
+ """
38
+ gc.collect()
39
+
40
+ # device may "cuda" or "cuda:0", so we need to check the type of device
41
+ if device.type == "cuda":
42
+ torch.cuda.empty_cache()
43
+ if device.type == "xpu":
44
+ torch.xpu.empty_cache()
45
+ if device.type == "mps":
46
+ torch.mps.empty_cache()
47
+
48
+
49
+ @functools.lru_cache(maxsize=None)
50
+ def get_preferred_device() -> torch.device:
51
+ r"""
52
+ Do not call this function from training scripts. Use accelerator.device instead.
53
+ """
54
+ if HAS_CUDA:
55
+ device = torch.device("cuda")
56
+ elif HAS_XPU:
57
+ device = torch.device("xpu")
58
+ elif HAS_MPS:
59
+ device = torch.device("mps")
60
+ else:
61
+ device = torch.device("cpu")
62
+ print(f"get_preferred_device() -> {device}")
63
+ return device
64
+
65
+
66
+ def init_ipex():
67
+ """
68
+ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
69
+
70
+ This function should run right after importing torch and before doing anything else.
71
+
72
+ If IPEX is not available, this function does nothing.
73
+ """
74
+ try:
75
+ if HAS_XPU:
76
+ from library.ipex import ipex_init
77
+
78
+ is_initialized, error_message = ipex_init()
79
+ if not is_initialized:
80
+ print("failed to initialize ipex:", error_message)
81
+ else:
82
+ return
83
+ except Exception as e:
84
+ print("failed to initialize ipex:", e)
library/flux_models.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from FLUX repo: https://github.com/black-forest-labs/flux
2
+ # license: Apache-2.0 License
3
+
4
+
5
+ from concurrent.futures import Future, ThreadPoolExecutor
6
+ from dataclasses import dataclass
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Dict, List, Optional, Union
11
+
12
+ from library import utils
13
+ from library.device_utils import init_ipex, clean_memory_on_device
14
+
15
+ init_ipex()
16
+
17
+ import torch
18
+ from einops import rearrange
19
+ from torch import Tensor, nn
20
+ from torch.utils.checkpoint import checkpoint
21
+ from library import custom_offloading_utils
22
+
23
+ # USE_REENTRANT = True
24
+
25
+
26
+ @dataclass
27
+ class FluxParams:
28
+ in_channels: int
29
+ vec_in_dim: int
30
+ context_in_dim: int
31
+ hidden_size: int
32
+ mlp_ratio: float
33
+ num_heads: int
34
+ depth: int
35
+ depth_single_blocks: int
36
+ axes_dim: list[int]
37
+ theta: int
38
+ qkv_bias: bool
39
+ guidance_embed: bool
40
+
41
+
42
+ # region autoencoder
43
+
44
+
45
+ @dataclass
46
+ class AutoEncoderParams:
47
+ resolution: int
48
+ in_channels: int
49
+ ch: int
50
+ out_ch: int
51
+ ch_mult: list[int]
52
+ num_res_blocks: int
53
+ z_channels: int
54
+ scale_factor: float
55
+ shift_factor: float
56
+
57
+
58
+ def swish(x: Tensor) -> Tensor:
59
+ return x * torch.sigmoid(x)
60
+
61
+
62
+ class AttnBlock(nn.Module):
63
+ def __init__(self, in_channels: int):
64
+ super().__init__()
65
+ self.in_channels = in_channels
66
+
67
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
68
+
69
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
70
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
71
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
72
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
73
+
74
+ def attention(self, h_: Tensor) -> Tensor:
75
+ h_ = self.norm(h_)
76
+ q = self.q(h_)
77
+ k = self.k(h_)
78
+ v = self.v(h_)
79
+
80
+ b, c, h, w = q.shape
81
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
82
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
83
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
84
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
85
+
86
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ return x + self.proj_out(self.attention(x))
90
+
91
+
92
+ class ResnetBlock(nn.Module):
93
+ def __init__(self, in_channels: int, out_channels: int):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+ out_channels = in_channels if out_channels is None else out_channels
97
+ self.out_channels = out_channels
98
+
99
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
100
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
101
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
102
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
105
+
106
+ def forward(self, x):
107
+ h = x
108
+ h = self.norm1(h)
109
+ h = swish(h)
110
+ h = self.conv1(h)
111
+
112
+ h = self.norm2(h)
113
+ h = swish(h)
114
+ h = self.conv2(h)
115
+
116
+ if self.in_channels != self.out_channels:
117
+ x = self.nin_shortcut(x)
118
+
119
+ return x + h
120
+
121
+
122
+ class Downsample(nn.Module):
123
+ def __init__(self, in_channels: int):
124
+ super().__init__()
125
+ # no asymmetric padding in torch conv, must do it ourselves
126
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
127
+
128
+ def forward(self, x: Tensor):
129
+ pad = (0, 1, 0, 1)
130
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
131
+ x = self.conv(x)
132
+ return x
133
+
134
+
135
+ class Upsample(nn.Module):
136
+ def __init__(self, in_channels: int):
137
+ super().__init__()
138
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
139
+
140
+ def forward(self, x: Tensor):
141
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
142
+ x = self.conv(x)
143
+ return x
144
+
145
+
146
+ class Encoder(nn.Module):
147
+ def __init__(
148
+ self,
149
+ resolution: int,
150
+ in_channels: int,
151
+ ch: int,
152
+ ch_mult: list[int],
153
+ num_res_blocks: int,
154
+ z_channels: int,
155
+ ):
156
+ super().__init__()
157
+ self.ch = ch
158
+ self.num_resolutions = len(ch_mult)
159
+ self.num_res_blocks = num_res_blocks
160
+ self.resolution = resolution
161
+ self.in_channels = in_channels
162
+ # downsampling
163
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
164
+
165
+ curr_res = resolution
166
+ in_ch_mult = (1,) + tuple(ch_mult)
167
+ self.in_ch_mult = in_ch_mult
168
+ self.down = nn.ModuleList()
169
+ block_in = self.ch
170
+ for i_level in range(self.num_resolutions):
171
+ block = nn.ModuleList()
172
+ attn = nn.ModuleList()
173
+ block_in = ch * in_ch_mult[i_level]
174
+ block_out = ch * ch_mult[i_level]
175
+ for _ in range(self.num_res_blocks):
176
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
177
+ block_in = block_out
178
+ down = nn.Module()
179
+ down.block = block
180
+ down.attn = attn
181
+ if i_level != self.num_resolutions - 1:
182
+ down.downsample = Downsample(block_in)
183
+ curr_res = curr_res // 2
184
+ self.down.append(down)
185
+
186
+ # middle
187
+ self.mid = nn.Module()
188
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
189
+ self.mid.attn_1 = AttnBlock(block_in)
190
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
191
+
192
+ # end
193
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
194
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
195
+
196
+ def forward(self, x: Tensor) -> Tensor:
197
+ # downsampling
198
+ hs = [self.conv_in(x)]
199
+ for i_level in range(self.num_resolutions):
200
+ for i_block in range(self.num_res_blocks):
201
+ h = self.down[i_level].block[i_block](hs[-1])
202
+ if len(self.down[i_level].attn) > 0:
203
+ h = self.down[i_level].attn[i_block](h)
204
+ hs.append(h)
205
+ if i_level != self.num_resolutions - 1:
206
+ hs.append(self.down[i_level].downsample(hs[-1]))
207
+
208
+ # middle
209
+ h = hs[-1]
210
+ h = self.mid.block_1(h)
211
+ h = self.mid.attn_1(h)
212
+ h = self.mid.block_2(h)
213
+ # end
214
+ h = self.norm_out(h)
215
+ h = swish(h)
216
+ h = self.conv_out(h)
217
+ return h
218
+
219
+
220
+ class Decoder(nn.Module):
221
+ def __init__(
222
+ self,
223
+ ch: int,
224
+ out_ch: int,
225
+ ch_mult: list[int],
226
+ num_res_blocks: int,
227
+ in_channels: int,
228
+ resolution: int,
229
+ z_channels: int,
230
+ ):
231
+ super().__init__()
232
+ self.ch = ch
233
+ self.num_resolutions = len(ch_mult)
234
+ self.num_res_blocks = num_res_blocks
235
+ self.resolution = resolution
236
+ self.in_channels = in_channels
237
+ self.ffactor = 2 ** (self.num_resolutions - 1)
238
+
239
+ # compute in_ch_mult, block_in and curr_res at lowest res
240
+ block_in = ch * ch_mult[self.num_resolutions - 1]
241
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
242
+ self.z_shape = (1, z_channels, curr_res, curr_res)
243
+
244
+ # z to block_in
245
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
246
+
247
+ # middle
248
+ self.mid = nn.Module()
249
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
250
+ self.mid.attn_1 = AttnBlock(block_in)
251
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
252
+
253
+ # upsampling
254
+ self.up = nn.ModuleList()
255
+ for i_level in reversed(range(self.num_resolutions)):
256
+ block = nn.ModuleList()
257
+ attn = nn.ModuleList()
258
+ block_out = ch * ch_mult[i_level]
259
+ for _ in range(self.num_res_blocks + 1):
260
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
261
+ block_in = block_out
262
+ up = nn.Module()
263
+ up.block = block
264
+ up.attn = attn
265
+ if i_level != 0:
266
+ up.upsample = Upsample(block_in)
267
+ curr_res = curr_res * 2
268
+ self.up.insert(0, up) # prepend to get consistent order
269
+
270
+ # end
271
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
272
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
273
+
274
+ def forward(self, z: Tensor) -> Tensor:
275
+ # z to block_in
276
+ h = self.conv_in(z)
277
+
278
+ # middle
279
+ h = self.mid.block_1(h)
280
+ h = self.mid.attn_1(h)
281
+ h = self.mid.block_2(h)
282
+
283
+ # upsampling
284
+ for i_level in reversed(range(self.num_resolutions)):
285
+ for i_block in range(self.num_res_blocks + 1):
286
+ h = self.up[i_level].block[i_block](h)
287
+ if len(self.up[i_level].attn) > 0:
288
+ h = self.up[i_level].attn[i_block](h)
289
+ if i_level != 0:
290
+ h = self.up[i_level].upsample(h)
291
+
292
+ # end
293
+ h = self.norm_out(h)
294
+ h = swish(h)
295
+ h = self.conv_out(h)
296
+ return h
297
+
298
+
299
+ class DiagonalGaussian(nn.Module):
300
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
301
+ super().__init__()
302
+ self.sample = sample
303
+ self.chunk_dim = chunk_dim
304
+
305
+ def forward(self, z: Tensor) -> Tensor:
306
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
307
+ if self.sample:
308
+ std = torch.exp(0.5 * logvar)
309
+ return mean + std * torch.randn_like(mean)
310
+ else:
311
+ return mean
312
+
313
+
314
+ class AutoEncoder(nn.Module):
315
+ def __init__(self, params: AutoEncoderParams):
316
+ super().__init__()
317
+ self.encoder = Encoder(
318
+ resolution=params.resolution,
319
+ in_channels=params.in_channels,
320
+ ch=params.ch,
321
+ ch_mult=params.ch_mult,
322
+ num_res_blocks=params.num_res_blocks,
323
+ z_channels=params.z_channels,
324
+ )
325
+ self.decoder = Decoder(
326
+ resolution=params.resolution,
327
+ in_channels=params.in_channels,
328
+ ch=params.ch,
329
+ out_ch=params.out_ch,
330
+ ch_mult=params.ch_mult,
331
+ num_res_blocks=params.num_res_blocks,
332
+ z_channels=params.z_channels,
333
+ )
334
+ self.reg = DiagonalGaussian()
335
+
336
+ self.scale_factor = params.scale_factor
337
+ self.shift_factor = params.shift_factor
338
+
339
+ @property
340
+ def device(self) -> torch.device:
341
+ return next(self.parameters()).device
342
+
343
+ @property
344
+ def dtype(self) -> torch.dtype:
345
+ return next(self.parameters()).dtype
346
+
347
+ def encode(self, x: Tensor) -> Tensor:
348
+ z = self.reg(self.encoder(x))
349
+ z = self.scale_factor * (z - self.shift_factor)
350
+ return z
351
+
352
+ def decode(self, z: Tensor) -> Tensor:
353
+ z = z / self.scale_factor + self.shift_factor
354
+ return self.decoder(z)
355
+
356
+ def forward(self, x: Tensor) -> Tensor:
357
+ return self.decode(self.encode(x))
358
+
359
+
360
+ # endregion
361
+ # region config
362
+
363
+
364
+ @dataclass
365
+ class ModelSpec:
366
+ params: FluxParams
367
+ ae_params: AutoEncoderParams
368
+ ckpt_path: str | None
369
+ ae_path: str | None
370
+ # repo_id: str | None
371
+ # repo_flow: str | None
372
+ # repo_ae: str | None
373
+
374
+
375
+ configs = {
376
+ "dev": ModelSpec(
377
+ # repo_id="black-forest-labs/FLUX.1-dev",
378
+ # repo_flow="flux1-dev.sft",
379
+ # repo_ae="ae.sft",
380
+ ckpt_path=None, # os.getenv("FLUX_DEV"),
381
+ params=FluxParams(
382
+ in_channels=64,
383
+ vec_in_dim=768,
384
+ context_in_dim=4096,
385
+ hidden_size=3072,
386
+ mlp_ratio=4.0,
387
+ num_heads=24,
388
+ depth=19,
389
+ depth_single_blocks=38,
390
+ axes_dim=[16, 56, 56],
391
+ theta=10_000,
392
+ qkv_bias=True,
393
+ guidance_embed=True,
394
+ ),
395
+ ae_path=None, # os.getenv("AE"),
396
+ ae_params=AutoEncoderParams(
397
+ resolution=256,
398
+ in_channels=3,
399
+ ch=128,
400
+ out_ch=3,
401
+ ch_mult=[1, 2, 4, 4],
402
+ num_res_blocks=2,
403
+ z_channels=16,
404
+ scale_factor=0.3611,
405
+ shift_factor=0.1159,
406
+ ),
407
+ ),
408
+ "schnell": ModelSpec(
409
+ # repo_id="black-forest-labs/FLUX.1-schnell",
410
+ # repo_flow="flux1-schnell.sft",
411
+ # repo_ae="ae.sft",
412
+ ckpt_path=None, # os.getenv("FLUX_SCHNELL"),
413
+ params=FluxParams(
414
+ in_channels=64,
415
+ vec_in_dim=768,
416
+ context_in_dim=4096,
417
+ hidden_size=3072,
418
+ mlp_ratio=4.0,
419
+ num_heads=24,
420
+ depth=19,
421
+ depth_single_blocks=38,
422
+ axes_dim=[16, 56, 56],
423
+ theta=10_000,
424
+ qkv_bias=True,
425
+ guidance_embed=False,
426
+ ),
427
+ ae_path=None, # os.getenv("AE"),
428
+ ae_params=AutoEncoderParams(
429
+ resolution=256,
430
+ in_channels=3,
431
+ ch=128,
432
+ out_ch=3,
433
+ ch_mult=[1, 2, 4, 4],
434
+ num_res_blocks=2,
435
+ z_channels=16,
436
+ scale_factor=0.3611,
437
+ shift_factor=0.1159,
438
+ ),
439
+ ),
440
+ }
441
+
442
+
443
+ # endregion
444
+
445
+ # region math
446
+
447
+
448
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
449
+ q, k = apply_rope(q, k, pe)
450
+
451
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
452
+ x = rearrange(x, "B H L D -> B L (H D)")
453
+
454
+ return x
455
+
456
+
457
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
458
+ assert dim % 2 == 0
459
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
460
+ omega = 1.0 / (theta**scale)
461
+ out = torch.einsum("...n,d->...nd", pos, omega)
462
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
463
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
464
+ return out.float()
465
+
466
+
467
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
468
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
469
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
470
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
471
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
472
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
473
+
474
+
475
+ # endregion
476
+
477
+
478
+ # region layers
479
+
480
+
481
+ # for cpu_offload_checkpointing
482
+
483
+
484
+ def to_cuda(x):
485
+ if isinstance(x, torch.Tensor):
486
+ return x.cuda()
487
+ elif isinstance(x, (list, tuple)):
488
+ return [to_cuda(elem) for elem in x]
489
+ elif isinstance(x, dict):
490
+ return {k: to_cuda(v) for k, v in x.items()}
491
+ else:
492
+ return x
493
+
494
+
495
+ def to_cpu(x):
496
+ if isinstance(x, torch.Tensor):
497
+ return x.cpu()
498
+ elif isinstance(x, (list, tuple)):
499
+ return [to_cpu(elem) for elem in x]
500
+ elif isinstance(x, dict):
501
+ return {k: to_cpu(v) for k, v in x.items()}
502
+ else:
503
+ return x
504
+
505
+
506
+ class EmbedND(nn.Module):
507
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
508
+ super().__init__()
509
+ self.dim = dim
510
+ self.theta = theta
511
+ self.axes_dim = axes_dim
512
+
513
+ def forward(self, ids: Tensor) -> Tensor:
514
+ n_axes = ids.shape[-1]
515
+ emb = torch.cat(
516
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
517
+ dim=-3,
518
+ )
519
+
520
+ return emb.unsqueeze(1)
521
+
522
+
523
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
524
+ """
525
+ Create sinusoidal timestep embeddings.
526
+ :param t: a 1-D Tensor of N indices, one per batch element.
527
+ These may be fractional.
528
+ :param dim: the dimension of the output.
529
+ :param max_period: controls the minimum frequency of the embeddings.
530
+ :return: an (N, D) Tensor of positional embeddings.
531
+ """
532
+ t = time_factor * t
533
+ half = dim // 2
534
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
535
+
536
+ args = t[:, None].float() * freqs[None]
537
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
538
+ if dim % 2:
539
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
540
+ if torch.is_floating_point(t):
541
+ embedding = embedding.to(t)
542
+ return embedding
543
+
544
+
545
+ class MLPEmbedder(nn.Module):
546
+ def __init__(self, in_dim: int, hidden_dim: int):
547
+ super().__init__()
548
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
549
+ self.silu = nn.SiLU()
550
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
551
+
552
+ self.gradient_checkpointing = False
553
+
554
+ def enable_gradient_checkpointing(self):
555
+ self.gradient_checkpointing = True
556
+
557
+ def disable_gradient_checkpointing(self):
558
+ self.gradient_checkpointing = False
559
+
560
+ def _forward(self, x: Tensor) -> Tensor:
561
+ return self.out_layer(self.silu(self.in_layer(x)))
562
+
563
+ def forward(self, *args, **kwargs):
564
+ if self.training and self.gradient_checkpointing:
565
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
566
+ else:
567
+ return self._forward(*args, **kwargs)
568
+
569
+ # def forward(self, x):
570
+ # if self.training and self.gradient_checkpointing:
571
+ # def create_custom_forward(func):
572
+ # def custom_forward(*inputs):
573
+ # return func(*inputs)
574
+ # return custom_forward
575
+ # return torch.utils.checkpoint.checkpoint(create_custom_forward(self._forward), x, use_reentrant=USE_REENTRANT)
576
+ # else:
577
+ # return self._forward(x)
578
+
579
+
580
+ class RMSNorm(torch.nn.Module):
581
+ def __init__(self, dim: int):
582
+ super().__init__()
583
+ self.scale = nn.Parameter(torch.ones(dim))
584
+
585
+ def forward(self, x: Tensor):
586
+ x_dtype = x.dtype
587
+ x = x.float()
588
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
589
+ # return (x * rrms).to(dtype=x_dtype) * self.scale
590
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
591
+
592
+
593
+ class QKNorm(torch.nn.Module):
594
+ def __init__(self, dim: int):
595
+ super().__init__()
596
+ self.query_norm = RMSNorm(dim)
597
+ self.key_norm = RMSNorm(dim)
598
+
599
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
600
+ q = self.query_norm(q)
601
+ k = self.key_norm(k)
602
+ return q.to(v), k.to(v)
603
+
604
+
605
+ class SelfAttention(nn.Module):
606
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
607
+ super().__init__()
608
+ self.num_heads = num_heads
609
+ head_dim = dim // num_heads
610
+
611
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
612
+ self.norm = QKNorm(head_dim)
613
+ self.proj = nn.Linear(dim, dim)
614
+
615
+ # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
616
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
617
+ qkv = self.qkv(x)
618
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
619
+ q, k = self.norm(q, k, v)
620
+ x = attention(q, k, v, pe=pe)
621
+ x = self.proj(x)
622
+ return x
623
+
624
+
625
+ @dataclass
626
+ class ModulationOut:
627
+ shift: Tensor
628
+ scale: Tensor
629
+ gate: Tensor
630
+
631
+
632
+ class Modulation(nn.Module):
633
+ def __init__(self, dim: int, double: bool):
634
+ super().__init__()
635
+ self.is_double = double
636
+ self.multiplier = 6 if double else 3
637
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
638
+
639
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
640
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
641
+
642
+ return (
643
+ ModulationOut(*out[:3]),
644
+ ModulationOut(*out[3:]) if self.is_double else None,
645
+ )
646
+
647
+
648
+ class DoubleStreamBlock(nn.Module):
649
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
650
+ super().__init__()
651
+
652
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
653
+ self.num_heads = num_heads
654
+ self.hidden_size = hidden_size
655
+ self.img_mod = Modulation(hidden_size, double=True)
656
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
657
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
658
+
659
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
660
+ self.img_mlp = nn.Sequential(
661
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
662
+ nn.GELU(approximate="tanh"),
663
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
664
+ )
665
+
666
+ self.txt_mod = Modulation(hidden_size, double=True)
667
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
668
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
669
+
670
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
671
+ self.txt_mlp = nn.Sequential(
672
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
673
+ nn.GELU(approximate="tanh"),
674
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
675
+ )
676
+
677
+ self.gradient_checkpointing = False
678
+ self.cpu_offload_checkpointing = False
679
+
680
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
681
+ self.gradient_checkpointing = True
682
+ self.cpu_offload_checkpointing = cpu_offload
683
+
684
+ def disable_gradient_checkpointing(self):
685
+ self.gradient_checkpointing = False
686
+ self.cpu_offload_checkpointing = False
687
+
688
+ def _forward(
689
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
690
+ ) -> tuple[Tensor, Tensor]:
691
+ img_mod1, img_mod2 = self.img_mod(vec)
692
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
693
+
694
+ # prepare image for attention
695
+ img_modulated = self.img_norm1(img)
696
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
697
+ img_qkv = self.img_attn.qkv(img_modulated)
698
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
699
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
700
+
701
+ # prepare txt for attention
702
+ txt_modulated = self.txt_norm1(txt)
703
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
704
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
705
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
706
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
707
+
708
+ # run actual attention
709
+ q = torch.cat((txt_q, img_q), dim=2)
710
+ k = torch.cat((txt_k, img_k), dim=2)
711
+ v = torch.cat((txt_v, img_v), dim=2)
712
+
713
+ # make attention mask if not None
714
+ attn_mask = None
715
+ if txt_attention_mask is not None:
716
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
717
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
718
+ attn_mask = torch.cat(
719
+ (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
720
+ ) # b, seq_len + img_len
721
+
722
+ # broadcast attn_mask to all heads
723
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
724
+
725
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
726
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
727
+
728
+ # calculate the img blocks
729
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
730
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
731
+
732
+ # calculate the txt blocks
733
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
734
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
735
+ return img, txt
736
+
737
+ def forward(
738
+ self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
739
+ ) -> tuple[Tensor, Tensor]:
740
+ if self.training and self.gradient_checkpointing:
741
+ if not self.cpu_offload_checkpointing:
742
+ return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
743
+ # cpu offload checkpointing
744
+
745
+ def create_custom_forward(func):
746
+ def custom_forward(*inputs):
747
+ cuda_inputs = to_cuda(inputs)
748
+ outputs = func(*cuda_inputs)
749
+ return to_cpu(outputs)
750
+
751
+ return custom_forward
752
+
753
+ return torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
755
+ )
756
+
757
+ else:
758
+ return self._forward(img, txt, vec, pe, txt_attention_mask)
759
+
760
+
761
+ class SingleStreamBlock(nn.Module):
762
+ """
763
+ A DiT block with parallel linear layers as described in
764
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
765
+ """
766
+
767
+ def __init__(
768
+ self,
769
+ hidden_size: int,
770
+ num_heads: int,
771
+ mlp_ratio: float = 4.0,
772
+ qk_scale: float | None = None,
773
+ ):
774
+ super().__init__()
775
+ self.hidden_dim = hidden_size
776
+ self.num_heads = num_heads
777
+ head_dim = hidden_size // num_heads
778
+ self.scale = qk_scale or head_dim**-0.5
779
+
780
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
781
+ # qkv and mlp_in
782
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
783
+ # proj and mlp_out
784
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
785
+
786
+ self.norm = QKNorm(head_dim)
787
+
788
+ self.hidden_size = hidden_size
789
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
790
+
791
+ self.mlp_act = nn.GELU(approximate="tanh")
792
+ self.modulation = Modulation(hidden_size, double=False)
793
+
794
+ self.gradient_checkpointing = False
795
+ self.cpu_offload_checkpointing = False
796
+
797
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
798
+ self.gradient_checkpointing = True
799
+ self.cpu_offload_checkpointing = cpu_offload
800
+
801
+ def disable_gradient_checkpointing(self):
802
+ self.gradient_checkpointing = False
803
+ self.cpu_offload_checkpointing = False
804
+
805
+ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
806
+ mod, _ = self.modulation(vec)
807
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
808
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
809
+
810
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
811
+ q, k = self.norm(q, k, v)
812
+
813
+ # make attention mask if not None
814
+ attn_mask = None
815
+ if txt_attention_mask is not None:
816
+ # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
817
+ attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
818
+ attn_mask = torch.cat(
819
+ (
820
+ attn_mask,
821
+ torch.ones(
822
+ attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
823
+ ),
824
+ ),
825
+ dim=1,
826
+ ) # b, seq_len + img_len = x_len
827
+
828
+ # broadcast attn_mask to all heads
829
+ attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
830
+
831
+ # compute attention
832
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
833
+
834
+ # compute activation in mlp stream, cat again and run second linear layer
835
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
836
+ return x + mod.gate * output
837
+
838
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
839
+ if self.training and self.gradient_checkpointing:
840
+ if not self.cpu_offload_checkpointing:
841
+ return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
842
+
843
+ # cpu offload checkpointing
844
+
845
+ def create_custom_forward(func):
846
+ def custom_forward(*inputs):
847
+ cuda_inputs = to_cuda(inputs)
848
+ outputs = func(*cuda_inputs)
849
+ return to_cpu(outputs)
850
+
851
+ return custom_forward
852
+
853
+ return torch.utils.checkpoint.checkpoint(
854
+ create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
855
+ )
856
+ else:
857
+ return self._forward(x, vec, pe, txt_attention_mask)
858
+
859
+
860
+ class LastLayer(nn.Module):
861
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
862
+ super().__init__()
863
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
864
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
865
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
866
+
867
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
868
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
869
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
870
+ x = self.linear(x)
871
+ return x
872
+
873
+
874
+ # endregion
875
+
876
+
877
+ class Flux(nn.Module):
878
+ """
879
+ Transformer model for flow matching on sequences.
880
+ """
881
+
882
+ def __init__(self, params: FluxParams):
883
+ super().__init__()
884
+
885
+ self.params = params
886
+ self.in_channels = params.in_channels
887
+ self.out_channels = self.in_channels
888
+ if params.hidden_size % params.num_heads != 0:
889
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
890
+ pe_dim = params.hidden_size // params.num_heads
891
+ if sum(params.axes_dim) != pe_dim:
892
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
893
+ self.hidden_size = params.hidden_size
894
+ self.num_heads = params.num_heads
895
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
896
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
897
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
898
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
899
+ self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
900
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
901
+
902
+ self.double_blocks = nn.ModuleList(
903
+ [
904
+ DoubleStreamBlock(
905
+ self.hidden_size,
906
+ self.num_heads,
907
+ mlp_ratio=params.mlp_ratio,
908
+ qkv_bias=params.qkv_bias,
909
+ )
910
+ for _ in range(params.depth)
911
+ ]
912
+ )
913
+
914
+ self.single_blocks = nn.ModuleList(
915
+ [
916
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
917
+ for _ in range(params.depth_single_blocks)
918
+ ]
919
+ )
920
+
921
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
922
+
923
+ self.gradient_checkpointing = False
924
+ self.cpu_offload_checkpointing = False
925
+ self.blocks_to_swap = None
926
+
927
+ self.offloader_double = None
928
+ self.offloader_single = None
929
+ self.num_double_blocks = len(self.double_blocks)
930
+ self.num_single_blocks = len(self.single_blocks)
931
+
932
+ @property
933
+ def device(self):
934
+ return next(self.parameters()).device
935
+
936
+ @property
937
+ def dtype(self):
938
+ return next(self.parameters()).dtype
939
+
940
+ def enable_gradient_checkpointing(self, cpu_offload: bool = False):
941
+ self.gradient_checkpointing = True
942
+ self.cpu_offload_checkpointing = cpu_offload
943
+
944
+ self.time_in.enable_gradient_checkpointing()
945
+ self.vector_in.enable_gradient_checkpointing()
946
+ if self.guidance_in.__class__ != nn.Identity:
947
+ self.guidance_in.enable_gradient_checkpointing()
948
+
949
+ for block in self.double_blocks + self.single_blocks:
950
+ block.enable_gradient_checkpointing(cpu_offload=cpu_offload)
951
+
952
+ print(f"FLUX: Gradient checkpointing enabled. CPU offload: {cpu_offload}")
953
+
954
+ def disable_gradient_checkpointing(self):
955
+ self.gradient_checkpointing = False
956
+ self.cpu_offload_checkpointing = False
957
+
958
+ self.time_in.disable_gradient_checkpointing()
959
+ self.vector_in.disable_gradient_checkpointing()
960
+ if self.guidance_in.__class__ != nn.Identity:
961
+ self.guidance_in.disable_gradient_checkpointing()
962
+
963
+ for block in self.double_blocks + self.single_blocks:
964
+ block.disable_gradient_checkpointing()
965
+
966
+ print("FLUX: Gradient checkpointing disabled.")
967
+
968
+ def enable_block_swap(self, num_blocks: int, device: torch.device):
969
+ self.blocks_to_swap = num_blocks
970
+ double_blocks_to_swap = num_blocks // 2
971
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2
972
+
973
+ assert double_blocks_to_swap <= self.num_double_blocks - 2 and single_blocks_to_swap <= self.num_single_blocks - 2, (
974
+ f"Cannot swap more than {self.num_double_blocks - 2} double blocks and {self.num_single_blocks - 2} single blocks. "
975
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
976
+ )
977
+
978
+ self.offloader_double = custom_offloading_utils.ModelOffloader(
979
+ self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
980
+ )
981
+ self.offloader_single = custom_offloading_utils.ModelOffloader(
982
+ self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
983
+ )
984
+ print(
985
+ f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
986
+ )
987
+
988
+ def move_to_device_except_swap_blocks(self, device: torch.device):
989
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
990
+ if self.blocks_to_swap:
991
+ save_double_blocks = self.double_blocks
992
+ save_single_blocks = self.single_blocks
993
+ self.double_blocks = None
994
+ self.single_blocks = None
995
+
996
+ self.to(device)
997
+
998
+ if self.blocks_to_swap:
999
+ self.double_blocks = save_double_blocks
1000
+ self.single_blocks = save_single_blocks
1001
+
1002
+ def prepare_block_swap_before_forward(self):
1003
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1004
+ return
1005
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
1006
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
1007
+
1008
+ def forward(
1009
+ self,
1010
+ img: Tensor,
1011
+ img_ids: Tensor,
1012
+ txt: Tensor,
1013
+ txt_ids: Tensor,
1014
+ timesteps: Tensor,
1015
+ y: Tensor,
1016
+ guidance: Tensor | None = None,
1017
+ txt_attention_mask: Tensor | None = None,
1018
+ ) -> Tensor:
1019
+ if img.ndim != 3 or txt.ndim != 3:
1020
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
1021
+
1022
+ # running on sequences img
1023
+ img = self.img_in(img)
1024
+ vec = self.time_in(timestep_embedding(timesteps, 256))
1025
+ if self.params.guidance_embed:
1026
+ if guidance is None:
1027
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1028
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1029
+ vec = vec + self.vector_in(y)
1030
+ txt = self.txt_in(txt)
1031
+
1032
+ ids = torch.cat((txt_ids, img_ids), dim=1)
1033
+ pe = self.pe_embedder(ids)
1034
+
1035
+ if not self.blocks_to_swap:
1036
+ for block in self.double_blocks:
1037
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1038
+ img = torch.cat((txt, img), 1)
1039
+ for block in self.single_blocks:
1040
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1041
+ else:
1042
+ for block_idx, block in enumerate(self.double_blocks):
1043
+ self.offloader_double.wait_for_block(block_idx)
1044
+
1045
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1046
+
1047
+ self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
1048
+
1049
+ img = torch.cat((txt, img), 1)
1050
+
1051
+ for block_idx, block in enumerate(self.single_blocks):
1052
+ self.offloader_single.wait_for_block(block_idx)
1053
+
1054
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1055
+
1056
+ self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
1057
+
1058
+ img = img[:, txt.shape[1] :, ...]
1059
+
1060
+ if self.training and self.cpu_offload_checkpointing:
1061
+ img = img.to(self.device)
1062
+ vec = vec.to(self.device)
1063
+
1064
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
1065
+
1066
+ return img
1067
+
1068
+
1069
+ """
1070
+ class FluxUpper(nn.Module):
1071
+ ""
1072
+ Transformer model for flow matching on sequences.
1073
+ ""
1074
+
1075
+ def __init__(self, params: FluxParams):
1076
+ super().__init__()
1077
+
1078
+ self.params = params
1079
+ self.in_channels = params.in_channels
1080
+ self.out_channels = self.in_channels
1081
+ if params.hidden_size % params.num_heads != 0:
1082
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
1083
+ pe_dim = params.hidden_size // params.num_heads
1084
+ if sum(params.axes_dim) != pe_dim:
1085
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
1086
+ self.hidden_size = params.hidden_size
1087
+ self.num_heads = params.num_heads
1088
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
1089
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
1090
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
1091
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
1092
+ self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
1093
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
1094
+
1095
+ self.double_blocks = nn.ModuleList(
1096
+ [
1097
+ DoubleStreamBlock(
1098
+ self.hidden_size,
1099
+ self.num_heads,
1100
+ mlp_ratio=params.mlp_ratio,
1101
+ qkv_bias=params.qkv_bias,
1102
+ )
1103
+ for _ in range(params.depth)
1104
+ ]
1105
+ )
1106
+
1107
+ self.gradient_checkpointing = False
1108
+
1109
+ @property
1110
+ def device(self):
1111
+ return next(self.parameters()).device
1112
+
1113
+ @property
1114
+ def dtype(self):
1115
+ return next(self.parameters()).dtype
1116
+
1117
+ def enable_gradient_checkpointing(self):
1118
+ self.gradient_checkpointing = True
1119
+
1120
+ self.time_in.enable_gradient_checkpointing()
1121
+ self.vector_in.enable_gradient_checkpointing()
1122
+ if self.guidance_in.__class__ != nn.Identity:
1123
+ self.guidance_in.enable_gradient_checkpointing()
1124
+
1125
+ for block in self.double_blocks:
1126
+ block.enable_gradient_checkpointing()
1127
+
1128
+ print("FLUX: Gradient checkpointing enabled.")
1129
+
1130
+ def disable_gradient_checkpointing(self):
1131
+ self.gradient_checkpointing = False
1132
+
1133
+ self.time_in.disable_gradient_checkpointing()
1134
+ self.vector_in.disable_gradient_checkpointing()
1135
+ if self.guidance_in.__class__ != nn.Identity:
1136
+ self.guidance_in.disable_gradient_checkpointing()
1137
+
1138
+ for block in self.double_blocks:
1139
+ block.disable_gradient_checkpointing()
1140
+
1141
+ print("FLUX: Gradient checkpointing disabled.")
1142
+
1143
+ def forward(
1144
+ self,
1145
+ img: Tensor,
1146
+ img_ids: Tensor,
1147
+ txt: Tensor,
1148
+ txt_ids: Tensor,
1149
+ timesteps: Tensor,
1150
+ y: Tensor,
1151
+ guidance: Tensor | None = None,
1152
+ txt_attention_mask: Tensor | None = None,
1153
+ ) -> Tensor:
1154
+ if img.ndim != 3 or txt.ndim != 3:
1155
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
1156
+
1157
+ # running on sequences img
1158
+ img = self.img_in(img)
1159
+ vec = self.time_in(timestep_embedding(timesteps, 256))
1160
+ if self.params.guidance_embed:
1161
+ if guidance is None:
1162
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1163
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1164
+ vec = vec + self.vector_in(y)
1165
+ txt = self.txt_in(txt)
1166
+
1167
+ ids = torch.cat((txt_ids, img_ids), dim=1)
1168
+ pe = self.pe_embedder(ids)
1169
+
1170
+ for block in self.double_blocks:
1171
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1172
+
1173
+ return img, txt, vec, pe
1174
+
1175
+
1176
+ class FluxLower(nn.Module):
1177
+ ""
1178
+ Transformer model for flow matching on sequences.
1179
+ ""
1180
+
1181
+ def __init__(self, params: FluxParams):
1182
+ super().__init__()
1183
+ self.hidden_size = params.hidden_size
1184
+ self.num_heads = params.num_heads
1185
+ self.out_channels = params.in_channels
1186
+
1187
+ self.single_blocks = nn.ModuleList(
1188
+ [
1189
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
1190
+ for _ in range(params.depth_single_blocks)
1191
+ ]
1192
+ )
1193
+
1194
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
1195
+
1196
+ self.gradient_checkpointing = False
1197
+
1198
+ @property
1199
+ def device(self):
1200
+ return next(self.parameters()).device
1201
+
1202
+ @property
1203
+ def dtype(self):
1204
+ return next(self.parameters()).dtype
1205
+
1206
+ def enable_gradient_checkpointing(self):
1207
+ self.gradient_checkpointing = True
1208
+
1209
+ for block in self.single_blocks:
1210
+ block.enable_gradient_checkpointing()
1211
+
1212
+ print("FLUX: Gradient checkpointing enabled.")
1213
+
1214
+ def disable_gradient_checkpointing(self):
1215
+ self.gradient_checkpointing = False
1216
+
1217
+ for block in self.single_blocks:
1218
+ block.disable_gradient_checkpointing()
1219
+
1220
+ print("FLUX: Gradient checkpointing disabled.")
1221
+
1222
+ def forward(
1223
+ self,
1224
+ img: Tensor,
1225
+ txt: Tensor,
1226
+ vec: Tensor | None = None,
1227
+ pe: Tensor | None = None,
1228
+ txt_attention_mask: Tensor | None = None,
1229
+ ) -> Tensor:
1230
+ img = torch.cat((txt, img), 1)
1231
+ for block in self.single_blocks:
1232
+ img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
1233
+ img = img[:, txt.shape[1] :, ...]
1234
+
1235
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
1236
+ return img
1237
+ """
library/flux_train_utils.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import numpy as np
5
+ import toml
6
+ import json
7
+ import time
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from accelerate import Accelerator, PartialState
12
+ from transformers import CLIPTextModel
13
+ from tqdm import tqdm
14
+ from PIL import Image
15
+ from safetensors.torch import save_file
16
+
17
+ from library import flux_models, flux_utils, strategy_base, train_util
18
+ from library.device_utils import init_ipex, clean_memory_on_device
19
+
20
+ init_ipex()
21
+
22
+ from .utils import setup_logging, mem_eff_save_file
23
+
24
+ setup_logging()
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ # region sample images
31
+
32
+
33
+ def sample_images(
34
+ accelerator: Accelerator,
35
+ args: argparse.Namespace,
36
+ epoch,
37
+ steps,
38
+ flux,
39
+ ae,
40
+ text_encoders,
41
+ sample_prompts_te_outputs,
42
+ prompt_replacement=None,
43
+ ):
44
+ if steps == 0:
45
+ if not args.sample_at_first:
46
+ return
47
+ else:
48
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
49
+ return
50
+ if args.sample_every_n_epochs is not None:
51
+ # sample_every_n_steps は無視する
52
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
53
+ return
54
+ else:
55
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
56
+ return
57
+
58
+ logger.info("")
59
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
60
+ if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
61
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
62
+ return
63
+
64
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
65
+
66
+ # unwrap unet and text_encoder(s)
67
+ flux = accelerator.unwrap_model(flux)
68
+ if text_encoders is not None:
69
+ text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
70
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
71
+
72
+ prompts = train_util.load_prompts(args.sample_prompts)
73
+
74
+ save_dir = args.output_dir + "/sample"
75
+ os.makedirs(save_dir, exist_ok=True)
76
+
77
+ # save random state to restore later
78
+ rng_state = torch.get_rng_state()
79
+ cuda_rng_state = None
80
+ try:
81
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
82
+ except Exception:
83
+ pass
84
+
85
+ if distributed_state.num_processes <= 1:
86
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
87
+ with torch.no_grad(), accelerator.autocast():
88
+ for prompt_dict in prompts:
89
+ sample_image_inference(
90
+ accelerator,
91
+ args,
92
+ flux,
93
+ text_encoders,
94
+ ae,
95
+ save_dir,
96
+ prompt_dict,
97
+ epoch,
98
+ steps,
99
+ sample_prompts_te_outputs,
100
+ prompt_replacement,
101
+ )
102
+ else:
103
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
104
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
105
+ per_process_prompts = [] # list of lists
106
+ for i in range(distributed_state.num_processes):
107
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
108
+
109
+ with torch.no_grad():
110
+ with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
111
+ for prompt_dict in prompt_dict_lists[0]:
112
+ sample_image_inference(
113
+ accelerator,
114
+ args,
115
+ flux,
116
+ text_encoders,
117
+ ae,
118
+ save_dir,
119
+ prompt_dict,
120
+ epoch,
121
+ steps,
122
+ sample_prompts_te_outputs,
123
+ prompt_replacement,
124
+ )
125
+
126
+ torch.set_rng_state(rng_state)
127
+ if cuda_rng_state is not None:
128
+ torch.cuda.set_rng_state(cuda_rng_state)
129
+
130
+ clean_memory_on_device(accelerator.device)
131
+
132
+
133
+ def sample_image_inference(
134
+ accelerator: Accelerator,
135
+ args: argparse.Namespace,
136
+ flux: flux_models.Flux,
137
+ text_encoders: Optional[List[CLIPTextModel]],
138
+ ae: flux_models.AutoEncoder,
139
+ save_dir,
140
+ prompt_dict,
141
+ epoch,
142
+ steps,
143
+ sample_prompts_te_outputs,
144
+ prompt_replacement,
145
+ ):
146
+ assert isinstance(prompt_dict, dict)
147
+ # negative_prompt = prompt_dict.get("negative_prompt")
148
+ sample_steps = prompt_dict.get("sample_steps", 20)
149
+ width = prompt_dict.get("width", 512)
150
+ height = prompt_dict.get("height", 512)
151
+ scale = prompt_dict.get("scale", 3.5)
152
+ seed = prompt_dict.get("seed")
153
+ # controlnet_image = prompt_dict.get("controlnet_image")
154
+ prompt: str = prompt_dict.get("prompt", "")
155
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
156
+
157
+ if prompt_replacement is not None:
158
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
159
+ # if negative_prompt is not None:
160
+ # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
161
+
162
+ if seed is not None:
163
+ torch.manual_seed(seed)
164
+ torch.cuda.manual_seed(seed)
165
+ else:
166
+ # True random sample image generation
167
+ torch.seed()
168
+ torch.cuda.seed()
169
+
170
+ # if negative_prompt is None:
171
+ # negative_prompt = ""
172
+
173
+ height = max(64, height - height % 16) # round to divisible by 16
174
+ width = max(64, width - width % 16) # round to divisible by 16
175
+ logger.info(f"prompt: {prompt}")
176
+ # logger.info(f"negative_prompt: {negative_prompt}")
177
+ logger.info(f"height: {height}")
178
+ logger.info(f"width: {width}")
179
+ logger.info(f"sample_steps: {sample_steps}")
180
+ logger.info(f"scale: {scale}")
181
+ # logger.info(f"sample_sampler: {sampler_name}")
182
+ if seed is not None:
183
+ logger.info(f"seed: {seed}")
184
+
185
+ # encode prompts
186
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
187
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
188
+
189
+ text_encoder_conds = []
190
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
191
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
192
+ print(f"Using cached text encoder outputs for prompt: {prompt}")
193
+ if text_encoders is not None:
194
+ print(f"Encoding prompt: {prompt}")
195
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
196
+ # strategy has apply_t5_attn_mask option
197
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
198
+
199
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
200
+ if len(text_encoder_conds) == 0:
201
+ text_encoder_conds = encoded_text_encoder_conds
202
+ else:
203
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
204
+ for i in range(len(encoded_text_encoder_conds)):
205
+ if encoded_text_encoder_conds[i] is not None:
206
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
207
+
208
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
209
+
210
+ # sample image
211
+ weight_dtype = ae.dtype # TOFO give dtype as argument
212
+ packed_latent_height = height // 16
213
+ packed_latent_width = width // 16
214
+ noise = torch.randn(
215
+ 1,
216
+ packed_latent_height * packed_latent_width,
217
+ 16 * 2 * 2,
218
+ device=accelerator.device,
219
+ dtype=weight_dtype,
220
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
221
+ )
222
+ timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
223
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
224
+ t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
225
+
226
+ with accelerator.autocast(), torch.no_grad():
227
+ x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
228
+
229
+ x = x.float()
230
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
231
+
232
+ # latent to image
233
+ clean_memory_on_device(accelerator.device)
234
+ org_vae_device = ae.device # will be on cpu
235
+ ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
236
+ with accelerator.autocast(), torch.no_grad():
237
+ x = ae.decode(x)
238
+ ae.to(org_vae_device)
239
+ clean_memory_on_device(accelerator.device)
240
+
241
+ x = x.clamp(-1, 1)
242
+ x = x.permute(0, 2, 3, 1)
243
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
244
+
245
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
246
+ # but adding 'enum' to the filename should be enough
247
+
248
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
249
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
250
+ seed_suffix = "" if seed is None else f"_{seed}"
251
+ i: int = prompt_dict["enum"]
252
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
253
+ image.save(os.path.join(save_dir, img_filename))
254
+
255
+ # send images to wandb if enabled
256
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
257
+ wandb_tracker = accelerator.get_tracker("wandb")
258
+
259
+ import wandb
260
+
261
+ # not to commit images to avoid inconsistency between training and logging steps
262
+ wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
263
+
264
+
265
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
266
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
267
+
268
+
269
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
270
+ m = (y2 - y1) / (x2 - x1)
271
+ b = y1 - m * x1
272
+ return lambda x: m * x + b
273
+
274
+
275
+ def get_schedule(
276
+ num_steps: int,
277
+ image_seq_len: int,
278
+ base_shift: float = 0.5,
279
+ max_shift: float = 1.15,
280
+ shift: bool = True,
281
+ ) -> list[float]:
282
+ # extra step for zero
283
+ timesteps = torch.linspace(1, 0, num_steps + 1)
284
+
285
+ # shifting the schedule to favor high timesteps for higher signal images
286
+ if shift:
287
+ # eastimate mu based on linear estimation between two points
288
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
289
+ timesteps = time_shift(mu, 1.0, timesteps)
290
+
291
+ return timesteps.tolist()
292
+
293
+
294
+ def denoise(
295
+ model: flux_models.Flux,
296
+ img: torch.Tensor,
297
+ img_ids: torch.Tensor,
298
+ txt: torch.Tensor,
299
+ txt_ids: torch.Tensor,
300
+ vec: torch.Tensor,
301
+ timesteps: list[float],
302
+ guidance: float = 4.0,
303
+ t5_attn_mask: Optional[torch.Tensor] = None,
304
+ ):
305
+ # this is ignored for schnell
306
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
307
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
308
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
309
+ model.prepare_block_swap_before_forward()
310
+ pred = model(
311
+ img=img,
312
+ img_ids=img_ids,
313
+ txt=txt,
314
+ txt_ids=txt_ids,
315
+ y=vec,
316
+ timesteps=t_vec,
317
+ guidance=guidance_vec,
318
+ txt_attention_mask=t5_attn_mask,
319
+ )
320
+
321
+ img = img + (t_prev - t_curr) * pred
322
+
323
+ model.prepare_block_swap_before_forward()
324
+ return img
325
+
326
+
327
+ # endregion
328
+
329
+
330
+ # region train
331
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
332
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
333
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
334
+ timesteps = timesteps.to(device)
335
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
336
+
337
+ sigma = sigmas[step_indices].flatten()
338
+ while len(sigma.shape) < n_dim:
339
+ sigma = sigma.unsqueeze(-1)
340
+ return sigma
341
+
342
+
343
+ def compute_density_for_timestep_sampling(
344
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
345
+ ):
346
+ """Compute the density for sampling the timesteps when doing SD3 training.
347
+
348
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
349
+
350
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
351
+ """
352
+ if weighting_scheme == "logit_normal":
353
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
354
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
355
+ u = torch.nn.functional.sigmoid(u)
356
+ elif weighting_scheme == "mode":
357
+ u = torch.rand(size=(batch_size,), device="cpu")
358
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
359
+ else:
360
+ u = torch.rand(size=(batch_size,), device="cpu")
361
+ return u
362
+
363
+
364
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
365
+ """Computes loss weighting scheme for SD3 training.
366
+
367
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
368
+
369
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
370
+ """
371
+ if weighting_scheme == "sigma_sqrt":
372
+ weighting = (sigmas**-2.0).float()
373
+ elif weighting_scheme == "cosmap":
374
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
375
+ weighting = 2 / (math.pi * bot)
376
+ else:
377
+ weighting = torch.ones_like(sigmas)
378
+ return weighting
379
+
380
+
381
+ def get_noisy_model_input_and_timesteps(
382
+ args, noise_scheduler, latents, noise, device, dtype
383
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
384
+ bsz, _, h, w = latents.shape
385
+ sigmas = None
386
+
387
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
388
+ # Simple random t-based noise sampling
389
+ if args.timestep_sampling == "sigmoid":
390
+ # https://github.com/XLabs-AI/x-flux/tree/main
391
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
392
+ else:
393
+ t = torch.rand((bsz,), device=device)
394
+
395
+ timesteps = t * 1000.0
396
+ t = t.view(-1, 1, 1, 1)
397
+ noisy_model_input = (1 - t) * latents + t * noise
398
+ elif args.timestep_sampling == "shift":
399
+ shift = args.discrete_flow_shift
400
+ logits_norm = torch.randn(bsz, device=device)
401
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
402
+ timesteps = logits_norm.sigmoid()
403
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
404
+
405
+ t = timesteps.view(-1, 1, 1, 1)
406
+ timesteps = timesteps * 1000.0
407
+ noisy_model_input = (1 - t) * latents + t * noise
408
+ elif args.timestep_sampling == "flux_shift":
409
+ logits_norm = torch.randn(bsz, device=device)
410
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
411
+ timesteps = logits_norm.sigmoid()
412
+ mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
413
+ timesteps = time_shift(mu, 1.0, timesteps)
414
+
415
+ t = timesteps.view(-1, 1, 1, 1)
416
+ timesteps = timesteps * 1000.0
417
+ noisy_model_input = (1 - t) * latents + t * noise
418
+ else:
419
+ # Sample a random timestep for each image
420
+ # for weighting schemes where we sample timesteps non-uniformly
421
+ u = compute_density_for_timestep_sampling(
422
+ weighting_scheme=args.weighting_scheme,
423
+ batch_size=bsz,
424
+ logit_mean=args.logit_mean,
425
+ logit_std=args.logit_std,
426
+ mode_scale=args.mode_scale,
427
+ )
428
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
429
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
430
+
431
+ # Add noise according to flow matching.
432
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
433
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
434
+
435
+ return noisy_model_input, timesteps, sigmas
436
+
437
+
438
+ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
439
+ weighting = None
440
+ if args.model_prediction_type == "raw":
441
+ pass
442
+ elif args.model_prediction_type == "additive":
443
+ # add the model_pred to the noisy_model_input
444
+ model_pred = model_pred + noisy_model_input
445
+ elif args.model_prediction_type == "sigma_scaled":
446
+ # apply sigma scaling
447
+ model_pred = model_pred * (-sigmas) + noisy_model_input
448
+
449
+ # these weighting schemes use a uniform timestep sampling
450
+ # and instead post-weight the loss
451
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
452
+
453
+ return model_pred, weighting
454
+
455
+
456
+ def save_models(
457
+ ckpt_path: str,
458
+ flux: flux_models.Flux,
459
+ sai_metadata: Optional[dict],
460
+ save_dtype: Optional[torch.dtype] = None,
461
+ use_mem_eff_save: bool = False,
462
+ ):
463
+ state_dict = {}
464
+
465
+ def update_sd(prefix, sd):
466
+ for k, v in sd.items():
467
+ key = prefix + k
468
+ if save_dtype is not None and v.dtype != save_dtype:
469
+ v = v.detach().clone().to("cpu").to(save_dtype)
470
+ state_dict[key] = v
471
+
472
+ update_sd("", flux.state_dict())
473
+
474
+ if not use_mem_eff_save:
475
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
476
+ else:
477
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
478
+
479
+
480
+ def save_flux_model_on_train_end(
481
+ args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
482
+ ):
483
+ def sd_saver(ckpt_file, epoch_no, global_step):
484
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
485
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
486
+
487
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
488
+
489
+
490
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
491
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
492
+ def save_flux_model_on_epoch_end_or_stepwise(
493
+ args: argparse.Namespace,
494
+ on_epoch_end: bool,
495
+ accelerator,
496
+ save_dtype: torch.dtype,
497
+ epoch: int,
498
+ num_train_epochs: int,
499
+ global_step: int,
500
+ flux: flux_models.Flux,
501
+ ):
502
+ def sd_saver(ckpt_file, epoch_no, global_step):
503
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
504
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
505
+
506
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
507
+ args,
508
+ on_epoch_end,
509
+ accelerator,
510
+ True,
511
+ True,
512
+ epoch,
513
+ num_train_epochs,
514
+ global_step,
515
+ sd_saver,
516
+ None,
517
+ )
518
+
519
+
520
+ # endregion
521
+
522
+
523
+ def add_flux_train_arguments(parser: argparse.ArgumentParser):
524
+ parser.add_argument(
525
+ "--clip_l",
526
+ type=str,
527
+ help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
528
+ )
529
+ parser.add_argument(
530
+ "--t5xxl",
531
+ type=str,
532
+ help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
533
+ )
534
+ parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
535
+ parser.add_argument(
536
+ "--t5xxl_max_token_length",
537
+ type=int,
538
+ default=None,
539
+ help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
540
+ " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
541
+ )
542
+ parser.add_argument(
543
+ "--apply_t5_attn_mask",
544
+ action="store_true",
545
+ help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
546
+ )
547
+
548
+ parser.add_argument(
549
+ "--guidance_scale",
550
+ type=float,
551
+ default=3.5,
552
+ help="the FLUX.1 dev variant is a guidance distilled model",
553
+ )
554
+
555
+ parser.add_argument(
556
+ "--timestep_sampling",
557
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
558
+ default="sigma",
559
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
560
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
561
+ )
562
+ parser.add_argument(
563
+ "--sigmoid_scale",
564
+ type=float,
565
+ default=1.0,
566
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
567
+ )
568
+ parser.add_argument(
569
+ "--model_prediction_type",
570
+ choices=["raw", "additive", "sigma_scaled"],
571
+ default="sigma_scaled",
572
+ help="How to interpret and process the model prediction: "
573
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
574
+ " / モデル予測の解釈と処理方法:"
575
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
576
+ )
577
+ parser.add_argument(
578
+ "--discrete_flow_shift",
579
+ type=float,
580
+ default=3.0,
581
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
582
+ )
library/flux_train_utils_recraft.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import numpy as np
5
+ import toml
6
+ import json
7
+ import time
8
+ from typing import Callable, Dict, List, Optional, Tuple, Union
9
+ import pdb
10
+
11
+ import torch
12
+ from accelerate import Accelerator, PartialState
13
+ from transformers import CLIPTextModel
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from safetensors.torch import save_file
17
+
18
+ from library import flux_models, flux_utils, strategy_base, train_util
19
+ from library.device_utils import init_ipex, clean_memory_on_device
20
+
21
+ init_ipex()
22
+
23
+ from .utils import setup_logging, mem_eff_save_file
24
+
25
+ setup_logging()
26
+ import logging
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ # region sample images
32
+
33
+ def sample_images(
34
+ accelerator: Accelerator,
35
+ args: argparse.Namespace,
36
+ epoch,
37
+ steps,
38
+ flux,
39
+ ae,
40
+ text_encoders,
41
+ sample_prompts_te_outputs,
42
+ prompt_replacement=None,
43
+ sample_images_ae_outputs=None
44
+ ):
45
+ if steps == 0:
46
+ if not args.sample_at_first:
47
+ return
48
+ else:
49
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
50
+ return
51
+ if args.sample_every_n_epochs is not None:
52
+ # sample_every_n_steps は無視する
53
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
54
+ return
55
+ else:
56
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
57
+ return
58
+
59
+ logger.info("")
60
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
61
+ if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
62
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
63
+ return
64
+
65
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
66
+
67
+ # unwrap unet and text_encoder(s)
68
+ flux = accelerator.unwrap_model(flux)
69
+ if text_encoders is not None:
70
+ text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
71
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
72
+
73
+ prompts = train_util.load_prompts(args.sample_prompts)
74
+
75
+ save_dir = args.output_dir + "/sample"
76
+ os.makedirs(save_dir, exist_ok=True)
77
+
78
+ # save random state to restore later
79
+ rng_state = torch.get_rng_state()
80
+ cuda_rng_state = None
81
+ try:
82
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
83
+ except Exception:
84
+ pass
85
+
86
+ if distributed_state.num_processes <= 1:
87
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
88
+ with torch.no_grad(), accelerator.autocast():
89
+ for prompt_dict in prompts:
90
+ sample_image_inference(
91
+ accelerator,
92
+ args,
93
+ flux,
94
+ text_encoders,
95
+ ae,
96
+ save_dir,
97
+ prompt_dict,
98
+ epoch,
99
+ steps,
100
+ sample_prompts_te_outputs,
101
+ prompt_replacement,
102
+ sample_images_ae_outputs
103
+ )
104
+ else:
105
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
106
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
107
+ per_process_prompts = [] # list of lists
108
+ for i in range(distributed_state.num_processes):
109
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
110
+
111
+ with torch.no_grad():
112
+ with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
113
+ for prompt_dict in prompt_dict_lists[0]:
114
+ sample_image_inference(
115
+ accelerator,
116
+ args,
117
+ flux,
118
+ text_encoders,
119
+ ae,
120
+ save_dir,
121
+ prompt_dict,
122
+ epoch,
123
+ steps,
124
+ sample_prompts_te_outputs,
125
+ prompt_replacement,
126
+ sample_images_ae_outputs
127
+ )
128
+
129
+ torch.set_rng_state(rng_state)
130
+ if cuda_rng_state is not None:
131
+ torch.cuda.set_rng_state(cuda_rng_state)
132
+
133
+ clean_memory_on_device(accelerator.device)
134
+
135
+
136
+ def sample_image_inference(
137
+ accelerator: Accelerator,
138
+ args: argparse.Namespace,
139
+ flux: flux_models.Flux,
140
+ text_encoders: Optional[List[CLIPTextModel]],
141
+ ae: flux_models.AutoEncoder,
142
+ save_dir,
143
+ prompt_dict,
144
+ epoch,
145
+ steps,
146
+ sample_prompts_te_outputs,
147
+ prompt_replacement,
148
+ sample_images_ae_outputs
149
+ ):
150
+ assert isinstance(prompt_dict, dict)
151
+ # negative_prompt = prompt_dict.get("negative_prompt")
152
+ sample_steps = prompt_dict.get("sample_steps", 20)
153
+ width = prompt_dict.get("width", 1024) if args.frame_num==4 else prompt_dict.get("width", 1056)
154
+ height = prompt_dict.get("height", 1024) if args.frame_num==4 else prompt_dict.get("height", 1056)
155
+ scale = prompt_dict.get("scale", 1.0)
156
+ seed = prompt_dict.get("seed")
157
+ # controlnet_image = prompt_dict.get("controlnet_image")
158
+ prompt: str = prompt_dict.get("prompt", "")
159
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
160
+
161
+ if prompt_replacement is not None:
162
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
163
+ # if negative_prompt is not None:
164
+ # negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
165
+
166
+ if seed is not None:
167
+ torch.manual_seed(seed)
168
+ torch.cuda.manual_seed(seed)
169
+ else:
170
+ # True random sample image generation
171
+ torch.seed()
172
+ torch.cuda.seed()
173
+
174
+ # if negative_prompt is None:
175
+ # negative_prompt = ""
176
+
177
+ height = max(64, height - height % 16) # round to divisible by 16
178
+ width = max(64, width - width % 16) # round to divisible by 16
179
+ logger.info(f"prompt: {prompt}")
180
+ # logger.info(f"negative_prompt: {negative_prompt}")
181
+ logger.info(f"height: {height}")
182
+ logger.info(f"width: {width}")
183
+ logger.info(f"sample_steps: {sample_steps}")
184
+ logger.info(f"scale: {scale}")
185
+ # logger.info(f"sample_sampler: {sampler_name}")
186
+ if seed is not None:
187
+ logger.info(f"seed: {seed}")
188
+
189
+ # encode prompts
190
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
191
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
192
+
193
+ text_encoder_conds = []
194
+ if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
195
+ text_encoder_conds = sample_prompts_te_outputs[prompt]
196
+ print(f"Using cached text encoder outputs for prompt: {prompt}")
197
+ if text_encoders is not None:
198
+ print(f"Encoding prompt: {prompt}")
199
+ tokens_and_masks = tokenize_strategy.tokenize(prompt)
200
+ # strategy has apply_t5_attn_mask option
201
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
202
+
203
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
204
+ if len(text_encoder_conds) == 0:
205
+ text_encoder_conds = encoded_text_encoder_conds
206
+ else:
207
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
208
+ for i in range(len(encoded_text_encoder_conds)):
209
+ if encoded_text_encoder_conds[i] is not None:
210
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
211
+
212
+ if sample_images_ae_outputs and prompt in sample_images_ae_outputs:
213
+ ae_outputs = sample_images_ae_outputs[prompt]
214
+ else:
215
+ ae_outputs = None
216
+
217
+ l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
218
+
219
+ # sample image
220
+ weight_dtype = ae.dtype # TOFO give dtype as argument
221
+ packed_latent_height = height // 16
222
+ packed_latent_width = width // 16
223
+ noise = torch.randn(
224
+ 1,
225
+ packed_latent_height * packed_latent_width,
226
+ 16 * 2 * 2,
227
+ device=accelerator.device,
228
+ dtype=weight_dtype,
229
+ generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
230
+ )
231
+ timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
232
+ img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
233
+ t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
234
+
235
+ with accelerator.autocast(), torch.no_grad():
236
+ x = denoise(args, flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, ae_outputs=ae_outputs)
237
+
238
+ x = x.float()
239
+ x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
240
+
241
+ # latent to image
242
+ clean_memory_on_device(accelerator.device)
243
+ org_vae_device = ae.device # will be on cpu
244
+ ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
245
+ with accelerator.autocast(), torch.no_grad():
246
+ x = ae.decode(x)
247
+ ae.to(org_vae_device)
248
+ clean_memory_on_device(accelerator.device)
249
+
250
+ x = x.clamp(-1, 1)
251
+ x = x.permute(0, 2, 3, 1)
252
+ image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
253
+
254
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
255
+ # but adding 'enum' to the filename should be enough
256
+
257
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
258
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
259
+ seed_suffix = "" if seed is None else f"_{seed}"
260
+ i: int = prompt_dict["enum"]
261
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
262
+ image.save(os.path.join(save_dir, img_filename))
263
+
264
+ # send images to wandb if enabled
265
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
266
+ wandb_tracker = accelerator.get_tracker("wandb")
267
+
268
+ import wandb
269
+ # not to commit images to avoid inconsistency between training and logging steps
270
+ wandb_tracker.log(
271
+ {f"sample_{i}": wandb.Image(
272
+ image,
273
+ caption=prompt # positive prompt as a caption
274
+ )},
275
+ commit=False
276
+ )
277
+
278
+
279
+ def time_shift(mu: float, sigma: float, t: torch.Tensor):
280
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
281
+
282
+
283
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
284
+ m = (y2 - y1) / (x2 - x1)
285
+ b = y1 - m * x1
286
+ return lambda x: m * x + b
287
+
288
+
289
+ def get_schedule(
290
+ num_steps: int,
291
+ image_seq_len: int,
292
+ base_shift: float = 0.5,
293
+ max_shift: float = 1.15,
294
+ shift: bool = True,
295
+ ) -> list[float]:
296
+ # extra step for zero
297
+ timesteps = torch.linspace(1, 0, num_steps + 1)
298
+
299
+ # shifting the schedule to favor high timesteps for higher signal images
300
+ if shift:
301
+ # eastimate mu based on linear estimation between two points
302
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
303
+ timesteps = time_shift(mu, 1.0, timesteps)
304
+
305
+ return timesteps.tolist()
306
+
307
+
308
+ def denoise(
309
+ args: argparse.Namespace,
310
+ model: flux_models.Flux,
311
+ img: torch.Tensor,
312
+ img_ids: torch.Tensor,
313
+ txt: torch.Tensor,
314
+ txt_ids: torch.Tensor,
315
+ vec: torch.Tensor,
316
+ timesteps: list[float],
317
+ guidance: float = 4.0,
318
+ t5_attn_mask: Optional[torch.Tensor] = None,
319
+ ae_outputs: torch.Tensor = None,
320
+ ):
321
+ # this is ignored for schnell
322
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
323
+ img_ids = img_ids.to(img.device)
324
+ txt_ids = txt_ids.to(img.device)
325
+ vec = vec.to(img.device)
326
+ txt = txt.to(img.device)
327
+
328
+ for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
329
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
330
+ model.prepare_block_swap_before_forward()
331
+ if args.frame_num == 4:
332
+ packed_latent_height, packed_latent_width = ae_outputs.shape[2]*2 // 2, ae_outputs.shape[3]*2 // 2
333
+ img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
334
+ img[:,:, img.shape[2] // 2: img.shape[2], :img.shape[3] // 2] = ae_outputs
335
+ else:
336
+ packed_latent_height, packed_latent_width = ae_outputs.shape[2]*3 // 2, ae_outputs.shape[3]*3 // 2
337
+ img = flux_utils.unpack_latents(img, packed_latent_height, packed_latent_width)
338
+ img[:,:, 2*img.shape[2] // 3: img.shape[2], 2*img.shape[3] // 3:img.shape[3]] = ae_outputs
339
+
340
+ img = flux_utils.pack_latents(img)
341
+ pred = model(
342
+ img=img,
343
+ img_ids=img_ids,
344
+ txt=txt,
345
+ txt_ids=txt_ids,
346
+ y=vec,
347
+ timesteps=t_vec,
348
+ guidance=guidance_vec,
349
+ txt_attention_mask=t5_attn_mask,
350
+ )
351
+
352
+ img = img + (t_prev - t_curr) * pred
353
+
354
+ model.prepare_block_swap_before_forward()
355
+ return img
356
+
357
+
358
+ # endregion
359
+
360
+
361
+ # region train
362
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
363
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
364
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
365
+ timesteps = timesteps.to(device)
366
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
367
+
368
+ sigma = sigmas[step_indices].flatten()
369
+ while len(sigma.shape) < n_dim:
370
+ sigma = sigma.unsqueeze(-1)
371
+ return sigma
372
+
373
+
374
+ def compute_density_for_timestep_sampling(
375
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
376
+ ):
377
+ """Compute the density for sampling the timesteps when doing SD3 training.
378
+
379
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
380
+
381
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
382
+ """
383
+ if weighting_scheme == "logit_normal":
384
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
385
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
386
+ u = torch.nn.functional.sigmoid(u)
387
+ elif weighting_scheme == "mode":
388
+ u = torch.rand(size=(batch_size,), device="cpu")
389
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
390
+ else:
391
+ u = torch.rand(size=(batch_size,), device="cpu")
392
+ return u
393
+
394
+
395
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
396
+ """Computes loss weighting scheme for SD3 training.
397
+
398
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
399
+
400
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
401
+ """
402
+ if weighting_scheme == "sigma_sqrt":
403
+ weighting = (sigmas**-2.0).float()
404
+ elif weighting_scheme == "cosmap":
405
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
406
+ weighting = 2 / (math.pi * bot)
407
+ else:
408
+ weighting = torch.ones_like(sigmas)
409
+ return weighting
410
+
411
+
412
+ def get_noisy_model_input_and_timesteps(
413
+ args, noise_scheduler, latents, noise, device, dtype
414
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
415
+ bsz, _, h, w = latents.shape
416
+ sigmas = None
417
+
418
+ if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
419
+ # Simple random t-based noise sampling
420
+ if args.timestep_sampling == "sigmoid":
421
+ # https://github.com/XLabs-AI/x-flux/tree/main
422
+ t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
423
+ else:
424
+ t = torch.rand((bsz,), device=device)
425
+
426
+ timesteps = t * 1000.0
427
+ t = t.view(-1, 1, 1, 1)
428
+ noisy_model_input = (1 - t) * latents + t * noise
429
+ elif args.timestep_sampling == "shift":
430
+ shift = args.discrete_flow_shift
431
+ logits_norm = torch.randn(bsz, device=device)
432
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
433
+ timesteps = logits_norm.sigmoid()
434
+ timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
435
+
436
+ t = timesteps.view(-1, 1, 1, 1)
437
+ timesteps = timesteps * 1000.0
438
+ noisy_model_input = (1 - t) * latents + t * noise
439
+ elif args.timestep_sampling == "flux_shift":
440
+ logits_norm = torch.randn(bsz, device=device)
441
+ logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
442
+ timesteps = logits_norm.sigmoid()
443
+ mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
444
+ timesteps = time_shift(mu, 1.0, timesteps)
445
+
446
+ t = timesteps.view(-1, 1, 1, 1)
447
+ timesteps = timesteps * 1000.0
448
+ noisy_model_input = (1 - t) * latents + t * noise
449
+ else:
450
+ # Sample a random timestep for each image
451
+ # for weighting schemes where we sample timesteps non-uniformly
452
+ u = compute_density_for_timestep_sampling(
453
+ weighting_scheme=args.weighting_scheme,
454
+ batch_size=bsz,
455
+ logit_mean=args.logit_mean,
456
+ logit_std=args.logit_std,
457
+ mode_scale=args.mode_scale,
458
+ )
459
+ indices = (u * noise_scheduler.config.num_train_timesteps).long()
460
+ timesteps = noise_scheduler.timesteps[indices].to(device=device)
461
+
462
+ # Add noise according to flow matching.
463
+ sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
464
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
465
+
466
+ # 替换部分区域为原始latents
467
+ h, w = noisy_model_input.shape[2], noisy_model_input.shape[3]
468
+ # import pdb; pdb.set_trace()
469
+ if args.frame_num == 4:
470
+ noisy_model_input[:, :, h//2 : h, w//2 : w] = latents[:, :, h//2:h, w//2:w]
471
+ else:
472
+ noisy_model_input[:, :, 2*h//3 : h, 2*w//3 : w] = latents[:, :, 2*h//3:h, 2*w//3:w]
473
+
474
+
475
+ return noisy_model_input, timesteps, sigmas
476
+
477
+
478
+ def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
479
+ weighting = None
480
+ if args.model_prediction_type == "raw":
481
+ pass
482
+ elif args.model_prediction_type == "additive":
483
+ # add the model_pred to the noisy_model_input
484
+ model_pred = model_pred + noisy_model_input
485
+ elif args.model_prediction_type == "sigma_scaled":
486
+ # apply sigma scaling
487
+ model_pred = model_pred * (-sigmas) + noisy_model_input
488
+
489
+ # these weighting schemes use a uniform timestep sampling
490
+ # and instead post-weight the loss
491
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
492
+
493
+ return model_pred, weighting
494
+
495
+
496
+ def save_models(
497
+ ckpt_path: str,
498
+ flux: flux_models.Flux,
499
+ sai_metadata: Optional[dict],
500
+ save_dtype: Optional[torch.dtype] = None,
501
+ use_mem_eff_save: bool = False,
502
+ ):
503
+ state_dict = {}
504
+
505
+ def update_sd(prefix, sd):
506
+ for k, v in sd.items():
507
+ key = prefix + k
508
+ if save_dtype is not None and v.dtype != save_dtype:
509
+ v = v.detach().clone().to("cpu").to(save_dtype)
510
+ state_dict[key] = v
511
+
512
+ update_sd("", flux.state_dict())
513
+
514
+ if not use_mem_eff_save:
515
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
516
+ else:
517
+ mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
518
+
519
+
520
+ def save_flux_model_on_train_end(
521
+ args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
522
+ ):
523
+ def sd_saver(ckpt_file, epoch_no, global_step):
524
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
525
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
526
+
527
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
528
+
529
+
530
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
531
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
532
+ def save_flux_model_on_epoch_end_or_stepwise(
533
+ args: argparse.Namespace,
534
+ on_epoch_end: bool,
535
+ accelerator,
536
+ save_dtype: torch.dtype,
537
+ epoch: int,
538
+ num_train_epochs: int,
539
+ global_step: int,
540
+ flux: flux_models.Flux,
541
+ ):
542
+ def sd_saver(ckpt_file, epoch_no, global_step):
543
+ sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
544
+ save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
545
+
546
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
547
+ args,
548
+ on_epoch_end,
549
+ accelerator,
550
+ True,
551
+ True,
552
+ epoch,
553
+ num_train_epochs,
554
+ global_step,
555
+ sd_saver,
556
+ None,
557
+ )
558
+
559
+
560
+ # endregion
561
+
562
+
563
+ def add_flux_train_arguments(parser: argparse.ArgumentParser):
564
+ parser.add_argument(
565
+ "--clip_l",
566
+ type=str,
567
+ help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス(*.sftまたは*.safetensors)、float16が前提",
568
+ )
569
+ parser.add_argument(
570
+ "--t5xxl",
571
+ type=str,
572
+ help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
573
+ )
574
+ parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
575
+ parser.add_argument(
576
+ "--t5xxl_max_token_length",
577
+ type=int,
578
+ default=None,
579
+ help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
580
+ " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
581
+ )
582
+ parser.add_argument(
583
+ "--apply_t5_attn_mask",
584
+ action="store_true",
585
+ help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
586
+ )
587
+ parser.add_argument(
588
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
589
+ )
590
+ parser.add_argument(
591
+ "--cache_text_encoder_outputs_to_disk",
592
+ action="store_true",
593
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
594
+ )
595
+ parser.add_argument(
596
+ "--text_encoder_batch_size",
597
+ type=int,
598
+ default=None,
599
+ help="text encoder batch size (default: None, use dataset's batch size)"
600
+ + " / text encoderのバッチサイズ(デフォルト: None, データセットのバッチサイズを使用)",
601
+ )
602
+ parser.add_argument(
603
+ "--disable_mmap_load_safetensors",
604
+ action="store_true",
605
+ help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
606
+ )
607
+
608
+ # copy from Diffusers
609
+ parser.add_argument(
610
+ "--weighting_scheme",
611
+ type=str,
612
+ default="none",
613
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
614
+ )
615
+ parser.add_argument(
616
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
617
+ )
618
+ parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
619
+ parser.add_argument(
620
+ "--mode_scale",
621
+ type=float,
622
+ default=1.29,
623
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
624
+ )
625
+ parser.add_argument(
626
+ "--guidance_scale",
627
+ type=float,
628
+ default=3.5,
629
+ help="the FLUX.1 dev variant is a guidance distilled model",
630
+ )
631
+
632
+ parser.add_argument(
633
+ "--timestep_sampling",
634
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
635
+ default="sigma",
636
+ help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
637
+ " / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
638
+ )
639
+ parser.add_argument(
640
+ "--sigmoid_scale",
641
+ type=float,
642
+ default=1.0,
643
+ help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
644
+ )
645
+ parser.add_argument(
646
+ "--model_prediction_type",
647
+ choices=["raw", "additive", "sigma_scaled"],
648
+ default="sigma_scaled",
649
+ help="How to interpret and process the model prediction: "
650
+ "raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
651
+ " / モデル予測の解釈と処理方法:"
652
+ "raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
653
+ )
654
+ parser.add_argument(
655
+ "--discrete_flow_shift",
656
+ type=float,
657
+ default=3.0,
658
+ help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
659
+ )
library/flux_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import replace
2
+ import json
3
+ import os
4
+ from typing import List, Optional, Tuple, Union
5
+ import einops
6
+ import torch
7
+
8
+ from safetensors.torch import load_file
9
+ from safetensors import safe_open
10
+ from accelerate import init_empty_weights
11
+ from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
12
+
13
+ from library.utils import setup_logging
14
+
15
+ setup_logging()
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ from library import flux_models
21
+ from library.utils import load_safetensors
22
+
23
+ MODEL_VERSION_FLUX_V1 = "flux1"
24
+ MODEL_NAME_DEV = "dev"
25
+ MODEL_NAME_SCHNELL = "schnell"
26
+
27
+
28
+ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
29
+ """
30
+ チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
31
+
32
+ Args:
33
+ ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
34
+
35
+ Returns:
36
+ Tuple[bool, bool, Tuple[int, int], List[str]]:
37
+ - bool: Diffusersかどうかを示すフラグ。
38
+ - bool: Schnellかどうかを示すフラグ。
39
+ - Tuple[int, int]: ダブルブロックとシングルブロックの数。
40
+ - List[str]: チェックポイントに含まれるキーのリスト。
41
+ """
42
+ # check the state dict: Diffusers or BFL, dev or schnell, number of blocks
43
+ logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
44
+
45
+ if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
46
+ ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
47
+ if "00001-of-00003" in ckpt_path:
48
+ ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
49
+ else:
50
+ ckpt_paths = [ckpt_path]
51
+
52
+ keys = []
53
+ for ckpt_path in ckpt_paths:
54
+ with safe_open(ckpt_path, framework="pt") as f:
55
+ keys.extend(f.keys())
56
+
57
+ # if the key has annoying prefix, remove it
58
+ if keys[0].startswith("model.diffusion_model."):
59
+ keys = [key.replace("model.diffusion_model.", "") for key in keys]
60
+
61
+ is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
62
+ is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
63
+
64
+ # check number of double and single blocks
65
+ if not is_diffusers:
66
+ max_double_block_index = max(
67
+ [int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
68
+ )
69
+ max_single_block_index = max(
70
+ [int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
71
+ )
72
+ else:
73
+ max_double_block_index = max(
74
+ [
75
+ int(key.split(".")[1])
76
+ for key in keys
77
+ if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
78
+ ]
79
+ )
80
+ max_single_block_index = max(
81
+ [
82
+ int(key.split(".")[1])
83
+ for key in keys
84
+ if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
85
+ ]
86
+ )
87
+
88
+ num_double_blocks = max_double_block_index + 1
89
+ num_single_blocks = max_single_block_index + 1
90
+
91
+ return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
92
+
93
+
94
+ def load_flow_model(
95
+ ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
96
+ ) -> Tuple[bool, flux_models.Flux]:
97
+ is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
98
+ name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
99
+
100
+ # build model
101
+ logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
102
+ with torch.device("meta"):
103
+ params = flux_models.configs[name].params
104
+
105
+ # set the number of blocks
106
+ if params.depth != num_double_blocks:
107
+ logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
108
+ params = replace(params, depth=num_double_blocks)
109
+ if params.depth_single_blocks != num_single_blocks:
110
+ logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
111
+ params = replace(params, depth_single_blocks=num_single_blocks)
112
+
113
+ model = flux_models.Flux(params)
114
+ if dtype is not None:
115
+ model = model.to(dtype)
116
+
117
+ # load_sft doesn't support torch.device
118
+ logger.info(f"Loading state dict from {ckpt_path}")
119
+ sd = {}
120
+ for ckpt_path in ckpt_paths:
121
+ sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
122
+
123
+ # convert Diffusers to BFL
124
+ if is_diffusers:
125
+ logger.info("Converting Diffusers to BFL")
126
+ sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
127
+ logger.info("Converted Diffusers to BFL")
128
+
129
+ # if the key has annoying prefix, remove it
130
+ for key in list(sd.keys()):
131
+ new_key = key.replace("model.diffusion_model.", "")
132
+ if new_key == key:
133
+ break # the model doesn't have annoying prefix
134
+ sd[new_key] = sd.pop(key)
135
+
136
+ info = model.load_state_dict(sd, strict=False, assign=True)
137
+ logger.info(f"Loaded Flux: {info}")
138
+ return is_schnell, model
139
+
140
+
141
+ def load_ae(
142
+ ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
143
+ ) -> flux_models.AutoEncoder:
144
+ logger.info("Building AutoEncoder")
145
+ with torch.device("meta"):
146
+ # dev and schnell have the same AE params
147
+ ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
148
+
149
+ logger.info(f"Loading state dict from {ckpt_path}")
150
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
151
+ info = ae.load_state_dict(sd, strict=False, assign=True)
152
+ logger.info(f"Loaded AE: {info}")
153
+ return ae
154
+
155
+
156
+ def load_clip_l(
157
+ ckpt_path: Optional[str],
158
+ dtype: torch.dtype,
159
+ device: Union[str, torch.device],
160
+ disable_mmap: bool = False,
161
+ state_dict: Optional[dict] = None,
162
+ ) -> CLIPTextModel:
163
+ logger.info("Building CLIP-L")
164
+ CLIPL_CONFIG = {
165
+ "_name_or_path": "clip-vit-large-patch14/",
166
+ "architectures": ["CLIPModel"],
167
+ "initializer_factor": 1.0,
168
+ "logit_scale_init_value": 2.6592,
169
+ "model_type": "clip",
170
+ "projection_dim": 768,
171
+ # "text_config": {
172
+ "_name_or_path": "",
173
+ "add_cross_attention": False,
174
+ "architectures": None,
175
+ "attention_dropout": 0.0,
176
+ "bad_words_ids": None,
177
+ "bos_token_id": 0,
178
+ "chunk_size_feed_forward": 0,
179
+ "cross_attention_hidden_size": None,
180
+ "decoder_start_token_id": None,
181
+ "diversity_penalty": 0.0,
182
+ "do_sample": False,
183
+ "dropout": 0.0,
184
+ "early_stopping": False,
185
+ "encoder_no_repeat_ngram_size": 0,
186
+ "eos_token_id": 2,
187
+ "finetuning_task": None,
188
+ "forced_bos_token_id": None,
189
+ "forced_eos_token_id": None,
190
+ "hidden_act": "quick_gelu",
191
+ "hidden_size": 768,
192
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
193
+ "initializer_factor": 1.0,
194
+ "initializer_range": 0.02,
195
+ "intermediate_size": 3072,
196
+ "is_decoder": False,
197
+ "is_encoder_decoder": False,
198
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
199
+ "layer_norm_eps": 1e-05,
200
+ "length_penalty": 1.0,
201
+ "max_length": 20,
202
+ "max_position_embeddings": 77,
203
+ "min_length": 0,
204
+ "model_type": "clip_text_model",
205
+ "no_repeat_ngram_size": 0,
206
+ "num_attention_heads": 12,
207
+ "num_beam_groups": 1,
208
+ "num_beams": 1,
209
+ "num_hidden_layers": 12,
210
+ "num_return_sequences": 1,
211
+ "output_attentions": False,
212
+ "output_hidden_states": False,
213
+ "output_scores": False,
214
+ "pad_token_id": 1,
215
+ "prefix": None,
216
+ "problem_type": None,
217
+ "projection_dim": 768,
218
+ "pruned_heads": {},
219
+ "remove_invalid_values": False,
220
+ "repetition_penalty": 1.0,
221
+ "return_dict": True,
222
+ "return_dict_in_generate": False,
223
+ "sep_token_id": None,
224
+ "task_specific_params": None,
225
+ "temperature": 1.0,
226
+ "tie_encoder_decoder": False,
227
+ "tie_word_embeddings": True,
228
+ "tokenizer_class": None,
229
+ "top_k": 50,
230
+ "top_p": 1.0,
231
+ "torch_dtype": None,
232
+ "torchscript": False,
233
+ "transformers_version": "4.16.0.dev0",
234
+ "use_bfloat16": False,
235
+ "vocab_size": 49408,
236
+ "hidden_act": "gelu",
237
+ "hidden_size": 1280,
238
+ "intermediate_size": 5120,
239
+ "num_attention_heads": 20,
240
+ "num_hidden_layers": 32,
241
+ # },
242
+ # "text_config_dict": {
243
+ "hidden_size": 768,
244
+ "intermediate_size": 3072,
245
+ "num_attention_heads": 12,
246
+ "num_hidden_layers": 12,
247
+ "projection_dim": 768,
248
+ # },
249
+ # "torch_dtype": "float32",
250
+ # "transformers_version": None,
251
+ }
252
+ config = CLIPConfig(**CLIPL_CONFIG)
253
+ with init_empty_weights():
254
+ clip = CLIPTextModel._from_config(config)
255
+
256
+ if state_dict is not None:
257
+ sd = state_dict
258
+ else:
259
+ logger.info(f"Loading state dict from {ckpt_path}")
260
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
261
+ info = clip.load_state_dict(sd, strict=False, assign=True)
262
+ logger.info(f"Loaded CLIP-L: {info}")
263
+ return clip
264
+
265
+
266
+ def load_t5xxl(
267
+ ckpt_path: str,
268
+ dtype: Optional[torch.dtype],
269
+ device: Union[str, torch.device],
270
+ disable_mmap: bool = False,
271
+ state_dict: Optional[dict] = None,
272
+ ) -> T5EncoderModel:
273
+ T5_CONFIG_JSON = """
274
+ {
275
+ "architectures": [
276
+ "T5EncoderModel"
277
+ ],
278
+ "classifier_dropout": 0.0,
279
+ "d_ff": 10240,
280
+ "d_kv": 64,
281
+ "d_model": 4096,
282
+ "decoder_start_token_id": 0,
283
+ "dense_act_fn": "gelu_new",
284
+ "dropout_rate": 0.1,
285
+ "eos_token_id": 1,
286
+ "feed_forward_proj": "gated-gelu",
287
+ "initializer_factor": 1.0,
288
+ "is_encoder_decoder": true,
289
+ "is_gated_act": true,
290
+ "layer_norm_epsilon": 1e-06,
291
+ "model_type": "t5",
292
+ "num_decoder_layers": 24,
293
+ "num_heads": 64,
294
+ "num_layers": 24,
295
+ "output_past": true,
296
+ "pad_token_id": 0,
297
+ "relative_attention_max_distance": 128,
298
+ "relative_attention_num_buckets": 32,
299
+ "tie_word_embeddings": false,
300
+ "torch_dtype": "float16",
301
+ "transformers_version": "4.41.2",
302
+ "use_cache": true,
303
+ "vocab_size": 32128
304
+ }
305
+ """
306
+ config = json.loads(T5_CONFIG_JSON)
307
+ config = T5Config(**config)
308
+ with init_empty_weights():
309
+ t5xxl = T5EncoderModel._from_config(config)
310
+
311
+ if state_dict is not None:
312
+ sd = state_dict
313
+ else:
314
+ logger.info(f"Loading state dict from {ckpt_path}")
315
+ sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
316
+ info = t5xxl.load_state_dict(sd, strict=False, assign=True)
317
+ logger.info(f"Loaded T5xxl: {info}")
318
+ return t5xxl
319
+
320
+
321
+ def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
322
+ # nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
323
+ return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
324
+
325
+
326
+ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
327
+ img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
328
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
329
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
330
+ img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
331
+ return img_ids
332
+
333
+
334
+ def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
335
+ """
336
+ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
337
+ """
338
+ x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
339
+ return x
340
+
341
+
342
+ def pack_latents(x: torch.Tensor) -> torch.Tensor:
343
+ """
344
+ x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
345
+ """
346
+ x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
347
+ return x
348
+
349
+
350
+ # region Diffusers
351
+
352
+ NUM_DOUBLE_BLOCKS = 19
353
+ NUM_SINGLE_BLOCKS = 38
354
+
355
+ BFL_TO_DIFFUSERS_MAP = {
356
+ "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
357
+ "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
358
+ "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
359
+ "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
360
+ "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
361
+ "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
362
+ "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
363
+ "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
364
+ "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
365
+ "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
366
+ "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
367
+ "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
368
+ "txt_in.weight": ["context_embedder.weight"],
369
+ "txt_in.bias": ["context_embedder.bias"],
370
+ "img_in.weight": ["x_embedder.weight"],
371
+ "img_in.bias": ["x_embedder.bias"],
372
+ "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
373
+ "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
374
+ "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
375
+ "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
376
+ "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
377
+ "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
378
+ "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
379
+ "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
380
+ "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
381
+ "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
382
+ "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
383
+ "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
384
+ "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
385
+ "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
386
+ "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
387
+ "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
388
+ "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
389
+ "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
390
+ "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
391
+ "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
392
+ "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
393
+ "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
394
+ "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
395
+ "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
396
+ "single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
397
+ "single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
398
+ "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
399
+ "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
400
+ "single_blocks.().linear2.weight": ["proj_out.weight"],
401
+ "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
402
+ "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
403
+ "single_blocks.().linear2.weight": ["proj_out.weight"],
404
+ "single_blocks.().linear2.bias": ["proj_out.bias"],
405
+ "final_layer.linear.weight": ["proj_out.weight"],
406
+ "final_layer.linear.bias": ["proj_out.bias"],
407
+ "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
408
+ "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
409
+ }
410
+
411
+
412
+ def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
413
+ # make reverse map from diffusers map
414
+ diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
415
+ for b in range(num_double_blocks):
416
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
417
+ if key.startswith("double_blocks."):
418
+ block_prefix = f"transformer_blocks.{b}."
419
+ for i, weight in enumerate(weights):
420
+ diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
421
+ for b in range(num_single_blocks):
422
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
423
+ if key.startswith("single_blocks."):
424
+ block_prefix = f"single_transformer_blocks.{b}."
425
+ for i, weight in enumerate(weights):
426
+ diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
427
+ for key, weights in BFL_TO_DIFFUSERS_MAP.items():
428
+ if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
429
+ for i, weight in enumerate(weights):
430
+ diffusers_to_bfl_map[weight] = (i, key)
431
+ return diffusers_to_bfl_map
432
+
433
+
434
+ def convert_diffusers_sd_to_bfl(
435
+ diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
436
+ ) -> dict[str, torch.Tensor]:
437
+ diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
438
+
439
+ # iterate over three safetensors files to reduce memory usage
440
+ flux_sd = {}
441
+ for diffusers_key, tensor in diffusers_sd.items():
442
+ if diffusers_key in diffusers_to_bfl_map:
443
+ index, bfl_key = diffusers_to_bfl_map[diffusers_key]
444
+ if bfl_key not in flux_sd:
445
+ flux_sd[bfl_key] = []
446
+ flux_sd[bfl_key].append((index, tensor))
447
+ else:
448
+ logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
449
+ raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
450
+
451
+ # concat tensors if multiple tensors are mapped to a single key, sort by index
452
+ for key, values in flux_sd.items():
453
+ if len(values) == 1:
454
+ flux_sd[key] = values[0][1]
455
+ else:
456
+ flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
457
+
458
+ # special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
459
+ def swap_scale_shift(weight):
460
+ shift, scale = weight.chunk(2, dim=0)
461
+ new_weight = torch.cat([scale, shift], dim=0)
462
+ return new_weight
463
+
464
+ if "final_layer.adaLN_modulation.1.weight" in flux_sd:
465
+ flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
466
+ if "final_layer.adaLN_modulation.1.bias" in flux_sd:
467
+ flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
468
+
469
+ return flux_sd
470
+
471
+
472
+ # endregion
library/huggingface_util.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, BinaryIO
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+ from library.utils import fire_in_thread
7
+ from library.utils import setup_logging
8
+ setup_logging()
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13
+ api = HfApi(
14
+ token=token,
15
+ )
16
+ try:
17
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18
+ return True
19
+ except:
20
+ return False
21
+
22
+
23
+ def upload(
24
+ args: argparse.Namespace,
25
+ src: Union[str, Path, bytes, BinaryIO],
26
+ dest_suffix: str = "",
27
+ force_sync_upload: bool = False,
28
+ ):
29
+ repo_id = args.huggingface_repo_id
30
+ repo_type = args.huggingface_repo_type
31
+ token = args.huggingface_token
32
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34
+ api = HfApi(token=token)
35
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36
+ try:
37
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38
+ except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39
+ logger.error("===========================================")
40
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41
+ logger.error("===========================================")
42
+
43
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44
+
45
+ def uploader():
46
+ try:
47
+ if is_folder:
48
+ api.upload_folder(
49
+ repo_id=repo_id,
50
+ repo_type=repo_type,
51
+ folder_path=src,
52
+ path_in_repo=path_in_repo,
53
+ )
54
+ else:
55
+ api.upload_file(
56
+ repo_id=repo_id,
57
+ repo_type=repo_type,
58
+ path_or_fileobj=src,
59
+ path_in_repo=path_in_repo,
60
+ )
61
+ except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62
+ logger.error("===========================================")
63
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64
+ logger.error("===========================================")
65
+
66
+ if args.async_upload and not force_sync_upload:
67
+ fire_in_thread(uploader)
68
+ else:
69
+ uploader()
70
+
71
+
72
+ def list_dir(
73
+ repo_id: str,
74
+ subfolder: str,
75
+ repo_type: str,
76
+ revision: str = "main",
77
+ token: str = None,
78
+ ):
79
+ api = HfApi(
80
+ token=token,
81
+ )
82
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84
+ return file_list
library/hypernetwork.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers.models.attention_processor import (
4
+ Attention,
5
+ AttnProcessor2_0,
6
+ SlicedAttnProcessor,
7
+ XFormersAttnProcessor
8
+ )
9
+
10
+ try:
11
+ import xformers.ops
12
+ except:
13
+ xformers = None
14
+
15
+
16
+ loaded_networks = []
17
+
18
+
19
+ def apply_single_hypernetwork(
20
+ hypernetwork, hidden_states, encoder_hidden_states
21
+ ):
22
+ context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
23
+ return context_k, context_v
24
+
25
+
26
+ def apply_hypernetworks(context_k, context_v, layer=None):
27
+ if len(loaded_networks) == 0:
28
+ return context_v, context_v
29
+ for hypernetwork in loaded_networks:
30
+ context_k, context_v = hypernetwork.forward(context_k, context_v)
31
+
32
+ context_k = context_k.to(dtype=context_k.dtype)
33
+ context_v = context_v.to(dtype=context_k.dtype)
34
+
35
+ return context_k, context_v
36
+
37
+
38
+
39
+ def xformers_forward(
40
+ self: XFormersAttnProcessor,
41
+ attn: Attention,
42
+ hidden_states: torch.Tensor,
43
+ encoder_hidden_states: torch.Tensor = None,
44
+ attention_mask: torch.Tensor = None,
45
+ ):
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape
48
+ if encoder_hidden_states is None
49
+ else encoder_hidden_states.shape
50
+ )
51
+
52
+ attention_mask = attn.prepare_attention_mask(
53
+ attention_mask, sequence_length, batch_size
54
+ )
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
64
+
65
+ key = attn.to_k(context_k)
66
+ value = attn.to_v(context_v)
67
+
68
+ query = attn.head_to_batch_dim(query).contiguous()
69
+ key = attn.head_to_batch_dim(key).contiguous()
70
+ value = attn.head_to_batch_dim(value).contiguous()
71
+
72
+ hidden_states = xformers.ops.memory_efficient_attention(
73
+ query,
74
+ key,
75
+ value,
76
+ attn_bias=attention_mask,
77
+ op=self.attention_op,
78
+ scale=attn.scale,
79
+ )
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ hidden_states = attn.batch_to_head_dim(hidden_states)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+ return hidden_states
88
+
89
+
90
+ def sliced_attn_forward(
91
+ self: SlicedAttnProcessor,
92
+ attn: Attention,
93
+ hidden_states: torch.Tensor,
94
+ encoder_hidden_states: torch.Tensor = None,
95
+ attention_mask: torch.Tensor = None,
96
+ ):
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+ attention_mask = attn.prepare_attention_mask(
103
+ attention_mask, sequence_length, batch_size
104
+ )
105
+
106
+ query = attn.to_q(hidden_states)
107
+ dim = query.shape[-1]
108
+ query = attn.head_to_batch_dim(query)
109
+
110
+ if encoder_hidden_states is None:
111
+ encoder_hidden_states = hidden_states
112
+ elif attn.norm_cross:
113
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
114
+
115
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
116
+
117
+ key = attn.to_k(context_k)
118
+ value = attn.to_v(context_v)
119
+ key = attn.head_to_batch_dim(key)
120
+ value = attn.head_to_batch_dim(value)
121
+
122
+ batch_size_attention, query_tokens, _ = query.shape
123
+ hidden_states = torch.zeros(
124
+ (batch_size_attention, query_tokens, dim // attn.heads),
125
+ device=query.device,
126
+ dtype=query.dtype,
127
+ )
128
+
129
+ for i in range(batch_size_attention // self.slice_size):
130
+ start_idx = i * self.slice_size
131
+ end_idx = (i + 1) * self.slice_size
132
+
133
+ query_slice = query[start_idx:end_idx]
134
+ key_slice = key[start_idx:end_idx]
135
+ attn_mask_slice = (
136
+ attention_mask[start_idx:end_idx] if attention_mask is not None else None
137
+ )
138
+
139
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
140
+
141
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
142
+
143
+ hidden_states[start_idx:end_idx] = attn_slice
144
+
145
+ hidden_states = attn.batch_to_head_dim(hidden_states)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ return hidden_states
153
+
154
+
155
+ def v2_0_forward(
156
+ self: AttnProcessor2_0,
157
+ attn: Attention,
158
+ hidden_states,
159
+ encoder_hidden_states=None,
160
+ attention_mask=None,
161
+ ):
162
+ batch_size, sequence_length, _ = (
163
+ hidden_states.shape
164
+ if encoder_hidden_states is None
165
+ else encoder_hidden_states.shape
166
+ )
167
+ inner_dim = hidden_states.shape[-1]
168
+
169
+ if attention_mask is not None:
170
+ attention_mask = attn.prepare_attention_mask(
171
+ attention_mask, sequence_length, batch_size
172
+ )
173
+ # scaled_dot_product_attention expects attention_mask shape to be
174
+ # (batch, heads, source_length, target_length)
175
+ attention_mask = attention_mask.view(
176
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
177
+ )
178
+
179
+ query = attn.to_q(hidden_states)
180
+
181
+ if encoder_hidden_states is None:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
187
+
188
+ key = attn.to_k(context_k)
189
+ value = attn.to_v(context_v)
190
+
191
+ head_dim = inner_dim // attn.heads
192
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
197
+ # TODO: add support for attn.scale when we move to Torch 2.1
198
+ hidden_states = F.scaled_dot_product_attention(
199
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
200
+ )
201
+
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(
203
+ batch_size, -1, attn.heads * head_dim
204
+ )
205
+ hidden_states = hidden_states.to(query.dtype)
206
+
207
+ # linear proj
208
+ hidden_states = attn.to_out[0](hidden_states)
209
+ # dropout
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+ return hidden_states
212
+
213
+
214
+ def replace_attentions_for_hypernetwork():
215
+ import diffusers.models.attention_processor
216
+
217
+ diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
218
+ xformers_forward
219
+ )
220
+ diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
221
+ sliced_attn_forward
222
+ )
223
+ diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
library/ipex/__init__.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import contextlib
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ from .hijacks import ipex_hijacks
7
+
8
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
9
+
10
+ def ipex_init(): # pylint: disable=too-many-statements
11
+ try:
12
+ if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
13
+ return True, "Skipping IPEX hijack"
14
+ else:
15
+ # Replace cuda with xpu:
16
+ torch.cuda.current_device = torch.xpu.current_device
17
+ torch.cuda.current_stream = torch.xpu.current_stream
18
+ torch.cuda.device = torch.xpu.device
19
+ torch.cuda.device_count = torch.xpu.device_count
20
+ torch.cuda.device_of = torch.xpu.device_of
21
+ torch.cuda.get_device_name = torch.xpu.get_device_name
22
+ torch.cuda.get_device_properties = torch.xpu.get_device_properties
23
+ torch.cuda.init = torch.xpu.init
24
+ torch.cuda.is_available = torch.xpu.is_available
25
+ torch.cuda.is_initialized = torch.xpu.is_initialized
26
+ torch.cuda.is_current_stream_capturing = lambda: False
27
+ torch.cuda.set_device = torch.xpu.set_device
28
+ torch.cuda.stream = torch.xpu.stream
29
+ torch.cuda.synchronize = torch.xpu.synchronize
30
+ torch.cuda.Event = torch.xpu.Event
31
+ torch.cuda.Stream = torch.xpu.Stream
32
+ torch.cuda.FloatTensor = torch.xpu.FloatTensor
33
+ torch.Tensor.cuda = torch.Tensor.xpu
34
+ torch.Tensor.is_cuda = torch.Tensor.is_xpu
35
+ torch.nn.Module.cuda = torch.nn.Module.xpu
36
+ torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
37
+ torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
38
+ torch.cuda._initialized = torch.xpu.lazy_init._initialized
39
+ torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
40
+ torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
41
+ torch.cuda._tls = torch.xpu.lazy_init._tls
42
+ torch.cuda.threading = torch.xpu.lazy_init.threading
43
+ torch.cuda.traceback = torch.xpu.lazy_init.traceback
44
+ torch.cuda.Optional = torch.xpu.Optional
45
+ torch.cuda.__cached__ = torch.xpu.__cached__
46
+ torch.cuda.__loader__ = torch.xpu.__loader__
47
+ torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
48
+ torch.cuda.Tuple = torch.xpu.Tuple
49
+ torch.cuda.streams = torch.xpu.streams
50
+ torch.cuda._lazy_new = torch.xpu._lazy_new
51
+ torch.cuda.FloatStorage = torch.xpu.FloatStorage
52
+ torch.cuda.Any = torch.xpu.Any
53
+ torch.cuda.__doc__ = torch.xpu.__doc__
54
+ torch.cuda.default_generators = torch.xpu.default_generators
55
+ torch.cuda.HalfTensor = torch.xpu.HalfTensor
56
+ torch.cuda._get_device_index = torch.xpu._get_device_index
57
+ torch.cuda.__path__ = torch.xpu.__path__
58
+ torch.cuda.Device = torch.xpu.Device
59
+ torch.cuda.IntTensor = torch.xpu.IntTensor
60
+ torch.cuda.ByteStorage = torch.xpu.ByteStorage
61
+ torch.cuda.set_stream = torch.xpu.set_stream
62
+ torch.cuda.BoolStorage = torch.xpu.BoolStorage
63
+ torch.cuda.os = torch.xpu.os
64
+ torch.cuda.torch = torch.xpu.torch
65
+ torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
66
+ torch.cuda.Union = torch.xpu.Union
67
+ torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
68
+ torch.cuda.ShortTensor = torch.xpu.ShortTensor
69
+ torch.cuda.LongTensor = torch.xpu.LongTensor
70
+ torch.cuda.IntStorage = torch.xpu.IntStorage
71
+ torch.cuda.LongStorage = torch.xpu.LongStorage
72
+ torch.cuda.__annotations__ = torch.xpu.__annotations__
73
+ torch.cuda.__package__ = torch.xpu.__package__
74
+ torch.cuda.__builtins__ = torch.xpu.__builtins__
75
+ torch.cuda.CharTensor = torch.xpu.CharTensor
76
+ torch.cuda.List = torch.xpu.List
77
+ torch.cuda._lazy_init = torch.xpu._lazy_init
78
+ torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
79
+ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
80
+ torch.cuda.ByteTensor = torch.xpu.ByteTensor
81
+ torch.cuda.StreamContext = torch.xpu.StreamContext
82
+ torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
83
+ torch.cuda.ShortStorage = torch.xpu.ShortStorage
84
+ torch.cuda._lazy_call = torch.xpu._lazy_call
85
+ torch.cuda.HalfStorage = torch.xpu.HalfStorage
86
+ torch.cuda.random = torch.xpu.random
87
+ torch.cuda._device = torch.xpu._device
88
+ torch.cuda.classproperty = torch.xpu.classproperty
89
+ torch.cuda.__name__ = torch.xpu.__name__
90
+ torch.cuda._device_t = torch.xpu._device_t
91
+ torch.cuda.warnings = torch.xpu.warnings
92
+ torch.cuda.__spec__ = torch.xpu.__spec__
93
+ torch.cuda.BoolTensor = torch.xpu.BoolTensor
94
+ torch.cuda.CharStorage = torch.xpu.CharStorage
95
+ torch.cuda.__file__ = torch.xpu.__file__
96
+ torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
97
+ # torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
98
+
99
+ # Memory:
100
+ torch.cuda.memory = torch.xpu.memory
101
+ if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
102
+ torch.xpu.empty_cache = lambda: None
103
+ torch.cuda.empty_cache = torch.xpu.empty_cache
104
+ torch.cuda.memory_stats = torch.xpu.memory_stats
105
+ torch.cuda.memory_summary = torch.xpu.memory_summary
106
+ torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
107
+ torch.cuda.memory_allocated = torch.xpu.memory_allocated
108
+ torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
109
+ torch.cuda.memory_reserved = torch.xpu.memory_reserved
110
+ torch.cuda.memory_cached = torch.xpu.memory_reserved
111
+ torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
112
+ torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
113
+ torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
114
+ torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
115
+ torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
116
+ torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
117
+ torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
118
+
119
+ # RNG:
120
+ torch.cuda.get_rng_state = torch.xpu.get_rng_state
121
+ torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
122
+ torch.cuda.set_rng_state = torch.xpu.set_rng_state
123
+ torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
124
+ torch.cuda.manual_seed = torch.xpu.manual_seed
125
+ torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
126
+ torch.cuda.seed = torch.xpu.seed
127
+ torch.cuda.seed_all = torch.xpu.seed_all
128
+ torch.cuda.initial_seed = torch.xpu.initial_seed
129
+
130
+ # AMP:
131
+ torch.cuda.amp = torch.xpu.amp
132
+ torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
133
+ torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
134
+
135
+ if not hasattr(torch.cuda.amp, "common"):
136
+ torch.cuda.amp.common = contextlib.nullcontext()
137
+ torch.cuda.amp.common.amp_definitely_not_available = lambda: False
138
+
139
+ try:
140
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
141
+ except Exception: # pylint: disable=broad-exception-caught
142
+ try:
143
+ from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
144
+ gradscaler_init()
145
+ torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
146
+ except Exception: # pylint: disable=broad-exception-caught
147
+ torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
148
+
149
+ # C
150
+ torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
151
+ ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
152
+ ipex._C._DeviceProperties.major = 2024
153
+ ipex._C._DeviceProperties.minor = 0
154
+
155
+ # Fix functions with ipex:
156
+ torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
157
+ torch._utils._get_available_device_type = lambda: "xpu"
158
+ torch.has_cuda = True
159
+ torch.cuda.has_half = True
160
+ torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
161
+ torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
162
+ torch.backends.cuda.is_built = lambda *args, **kwargs: True
163
+ torch.version.cuda = "12.1"
164
+ torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
165
+ torch.cuda.get_device_properties.major = 12
166
+ torch.cuda.get_device_properties.minor = 1
167
+ torch.cuda.ipc_collect = lambda *args, **kwargs: None
168
+ torch.cuda.utilization = lambda *args, **kwargs: 0
169
+
170
+ ipex_hijacks()
171
+ if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
172
+ try:
173
+ from .diffusers import ipex_diffusers
174
+ ipex_diffusers()
175
+ except Exception: # pylint: disable=broad-exception-caught
176
+ pass
177
+ torch.cuda.is_xpu_hijacked = True
178
+ except Exception as e:
179
+ return False, e
180
+ return True, None
library/ipex/attention.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ from functools import cache
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
9
+
10
+ sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ # Find something divisible with the input_tokens
14
+ @cache
15
+ def find_slice_size(slice_size, slice_block_size):
16
+ while (slice_size * slice_block_size) > attention_slice_rate:
17
+ slice_size = slice_size // 2
18
+ if slice_size <= 1:
19
+ slice_size = 1
20
+ break
21
+ return slice_size
22
+
23
+ # Find slice sizes for SDPA
24
+ @cache
25
+ def find_sdpa_slice_sizes(query_shape, query_element_size):
26
+ if len(query_shape) == 3:
27
+ batch_size_attention, query_tokens, shape_three = query_shape
28
+ shape_four = 1
29
+ else:
30
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if block_size > sdpa_slice_trigger_rate:
44
+ do_split = True
45
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
46
+ if split_slice_size * slice_block_size > attention_slice_rate:
47
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
48
+ do_split_2 = True
49
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
50
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
51
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
52
+ do_split_3 = True
53
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
54
+
55
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
56
+
57
+ # Find slice sizes for BMM
58
+ @cache
59
+ def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
60
+ batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
61
+ slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
62
+ block_size = batch_size_attention * slice_block_size
63
+
64
+ split_slice_size = batch_size_attention
65
+ split_2_slice_size = input_tokens
66
+ split_3_slice_size = mat2_atten_shape
67
+
68
+ do_split = False
69
+ do_split_2 = False
70
+ do_split_3 = False
71
+
72
+ if block_size > attention_slice_rate:
73
+ do_split = True
74
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
75
+ if split_slice_size * slice_block_size > attention_slice_rate:
76
+ slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
77
+ do_split_2 = True
78
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
79
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
80
+ slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
81
+ do_split_3 = True
82
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
83
+
84
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
85
+
86
+
87
+ original_torch_bmm = torch.bmm
88
+ def torch_bmm_32_bit(input, mat2, *, out=None):
89
+ if input.device.type != "xpu":
90
+ return original_torch_bmm(input, mat2, out=out)
91
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
92
+
93
+ # Slice BMM
94
+ if do_split:
95
+ batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
96
+ hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
97
+ for i in range(batch_size_attention // split_slice_size):
98
+ start_idx = i * split_slice_size
99
+ end_idx = (i + 1) * split_slice_size
100
+ if do_split_2:
101
+ for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
102
+ start_idx_2 = i2 * split_2_slice_size
103
+ end_idx_2 = (i2 + 1) * split_2_slice_size
104
+ if do_split_3:
105
+ for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
106
+ start_idx_3 = i3 * split_3_slice_size
107
+ end_idx_3 = (i3 + 1) * split_3_slice_size
108
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
109
+ input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
110
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
111
+ out=out
112
+ )
113
+ else:
114
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
115
+ input[start_idx:end_idx, start_idx_2:end_idx_2],
116
+ mat2[start_idx:end_idx, start_idx_2:end_idx_2],
117
+ out=out
118
+ )
119
+ else:
120
+ hidden_states[start_idx:end_idx] = original_torch_bmm(
121
+ input[start_idx:end_idx],
122
+ mat2[start_idx:end_idx],
123
+ out=out
124
+ )
125
+ torch.xpu.synchronize(input.device)
126
+ else:
127
+ return original_torch_bmm(input, mat2, out=out)
128
+ return hidden_states
129
+
130
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
131
+ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
132
+ if query.device.type != "xpu":
133
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
134
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
135
+
136
+ # Slice SDPA
137
+ if do_split:
138
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
139
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
140
+ for i in range(batch_size_attention // split_slice_size):
141
+ start_idx = i * split_slice_size
142
+ end_idx = (i + 1) * split_slice_size
143
+ if do_split_2:
144
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
145
+ start_idx_2 = i2 * split_2_slice_size
146
+ end_idx_2 = (i2 + 1) * split_2_slice_size
147
+ if do_split_3:
148
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
149
+ start_idx_3 = i3 * split_3_slice_size
150
+ end_idx_3 = (i3 + 1) * split_3_slice_size
151
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
152
+ query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
153
+ key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
154
+ value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
155
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
156
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
157
+ )
158
+ else:
159
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
160
+ query[start_idx:end_idx, start_idx_2:end_idx_2],
161
+ key[start_idx:end_idx, start_idx_2:end_idx_2],
162
+ value[start_idx:end_idx, start_idx_2:end_idx_2],
163
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
164
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
165
+ )
166
+ else:
167
+ hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
168
+ query[start_idx:end_idx],
169
+ key[start_idx:end_idx],
170
+ value[start_idx:end_idx],
171
+ attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
172
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
173
+ )
174
+ torch.xpu.synchronize(query.device)
175
+ else:
176
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
177
+ return hidden_states
library/ipex/diffusers.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import diffusers #0.24.0 # pylint: disable=import-error
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.utils import USE_PEFT_BACKEND
7
+ from functools import cache
8
+
9
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
10
+
11
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
12
+
13
+ @cache
14
+ def find_slice_size(slice_size, slice_block_size):
15
+ while (slice_size * slice_block_size) > attention_slice_rate:
16
+ slice_size = slice_size // 2
17
+ if slice_size <= 1:
18
+ slice_size = 1
19
+ break
20
+ return slice_size
21
+
22
+ @cache
23
+ def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
24
+ if len(query_shape) == 3:
25
+ batch_size_attention, query_tokens, shape_three = query_shape
26
+ shape_four = 1
27
+ else:
28
+ batch_size_attention, query_tokens, shape_three, shape_four = query_shape
29
+ if slice_size is not None:
30
+ batch_size_attention = slice_size
31
+
32
+ slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
33
+ block_size = batch_size_attention * slice_block_size
34
+
35
+ split_slice_size = batch_size_attention
36
+ split_2_slice_size = query_tokens
37
+ split_3_slice_size = shape_three
38
+
39
+ do_split = False
40
+ do_split_2 = False
41
+ do_split_3 = False
42
+
43
+ if query_device_type != "xpu":
44
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
45
+
46
+ if block_size > attention_slice_rate:
47
+ do_split = True
48
+ split_slice_size = find_slice_size(split_slice_size, slice_block_size)
49
+ if split_slice_size * slice_block_size > attention_slice_rate:
50
+ slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
51
+ do_split_2 = True
52
+ split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
53
+ if split_2_slice_size * slice_2_block_size > attention_slice_rate:
54
+ slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
55
+ do_split_3 = True
56
+ split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
57
+
58
+ return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
59
+
60
+ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
61
+ r"""
62
+ Processor for implementing sliced attention.
63
+
64
+ Args:
65
+ slice_size (`int`, *optional*):
66
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
67
+ `attention_head_dim` must be a multiple of the `slice_size`.
68
+ """
69
+
70
+ def __init__(self, slice_size):
71
+ self.slice_size = slice_size
72
+
73
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
74
+ encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
75
+
76
+ residual = hidden_states
77
+
78
+ input_ndim = hidden_states.ndim
79
+
80
+ if input_ndim == 4:
81
+ batch_size, channel, height, width = hidden_states.shape
82
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
83
+
84
+ batch_size, sequence_length, _ = (
85
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
86
+ )
87
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
88
+
89
+ if attn.group_norm is not None:
90
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
91
+
92
+ query = attn.to_q(hidden_states)
93
+ dim = query.shape[-1]
94
+ query = attn.head_to_batch_dim(query)
95
+
96
+ if encoder_hidden_states is None:
97
+ encoder_hidden_states = hidden_states
98
+ elif attn.norm_cross:
99
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
100
+
101
+ key = attn.to_k(encoder_hidden_states)
102
+ value = attn.to_v(encoder_hidden_states)
103
+ key = attn.head_to_batch_dim(key)
104
+ value = attn.head_to_batch_dim(value)
105
+
106
+ batch_size_attention, query_tokens, shape_three = query.shape
107
+ hidden_states = torch.zeros(
108
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
109
+ )
110
+
111
+ ####################################################################
112
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
113
+ _, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
114
+
115
+ for i in range(batch_size_attention // split_slice_size):
116
+ start_idx = i * split_slice_size
117
+ end_idx = (i + 1) * split_slice_size
118
+ if do_split_2:
119
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
120
+ start_idx_2 = i2 * split_2_slice_size
121
+ end_idx_2 = (i2 + 1) * split_2_slice_size
122
+ if do_split_3:
123
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
124
+ start_idx_3 = i3 * split_3_slice_size
125
+ end_idx_3 = (i3 + 1) * split_3_slice_size
126
+
127
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
128
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
129
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
130
+
131
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
132
+ del query_slice
133
+ del key_slice
134
+ del attn_mask_slice
135
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
136
+
137
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
138
+ del attn_slice
139
+ else:
140
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
141
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
142
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
143
+
144
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
145
+ del query_slice
146
+ del key_slice
147
+ del attn_mask_slice
148
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
149
+
150
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
151
+ del attn_slice
152
+ torch.xpu.synchronize(query.device)
153
+ else:
154
+ query_slice = query[start_idx:end_idx]
155
+ key_slice = key[start_idx:end_idx]
156
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
157
+
158
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
159
+ del query_slice
160
+ del key_slice
161
+ del attn_mask_slice
162
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
163
+
164
+ hidden_states[start_idx:end_idx] = attn_slice
165
+ del attn_slice
166
+ ####################################################################
167
+
168
+ hidden_states = attn.batch_to_head_dim(hidden_states)
169
+
170
+ # linear proj
171
+ hidden_states = attn.to_out[0](hidden_states)
172
+ # dropout
173
+ hidden_states = attn.to_out[1](hidden_states)
174
+
175
+ if input_ndim == 4:
176
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
177
+
178
+ if attn.residual_connection:
179
+ hidden_states = hidden_states + residual
180
+
181
+ hidden_states = hidden_states / attn.rescale_output_factor
182
+
183
+ return hidden_states
184
+
185
+
186
+ class AttnProcessor:
187
+ r"""
188
+ Default processor for performing attention-related computations.
189
+ """
190
+
191
+ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
192
+ encoder_hidden_states=None, attention_mask=None,
193
+ temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
194
+
195
+ residual = hidden_states
196
+
197
+ args = () if USE_PEFT_BACKEND else (scale,)
198
+
199
+ if attn.spatial_norm is not None:
200
+ hidden_states = attn.spatial_norm(hidden_states, temb)
201
+
202
+ input_ndim = hidden_states.ndim
203
+
204
+ if input_ndim == 4:
205
+ batch_size, channel, height, width = hidden_states.shape
206
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
207
+
208
+ batch_size, sequence_length, _ = (
209
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
210
+ )
211
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
212
+
213
+ if attn.group_norm is not None:
214
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
215
+
216
+ query = attn.to_q(hidden_states, *args)
217
+
218
+ if encoder_hidden_states is None:
219
+ encoder_hidden_states = hidden_states
220
+ elif attn.norm_cross:
221
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
222
+
223
+ key = attn.to_k(encoder_hidden_states, *args)
224
+ value = attn.to_v(encoder_hidden_states, *args)
225
+
226
+ query = attn.head_to_batch_dim(query)
227
+ key = attn.head_to_batch_dim(key)
228
+ value = attn.head_to_batch_dim(value)
229
+
230
+ ####################################################################
231
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
232
+ batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
233
+ hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
234
+ do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
235
+
236
+ if do_split:
237
+ for i in range(batch_size_attention // split_slice_size):
238
+ start_idx = i * split_slice_size
239
+ end_idx = (i + 1) * split_slice_size
240
+ if do_split_2:
241
+ for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
242
+ start_idx_2 = i2 * split_2_slice_size
243
+ end_idx_2 = (i2 + 1) * split_2_slice_size
244
+ if do_split_3:
245
+ for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
246
+ start_idx_3 = i3 * split_3_slice_size
247
+ end_idx_3 = (i3 + 1) * split_3_slice_size
248
+
249
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
250
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
251
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
252
+
253
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
254
+ del query_slice
255
+ del key_slice
256
+ del attn_mask_slice
257
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
258
+
259
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
260
+ del attn_slice
261
+ else:
262
+ query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
263
+ key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
264
+ attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
265
+
266
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
267
+ del query_slice
268
+ del key_slice
269
+ del attn_mask_slice
270
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
271
+
272
+ hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
273
+ del attn_slice
274
+ else:
275
+ query_slice = query[start_idx:end_idx]
276
+ key_slice = key[start_idx:end_idx]
277
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
278
+
279
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
280
+ del query_slice
281
+ del key_slice
282
+ del attn_mask_slice
283
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
284
+
285
+ hidden_states[start_idx:end_idx] = attn_slice
286
+ del attn_slice
287
+ torch.xpu.synchronize(query.device)
288
+ else:
289
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
290
+ hidden_states = torch.bmm(attention_probs, value)
291
+ ####################################################################
292
+ hidden_states = attn.batch_to_head_dim(hidden_states)
293
+
294
+ # linear proj
295
+ hidden_states = attn.to_out[0](hidden_states, *args)
296
+ # dropout
297
+ hidden_states = attn.to_out[1](hidden_states)
298
+
299
+ if input_ndim == 4:
300
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
301
+
302
+ if attn.residual_connection:
303
+ hidden_states = hidden_states + residual
304
+
305
+ hidden_states = hidden_states / attn.rescale_output_factor
306
+
307
+ return hidden_states
308
+
309
+ def ipex_diffusers():
310
+ #ARC GPUs can't allocate more than 4GB to a single block:
311
+ diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
312
+ diffusers.models.attention_processor.AttnProcessor = AttnProcessor
library/ipex/gradscaler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
12
+
13
+ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
14
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
15
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
16
+
17
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
18
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
19
+ # However, we don't know their devices or dtypes in advance.
20
+
21
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
22
+ # Google says mypy struggles with defaultdicts type annotations.
23
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
24
+ # sync grad to master weight
25
+ if hasattr(optimizer, "sync_grad"):
26
+ optimizer.sync_grad()
27
+ with torch.no_grad():
28
+ for group in optimizer.param_groups:
29
+ for param in group["params"]:
30
+ if param.grad is None:
31
+ continue
32
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
33
+ raise ValueError("Attempting to unscale FP16 gradients.")
34
+ if param.grad.is_sparse:
35
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
36
+ # coalesce() deduplicates indices and adds all values that have the same index.
37
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
38
+ # so we should check the coalesced _values().
39
+ if param.grad.dtype is torch.float16:
40
+ param.grad = param.grad.coalesce()
41
+ to_unscale = param.grad._values()
42
+ else:
43
+ to_unscale = param.grad
44
+
45
+ # -: is there a way to split by device and dtype without appending in the inner loop?
46
+ to_unscale = to_unscale.to("cpu")
47
+ per_device_and_dtype_grads[to_unscale.device][
48
+ to_unscale.dtype
49
+ ].append(to_unscale)
50
+
51
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
52
+ for grads in per_dtype_grads.values():
53
+ core._amp_foreach_non_finite_check_and_unscale_(
54
+ grads,
55
+ per_device_found_inf.get("cpu"),
56
+ per_device_inv_scale.get("cpu"),
57
+ )
58
+
59
+ return per_device_found_inf._per_device_tensors
60
+
61
+ def unscale_(self, optimizer):
62
+ """
63
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
64
+ :meth:`unscale_` is optional, serving cases where you need to
65
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
66
+ between the backward pass(es) and :meth:`step`.
67
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
68
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
69
+ ...
70
+ scaler.scale(loss).backward()
71
+ scaler.unscale_(optimizer)
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
73
+ scaler.step(optimizer)
74
+ scaler.update()
75
+ Args:
76
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
77
+ .. warning::
78
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
79
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
80
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
81
+ .. warning::
82
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
83
+ """
84
+ if not self._enabled:
85
+ return
86
+
87
+ self._check_scale_growth_tracker("unscale_")
88
+
89
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
90
+
91
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
92
+ raise RuntimeError(
93
+ "unscale_() has already been called on this optimizer since the last update()."
94
+ )
95
+ elif optimizer_state["stage"] is OptState.STEPPED:
96
+ raise RuntimeError("unscale_() is being called after step().")
97
+
98
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
99
+ assert self._scale is not None
100
+ if device_supports_fp64:
101
+ inv_scale = self._scale.double().reciprocal().float()
102
+ else:
103
+ inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
104
+ found_inf = torch.full(
105
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
106
+ )
107
+
108
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
109
+ optimizer, inv_scale, found_inf, False
110
+ )
111
+ optimizer_state["stage"] = OptState.UNSCALED
112
+
113
+ def update(self, new_scale=None):
114
+ """
115
+ Updates the scale factor.
116
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
117
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
118
+ the scale is multiplied by ``growth_factor`` to increase it.
119
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
120
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
121
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
122
+ affect the scale GradScaler uses internally.)
123
+ Args:
124
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
125
+ .. warning::
126
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
127
+ been invoked for all optimizers used this iteration.
128
+ """
129
+ if not self._enabled:
130
+ return
131
+
132
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
133
+
134
+ if new_scale is not None:
135
+ # Accept a new user-defined scale.
136
+ if isinstance(new_scale, float):
137
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
138
+ else:
139
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
140
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
141
+ assert new_scale.numel() == 1, reason
142
+ assert new_scale.requires_grad is False, reason
143
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
144
+ else:
145
+ # Consume shared inf/nan data collected from optimizers to update the scale.
146
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
147
+ found_infs = [
148
+ found_inf.to(device="cpu", non_blocking=True)
149
+ for state in self._per_optimizer_states.values()
150
+ for found_inf in state["found_inf_per_device"].values()
151
+ ]
152
+
153
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
154
+
155
+ found_inf_combined = found_infs[0]
156
+ if len(found_infs) > 1:
157
+ for i in range(1, len(found_infs)):
158
+ found_inf_combined += found_infs[i]
159
+
160
+ to_device = _scale.device
161
+ _scale = _scale.to("cpu")
162
+ _growth_tracker = _growth_tracker.to("cpu")
163
+
164
+ core._amp_update_scale_(
165
+ _scale,
166
+ _growth_tracker,
167
+ found_inf_combined,
168
+ self._growth_factor,
169
+ self._backoff_factor,
170
+ self._growth_interval,
171
+ )
172
+
173
+ _scale = _scale.to(to_device)
174
+ _growth_tracker = _growth_tracker.to(to_device)
175
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
176
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
177
+
178
+ def gradscaler_init():
179
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
180
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
181
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
182
+ torch.xpu.amp.GradScaler.update = update
183
+ return torch.xpu.amp.GradScaler
library/ipex/hijacks.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import wraps
3
+ from contextlib import nullcontext
4
+ import torch
5
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
6
+ import numpy as np
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype()
9
+
10
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
11
+
12
+ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
13
+ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
14
+ if isinstance(device_ids, list) and len(device_ids) > 1:
15
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
16
+ return module.to("xpu")
17
+
18
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
19
+ return nullcontext()
20
+
21
+ @property
22
+ def is_cuda(self):
23
+ return self.device.type == 'xpu' or self.device.type == 'cuda'
24
+
25
+ def check_device(device):
26
+ return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
27
+
28
+ def return_xpu(device):
29
+ return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
30
+
31
+
32
+ # Autocast
33
+ original_autocast_init = torch.amp.autocast_mode.autocast.__init__
34
+ @wraps(torch.amp.autocast_mode.autocast.__init__)
35
+ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
36
+ if device_type == "cuda":
37
+ return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
38
+ else:
39
+ return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
40
+
41
+ # Latent Antialias CPU Offload:
42
+ original_interpolate = torch.nn.functional.interpolate
43
+ @wraps(torch.nn.functional.interpolate)
44
+ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
45
+ if antialias or align_corners is not None or mode == 'bicubic':
46
+ return_device = tensor.device
47
+ return_dtype = tensor.dtype
48
+ return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
49
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
50
+ else:
51
+ return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
52
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
53
+
54
+
55
+ # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
56
+ original_from_numpy = torch.from_numpy
57
+ @wraps(torch.from_numpy)
58
+ def from_numpy(ndarray):
59
+ if ndarray.dtype == float:
60
+ return original_from_numpy(ndarray.astype('float32'))
61
+ else:
62
+ return original_from_numpy(ndarray)
63
+
64
+ original_as_tensor = torch.as_tensor
65
+ @wraps(torch.as_tensor)
66
+ def as_tensor(data, dtype=None, device=None):
67
+ if check_device(device):
68
+ device = return_xpu(device)
69
+ if isinstance(data, np.ndarray) and data.dtype == float and not (
70
+ (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
71
+ return original_as_tensor(data, dtype=torch.float32, device=device)
72
+ else:
73
+ return original_as_tensor(data, dtype=dtype, device=device)
74
+
75
+
76
+ if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
77
+ original_torch_bmm = torch.bmm
78
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
79
+ else:
80
+ # 32 bit attention workarounds for Alchemist:
81
+ try:
82
+ from .attention import torch_bmm_32_bit as original_torch_bmm
83
+ from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
84
+ except Exception: # pylint: disable=broad-exception-caught
85
+ original_torch_bmm = torch.bmm
86
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
87
+
88
+
89
+ # Data Type Errors:
90
+ @wraps(torch.bmm)
91
+ def torch_bmm(input, mat2, *, out=None):
92
+ if input.dtype != mat2.dtype:
93
+ mat2 = mat2.to(input.dtype)
94
+ return original_torch_bmm(input, mat2, out=out)
95
+
96
+ @wraps(torch.nn.functional.scaled_dot_product_attention)
97
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
98
+ if query.dtype != key.dtype:
99
+ key = key.to(dtype=query.dtype)
100
+ if query.dtype != value.dtype:
101
+ value = value.to(dtype=query.dtype)
102
+ if attn_mask is not None and query.dtype != attn_mask.dtype:
103
+ attn_mask = attn_mask.to(dtype=query.dtype)
104
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
105
+
106
+ # A1111 FP16
107
+ original_functional_group_norm = torch.nn.functional.group_norm
108
+ @wraps(torch.nn.functional.group_norm)
109
+ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
110
+ if weight is not None and input.dtype != weight.data.dtype:
111
+ input = input.to(dtype=weight.data.dtype)
112
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
113
+ bias.data = bias.data.to(dtype=weight.data.dtype)
114
+ return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
115
+
116
+ # A1111 BF16
117
+ original_functional_layer_norm = torch.nn.functional.layer_norm
118
+ @wraps(torch.nn.functional.layer_norm)
119
+ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
120
+ if weight is not None and input.dtype != weight.data.dtype:
121
+ input = input.to(dtype=weight.data.dtype)
122
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
123
+ bias.data = bias.data.to(dtype=weight.data.dtype)
124
+ return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
125
+
126
+ # Training
127
+ original_functional_linear = torch.nn.functional.linear
128
+ @wraps(torch.nn.functional.linear)
129
+ def functional_linear(input, weight, bias=None):
130
+ if input.dtype != weight.data.dtype:
131
+ input = input.to(dtype=weight.data.dtype)
132
+ if bias is not None and bias.data.dtype != weight.data.dtype:
133
+ bias.data = bias.data.to(dtype=weight.data.dtype)
134
+ return original_functional_linear(input, weight, bias=bias)
135
+
136
+ original_functional_conv2d = torch.nn.functional.conv2d
137
+ @wraps(torch.nn.functional.conv2d)
138
+ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
139
+ if input.dtype != weight.data.dtype:
140
+ input = input.to(dtype=weight.data.dtype)
141
+ if bias is not None and bias.data.dtype != weight.data.dtype:
142
+ bias.data = bias.data.to(dtype=weight.data.dtype)
143
+ return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
144
+
145
+ # A1111 Embedding BF16
146
+ original_torch_cat = torch.cat
147
+ @wraps(torch.cat)
148
+ def torch_cat(tensor, *args, **kwargs):
149
+ if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
150
+ return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
151
+ else:
152
+ return original_torch_cat(tensor, *args, **kwargs)
153
+
154
+ # SwinIR BF16:
155
+ original_functional_pad = torch.nn.functional.pad
156
+ @wraps(torch.nn.functional.pad)
157
+ def functional_pad(input, pad, mode='constant', value=None):
158
+ if mode == 'reflect' and input.dtype == torch.bfloat16:
159
+ return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
160
+ else:
161
+ return original_functional_pad(input, pad, mode=mode, value=value)
162
+
163
+
164
+ original_torch_tensor = torch.tensor
165
+ @wraps(torch.tensor)
166
+ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
167
+ if check_device(device):
168
+ device = return_xpu(device)
169
+ if not device_supports_fp64:
170
+ if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
171
+ if dtype == torch.float64:
172
+ dtype = torch.float32
173
+ elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
174
+ dtype = torch.float32
175
+ return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
176
+
177
+ original_Tensor_to = torch.Tensor.to
178
+ @wraps(torch.Tensor.to)
179
+ def Tensor_to(self, device=None, *args, **kwargs):
180
+ if check_device(device):
181
+ return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
182
+ else:
183
+ return original_Tensor_to(self, device, *args, **kwargs)
184
+
185
+ original_Tensor_cuda = torch.Tensor.cuda
186
+ @wraps(torch.Tensor.cuda)
187
+ def Tensor_cuda(self, device=None, *args, **kwargs):
188
+ if check_device(device):
189
+ return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
190
+ else:
191
+ return original_Tensor_cuda(self, device, *args, **kwargs)
192
+
193
+ original_Tensor_pin_memory = torch.Tensor.pin_memory
194
+ @wraps(torch.Tensor.pin_memory)
195
+ def Tensor_pin_memory(self, device=None, *args, **kwargs):
196
+ if device is None:
197
+ device = "xpu"
198
+ if check_device(device):
199
+ return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
200
+ else:
201
+ return original_Tensor_pin_memory(self, device, *args, **kwargs)
202
+
203
+ original_UntypedStorage_init = torch.UntypedStorage.__init__
204
+ @wraps(torch.UntypedStorage.__init__)
205
+ def UntypedStorage_init(*args, device=None, **kwargs):
206
+ if check_device(device):
207
+ return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
208
+ else:
209
+ return original_UntypedStorage_init(*args, device=device, **kwargs)
210
+
211
+ original_UntypedStorage_cuda = torch.UntypedStorage.cuda
212
+ @wraps(torch.UntypedStorage.cuda)
213
+ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
214
+ if check_device(device):
215
+ return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
216
+ else:
217
+ return original_UntypedStorage_cuda(self, device, *args, **kwargs)
218
+
219
+ original_torch_empty = torch.empty
220
+ @wraps(torch.empty)
221
+ def torch_empty(*args, device=None, **kwargs):
222
+ if check_device(device):
223
+ return original_torch_empty(*args, device=return_xpu(device), **kwargs)
224
+ else:
225
+ return original_torch_empty(*args, device=device, **kwargs)
226
+
227
+ original_torch_randn = torch.randn
228
+ @wraps(torch.randn)
229
+ def torch_randn(*args, device=None, dtype=None, **kwargs):
230
+ if dtype == bytes:
231
+ dtype = None
232
+ if check_device(device):
233
+ return original_torch_randn(*args, device=return_xpu(device), **kwargs)
234
+ else:
235
+ return original_torch_randn(*args, device=device, **kwargs)
236
+
237
+ original_torch_ones = torch.ones
238
+ @wraps(torch.ones)
239
+ def torch_ones(*args, device=None, **kwargs):
240
+ if check_device(device):
241
+ return original_torch_ones(*args, device=return_xpu(device), **kwargs)
242
+ else:
243
+ return original_torch_ones(*args, device=device, **kwargs)
244
+
245
+ original_torch_zeros = torch.zeros
246
+ @wraps(torch.zeros)
247
+ def torch_zeros(*args, device=None, **kwargs):
248
+ if check_device(device):
249
+ return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
250
+ else:
251
+ return original_torch_zeros(*args, device=device, **kwargs)
252
+
253
+ original_torch_linspace = torch.linspace
254
+ @wraps(torch.linspace)
255
+ def torch_linspace(*args, device=None, **kwargs):
256
+ if check_device(device):
257
+ return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
258
+ else:
259
+ return original_torch_linspace(*args, device=device, **kwargs)
260
+
261
+ original_torch_Generator = torch.Generator
262
+ @wraps(torch.Generator)
263
+ def torch_Generator(device=None):
264
+ if check_device(device):
265
+ return original_torch_Generator(return_xpu(device))
266
+ else:
267
+ return original_torch_Generator(device)
268
+
269
+ original_torch_load = torch.load
270
+ @wraps(torch.load)
271
+ def torch_load(f, map_location=None, *args, **kwargs):
272
+ if map_location is None:
273
+ map_location = "xpu"
274
+ if check_device(map_location):
275
+ return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
276
+ else:
277
+ return original_torch_load(f, *args, map_location=map_location, **kwargs)
278
+
279
+
280
+ # Hijack Functions:
281
+ def ipex_hijacks():
282
+ torch.tensor = torch_tensor
283
+ torch.Tensor.to = Tensor_to
284
+ torch.Tensor.cuda = Tensor_cuda
285
+ torch.Tensor.pin_memory = Tensor_pin_memory
286
+ torch.UntypedStorage.__init__ = UntypedStorage_init
287
+ torch.UntypedStorage.cuda = UntypedStorage_cuda
288
+ torch.empty = torch_empty
289
+ torch.randn = torch_randn
290
+ torch.ones = torch_ones
291
+ torch.zeros = torch_zeros
292
+ torch.linspace = torch_linspace
293
+ torch.Generator = torch_Generator
294
+ torch.load = torch_load
295
+
296
+ torch.backends.cuda.sdp_kernel = return_null_context
297
+ torch.nn.DataParallel = DummyDataParallel
298
+ torch.UntypedStorage.is_cuda = is_cuda
299
+ torch.amp.autocast_mode.autocast.__init__ = autocast_init
300
+
301
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
302
+ torch.nn.functional.group_norm = functional_group_norm
303
+ torch.nn.functional.layer_norm = functional_layer_norm
304
+ torch.nn.functional.linear = functional_linear
305
+ torch.nn.functional.conv2d = functional_conv2d
306
+ torch.nn.functional.interpolate = interpolate
307
+ torch.nn.functional.pad = functional_pad
308
+
309
+ torch.bmm = torch_bmm
310
+ torch.cat = torch_cat
311
+ if not device_supports_fp64:
312
+ torch.from_numpy = from_numpy
313
+ torch.as_tensor = as_tensor
library/lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+ try:
21
+ from diffusers.utils import PIL_INTERPOLATION
22
+ except ImportError:
23
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
24
+ PIL_INTERPOLATION = {
25
+ "linear": PIL.Image.Resampling.BILINEAR,
26
+ "bilinear": PIL.Image.Resampling.BILINEAR,
27
+ "bicubic": PIL.Image.Resampling.BICUBIC,
28
+ "lanczos": PIL.Image.Resampling.LANCZOS,
29
+ "nearest": PIL.Image.Resampling.NEAREST,
30
+ }
31
+ else:
32
+ PIL_INTERPOLATION = {
33
+ "linear": PIL.Image.LINEAR,
34
+ "bilinear": PIL.Image.BILINEAR,
35
+ "bicubic": PIL.Image.BICUBIC,
36
+ "lanczos": PIL.Image.LANCZOS,
37
+ "nearest": PIL.Image.NEAREST,
38
+ }
39
+ # ------------------------------------------------------------------------------
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ \( - literal character '('
71
+ \[ - literal character '['
72
+ \) - literal character ')'
73
+ \] - literal character ']'
74
+ \\ - literal character '\'
75
+ anything else - just text
76
+ >>> parse_prompt_attention('normal text')
77
+ [['normal text', 1.0]]
78
+ >>> parse_prompt_attention('an (important) word')
79
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
80
+ >>> parse_prompt_attention('(unbalanced')
81
+ [['unbalanced', 1.1]]
82
+ >>> parse_prompt_attention('\(literal\]')
83
+ [['(literal]', 1.0]]
84
+ >>> parse_prompt_attention('(unnecessary)(parens)')
85
+ [['unnecessaryparens', 1.1]]
86
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
87
+ [['a ', 1.0],
88
+ ['house', 1.5730000000000004],
89
+ [' ', 1.1],
90
+ ['on', 1.0],
91
+ [' a ', 1.1],
92
+ ['hill', 0.55],
93
+ [', sun, ', 1.1],
94
+ ['sky', 1.4641000000000006],
95
+ ['.', 1.1]]
96
+ """
97
+
98
+ res = []
99
+ round_brackets = []
100
+ square_brackets = []
101
+
102
+ round_bracket_multiplier = 1.1
103
+ square_bracket_multiplier = 1 / 1.1
104
+
105
+ def multiply_range(start_position, multiplier):
106
+ for p in range(start_position, len(res)):
107
+ res[p][1] *= multiplier
108
+
109
+ for m in re_attention.finditer(text):
110
+ text = m.group(0)
111
+ weight = m.group(1)
112
+
113
+ if text.startswith("\\"):
114
+ res.append([text[1:], 1.0])
115
+ elif text == "(":
116
+ round_brackets.append(len(res))
117
+ elif text == "[":
118
+ square_brackets.append(len(res))
119
+ elif weight is not None and len(round_brackets) > 0:
120
+ multiply_range(round_brackets.pop(), float(weight))
121
+ elif text == ")" and len(round_brackets) > 0:
122
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
123
+ elif text == "]" and len(square_brackets) > 0:
124
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
125
+ else:
126
+ res.append([text, 1.0])
127
+
128
+ for pos in round_brackets:
129
+ multiply_range(pos, round_bracket_multiplier)
130
+
131
+ for pos in square_brackets:
132
+ multiply_range(pos, square_bracket_multiplier)
133
+
134
+ if len(res) == 0:
135
+ res = [["", 1.0]]
136
+
137
+ # merge runs of identical weights
138
+ i = 0
139
+ while i + 1 < len(res):
140
+ if res[i][1] == res[i + 1][1]:
141
+ res[i][0] += res[i + 1][0]
142
+ res.pop(i + 1)
143
+ else:
144
+ i += 1
145
+
146
+ return res
147
+
148
+
149
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
150
+ r"""
151
+ Tokenize a list of prompts and return its tokens with weights of each token.
152
+
153
+ No padding, starting or ending token is included.
154
+ """
155
+ tokens = []
156
+ weights = []
157
+ truncated = False
158
+ for text in prompt:
159
+ texts_and_weights = parse_prompt_attention(text)
160
+ text_token = []
161
+ text_weight = []
162
+ for word, weight in texts_and_weights:
163
+ # tokenize and discard the starting and the ending token
164
+ token = pipe.tokenizer(word).input_ids[1:-1]
165
+ text_token += token
166
+ # copy the weight by length of token
167
+ text_weight += [weight] * len(token)
168
+ # stop if the text is too long (longer than truncation limit)
169
+ if len(text_token) > max_length:
170
+ truncated = True
171
+ break
172
+ # truncate
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ text_token = text_token[:max_length]
176
+ text_weight = text_weight[:max_length]
177
+ tokens.append(text_token)
178
+ weights.append(text_weight)
179
+ if truncated:
180
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
181
+ return tokens, weights
182
+
183
+
184
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
192
+ if no_boseos_middle:
193
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
194
+ else:
195
+ w = []
196
+ if len(weights[i]) == 0:
197
+ w = [1.0] * weights_length
198
+ else:
199
+ for j in range(max_embeddings_multiples):
200
+ w.append(1.0) # weight for starting token in this chunk
201
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
202
+ w.append(1.0) # weight for ending token in this chunk
203
+ w += [1.0] * (weights_length - len(w))
204
+ weights[i] = w[:]
205
+
206
+ return tokens, weights
207
+
208
+
209
+ def get_unweighted_text_embeddings(
210
+ pipe: StableDiffusionPipeline,
211
+ text_input: torch.Tensor,
212
+ chunk_length: int,
213
+ clip_skip: int,
214
+ eos: int,
215
+ pad: int,
216
+ no_boseos_middle: Optional[bool] = True,
217
+ ):
218
+ """
219
+ When the length of tokens is a multiple of the capacity of the text encoder,
220
+ it should be split into chunks and sent to the text encoder individually.
221
+ """
222
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
223
+ if max_embeddings_multiples > 1:
224
+ text_embeddings = []
225
+ for i in range(max_embeddings_multiples):
226
+ # extract the i-th chunk
227
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
228
+
229
+ # cover the head and the tail by the starting and the ending tokens
230
+ text_input_chunk[:, 0] = text_input[0, 0]
231
+ if pad == eos: # v1
232
+ text_input_chunk[:, -1] = text_input[0, -1]
233
+ else: # v2
234
+ for j in range(len(text_input_chunk)):
235
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
236
+ text_input_chunk[j, -1] = eos
237
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
238
+ text_input_chunk[j, 1] = eos
239
+
240
+ if clip_skip is None or clip_skip == 1:
241
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
242
+ else:
243
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
244
+ text_embedding = enc_out["hidden_states"][-clip_skip]
245
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
246
+
247
+ if no_boseos_middle:
248
+ if i == 0:
249
+ # discard the ending token
250
+ text_embedding = text_embedding[:, :-1]
251
+ elif i == max_embeddings_multiples - 1:
252
+ # discard the starting token
253
+ text_embedding = text_embedding[:, 1:]
254
+ else:
255
+ # discard both starting and ending tokens
256
+ text_embedding = text_embedding[:, 1:-1]
257
+
258
+ text_embeddings.append(text_embedding)
259
+ text_embeddings = torch.concat(text_embeddings, axis=1)
260
+ else:
261
+ if clip_skip is None or clip_skip == 1:
262
+ text_embeddings = pipe.text_encoder(text_input)[0]
263
+ else:
264
+ enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
265
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
266
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
267
+ return text_embeddings
268
+
269
+
270
+ def get_weighted_text_embeddings(
271
+ pipe: StableDiffusionPipeline,
272
+ prompt: Union[str, List[str]],
273
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
274
+ max_embeddings_multiples: Optional[int] = 3,
275
+ no_boseos_middle: Optional[bool] = False,
276
+ skip_parsing: Optional[bool] = False,
277
+ skip_weighting: Optional[bool] = False,
278
+ clip_skip=None,
279
+ ):
280
+ r"""
281
+ Prompts can be assigned with local weights using brackets. For example,
282
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
283
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
284
+
285
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
286
+
287
+ Args:
288
+ pipe (`StableDiffusionPipeline`):
289
+ Pipe to provide access to the tokenizer and the text encoder.
290
+ prompt (`str` or `List[str]`):
291
+ The prompt or prompts to guide the image generation.
292
+ uncond_prompt (`str` or `List[str]`):
293
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
294
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
295
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
296
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
297
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
298
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
299
+ ending token in each of the chunk in the middle.
300
+ skip_parsing (`bool`, *optional*, defaults to `False`):
301
+ Skip the parsing of brackets.
302
+ skip_weighting (`bool`, *optional*, defaults to `False`):
303
+ Skip the weighting. When the parsing is skipped, it is forced True.
304
+ """
305
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
306
+ if isinstance(prompt, str):
307
+ prompt = [prompt]
308
+
309
+ if not skip_parsing:
310
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
311
+ if uncond_prompt is not None:
312
+ if isinstance(uncond_prompt, str):
313
+ uncond_prompt = [uncond_prompt]
314
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
315
+ else:
316
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
317
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
318
+ if uncond_prompt is not None:
319
+ if isinstance(uncond_prompt, str):
320
+ uncond_prompt = [uncond_prompt]
321
+ uncond_tokens = [
322
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
323
+ ]
324
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
325
+
326
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
327
+ max_length = max([len(token) for token in prompt_tokens])
328
+ if uncond_prompt is not None:
329
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
330
+
331
+ max_embeddings_multiples = min(
332
+ max_embeddings_multiples,
333
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
334
+ )
335
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
336
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
337
+
338
+ # pad the length of tokens and weights
339
+ bos = pipe.tokenizer.bos_token_id
340
+ eos = pipe.tokenizer.eos_token_id
341
+ pad = pipe.tokenizer.pad_token_id
342
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
343
+ prompt_tokens,
344
+ prompt_weights,
345
+ max_length,
346
+ bos,
347
+ eos,
348
+ no_boseos_middle=no_boseos_middle,
349
+ chunk_length=pipe.tokenizer.model_max_length,
350
+ )
351
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
352
+ if uncond_prompt is not None:
353
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
354
+ uncond_tokens,
355
+ uncond_weights,
356
+ max_length,
357
+ bos,
358
+ eos,
359
+ no_boseos_middle=no_boseos_middle,
360
+ chunk_length=pipe.tokenizer.model_max_length,
361
+ )
362
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
363
+
364
+ # get the embeddings
365
+ text_embeddings = get_unweighted_text_embeddings(
366
+ pipe,
367
+ prompt_tokens,
368
+ pipe.tokenizer.model_max_length,
369
+ clip_skip,
370
+ eos,
371
+ pad,
372
+ no_boseos_middle=no_boseos_middle,
373
+ )
374
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
375
+ if uncond_prompt is not None:
376
+ uncond_embeddings = get_unweighted_text_embeddings(
377
+ pipe,
378
+ uncond_tokens,
379
+ pipe.tokenizer.model_max_length,
380
+ clip_skip,
381
+ eos,
382
+ pad,
383
+ no_boseos_middle=no_boseos_middle,
384
+ )
385
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
386
+
387
+ # assign weights to the prompts and normalize in the sense of mean
388
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
389
+ if (not skip_parsing) and (not skip_weighting):
390
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
391
+ text_embeddings *= prompt_weights.unsqueeze(-1)
392
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
394
+ if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
396
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
397
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
398
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
399
+
400
+ if uncond_prompt is not None:
401
+ return text_embeddings, uncond_embeddings
402
+ return text_embeddings, None
403
+
404
+
405
+ def preprocess_image(image):
406
+ w, h = image.size
407
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
408
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
409
+ image = np.array(image).astype(np.float32) / 255.0
410
+ image = image[None].transpose(0, 3, 1, 2)
411
+ image = torch.from_numpy(image)
412
+ return 2.0 * image - 1.0
413
+
414
+
415
+ def preprocess_mask(mask, scale_factor=8):
416
+ mask = mask.convert("L")
417
+ w, h = mask.size
418
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
419
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
420
+ mask = np.array(mask).astype(np.float32) / 255.0
421
+ mask = np.tile(mask, (4, 1, 1))
422
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
423
+ mask = 1 - mask # repaint white, keep black
424
+ mask = torch.from_numpy(mask)
425
+ return mask
426
+
427
+
428
+ def prepare_controlnet_image(
429
+ image: PIL.Image.Image,
430
+ width: int,
431
+ height: int,
432
+ batch_size: int,
433
+ num_images_per_prompt: int,
434
+ device: torch.device,
435
+ dtype: torch.dtype,
436
+ do_classifier_free_guidance: bool = False,
437
+ guess_mode: bool = False,
438
+ ):
439
+ if not isinstance(image, torch.Tensor):
440
+ if isinstance(image, PIL.Image.Image):
441
+ image = [image]
442
+
443
+ if isinstance(image[0], PIL.Image.Image):
444
+ images = []
445
+
446
+ for image_ in image:
447
+ image_ = image_.convert("RGB")
448
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
449
+ image_ = np.array(image_)
450
+ image_ = image_[None, :]
451
+ images.append(image_)
452
+
453
+ image = images
454
+
455
+ image = np.concatenate(image, axis=0)
456
+ image = np.array(image).astype(np.float32) / 255.0
457
+ image = image.transpose(0, 3, 1, 2)
458
+ image = torch.from_numpy(image)
459
+ elif isinstance(image[0], torch.Tensor):
460
+ image = torch.cat(image, dim=0)
461
+
462
+ image_batch_size = image.shape[0]
463
+
464
+ if image_batch_size == 1:
465
+ repeat_by = batch_size
466
+ else:
467
+ # image batch size is the same as prompt batch size
468
+ repeat_by = num_images_per_prompt
469
+
470
+ image = image.repeat_interleave(repeat_by, dim=0)
471
+
472
+ image = image.to(device=device, dtype=dtype)
473
+
474
+ if do_classifier_free_guidance and not guess_mode:
475
+ image = torch.cat([image] * 2)
476
+
477
+ return image
478
+
479
+
480
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
481
+ r"""
482
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
483
+ weighting in prompt.
484
+
485
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
486
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
487
+
488
+ Args:
489
+ vae ([`AutoencoderKL`]):
490
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
491
+ text_encoder ([`CLIPTextModel`]):
492
+ Frozen text-encoder. Stable Diffusion uses the text portion of
493
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
494
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
495
+ tokenizer (`CLIPTokenizer`):
496
+ Tokenizer of class
497
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
498
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
499
+ scheduler ([`SchedulerMixin`]):
500
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
501
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
502
+ safety_checker ([`StableDiffusionSafetyChecker`]):
503
+ Classification module that estimates whether generated images could be considered offensive or harmful.
504
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
505
+ feature_extractor ([`CLIPFeatureExtractor`]):
506
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
507
+ """
508
+
509
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
510
+
511
+ def __init__(
512
+ self,
513
+ vae: AutoencoderKL,
514
+ text_encoder: CLIPTextModel,
515
+ tokenizer: CLIPTokenizer,
516
+ unet: UNet2DConditionModel,
517
+ scheduler: SchedulerMixin,
518
+ # clip_skip: int,
519
+ safety_checker: StableDiffusionSafetyChecker,
520
+ feature_extractor: CLIPFeatureExtractor,
521
+ requires_safety_checker: bool = True,
522
+ image_encoder: CLIPVisionModelWithProjection = None,
523
+ clip_skip: int = 1,
524
+ ):
525
+ super().__init__(
526
+ vae=vae,
527
+ text_encoder=text_encoder,
528
+ tokenizer=tokenizer,
529
+ unet=unet,
530
+ scheduler=scheduler,
531
+ safety_checker=safety_checker,
532
+ feature_extractor=feature_extractor,
533
+ requires_safety_checker=requires_safety_checker,
534
+ image_encoder=image_encoder,
535
+ )
536
+ self.custom_clip_skip = clip_skip
537
+ self.__init__additional__()
538
+
539
+ def __init__additional__(self):
540
+ if not hasattr(self, "vae_scale_factor"):
541
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
542
+
543
+ @property
544
+ def _execution_device(self):
545
+ r"""
546
+ Returns the device on which the pipeline's models will be executed. After calling
547
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
548
+ hooks.
549
+ """
550
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
551
+ return self.device
552
+ for module in self.unet.modules():
553
+ if (
554
+ hasattr(module, "_hf_hook")
555
+ and hasattr(module._hf_hook, "execution_device")
556
+ and module._hf_hook.execution_device is not None
557
+ ):
558
+ return torch.device(module._hf_hook.execution_device)
559
+ return self.device
560
+
561
+ def _encode_prompt(
562
+ self,
563
+ prompt,
564
+ device,
565
+ num_images_per_prompt,
566
+ do_classifier_free_guidance,
567
+ negative_prompt,
568
+ max_embeddings_multiples,
569
+ ):
570
+ r"""
571
+ Encodes the prompt into text encoder hidden states.
572
+
573
+ Args:
574
+ prompt (`str` or `list(int)`):
575
+ prompt to be encoded
576
+ device: (`torch.device`):
577
+ torch device
578
+ num_images_per_prompt (`int`):
579
+ number of images that should be generated per prompt
580
+ do_classifier_free_guidance (`bool`):
581
+ whether to use classifier free guidance or not
582
+ negative_prompt (`str` or `List[str]`):
583
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
584
+ if `guidance_scale` is less than `1`).
585
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
586
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
587
+ """
588
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
589
+
590
+ if negative_prompt is None:
591
+ negative_prompt = [""] * batch_size
592
+ elif isinstance(negative_prompt, str):
593
+ negative_prompt = [negative_prompt] * batch_size
594
+ if batch_size != len(negative_prompt):
595
+ raise ValueError(
596
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
597
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
598
+ " the batch size of `prompt`."
599
+ )
600
+
601
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
602
+ pipe=self,
603
+ prompt=prompt,
604
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
605
+ max_embeddings_multiples=max_embeddings_multiples,
606
+ clip_skip=self.custom_clip_skip,
607
+ )
608
+ bs_embed, seq_len, _ = text_embeddings.shape
609
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
610
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
611
+
612
+ if do_classifier_free_guidance:
613
+ bs_embed, seq_len, _ = uncond_embeddings.shape
614
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
615
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
616
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
617
+
618
+ return text_embeddings
619
+
620
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
621
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
622
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
623
+
624
+ if strength < 0 or strength > 1:
625
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
626
+
627
+ if height % 8 != 0 or width % 8 != 0:
628
+ logger.info(f'{height} {width}')
629
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
630
+
631
+ if (callback_steps is None) or (
632
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
633
+ ):
634
+ raise ValueError(
635
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
636
+ )
637
+
638
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
639
+ if is_text2img:
640
+ return self.scheduler.timesteps.to(device), num_inference_steps
641
+ else:
642
+ # get the original timestep using init_timestep
643
+ offset = self.scheduler.config.get("steps_offset", 0)
644
+ init_timestep = int(num_inference_steps * strength) + offset
645
+ init_timestep = min(init_timestep, num_inference_steps)
646
+
647
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
648
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
649
+ return timesteps, num_inference_steps - t_start
650
+
651
+ def run_safety_checker(self, image, device, dtype):
652
+ if self.safety_checker is not None:
653
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
654
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
655
+ else:
656
+ has_nsfw_concept = None
657
+ return image, has_nsfw_concept
658
+
659
+ def decode_latents(self, latents):
660
+ latents = 1 / 0.18215 * latents
661
+ image = self.vae.decode(latents).sample
662
+ image = (image / 2 + 0.5).clamp(0, 1)
663
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
664
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
665
+ return image
666
+
667
+ def prepare_extra_step_kwargs(self, generator, eta):
668
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
669
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
670
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
671
+ # and should be between [0, 1]
672
+
673
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
674
+ extra_step_kwargs = {}
675
+ if accepts_eta:
676
+ extra_step_kwargs["eta"] = eta
677
+
678
+ # check if the scheduler accepts generator
679
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
680
+ if accepts_generator:
681
+ extra_step_kwargs["generator"] = generator
682
+ return extra_step_kwargs
683
+
684
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
685
+ if image is None:
686
+ shape = (
687
+ batch_size,
688
+ self.unet.in_channels,
689
+ height // self.vae_scale_factor,
690
+ width // self.vae_scale_factor,
691
+ )
692
+
693
+ if latents is None:
694
+ if device.type == "mps":
695
+ # randn does not work reproducibly on mps
696
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
697
+ else:
698
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
699
+ else:
700
+ if latents.shape != shape:
701
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
702
+ latents = latents.to(device)
703
+
704
+ # scale the initial noise by the standard deviation required by the scheduler
705
+ latents = latents * self.scheduler.init_noise_sigma
706
+ return latents, None, None
707
+ else:
708
+ init_latent_dist = self.vae.encode(image).latent_dist
709
+ init_latents = init_latent_dist.sample(generator=generator)
710
+ init_latents = 0.18215 * init_latents
711
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
712
+ init_latents_orig = init_latents
713
+ shape = init_latents.shape
714
+
715
+ # add noise to latents using the timesteps
716
+ if device.type == "mps":
717
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
718
+ else:
719
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
720
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
721
+ return latents, init_latents_orig, noise
722
+
723
+ @torch.no_grad()
724
+ def __call__(
725
+ self,
726
+ prompt: Union[str, List[str]],
727
+ negative_prompt: Optional[Union[str, List[str]]] = None,
728
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
729
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
730
+ height: int = 512,
731
+ width: int = 512,
732
+ num_inference_steps: int = 50,
733
+ guidance_scale: float = 7.5,
734
+ strength: float = 0.8,
735
+ num_images_per_prompt: Optional[int] = 1,
736
+ eta: float = 0.0,
737
+ generator: Optional[torch.Generator] = None,
738
+ latents: Optional[torch.FloatTensor] = None,
739
+ max_embeddings_multiples: Optional[int] = 3,
740
+ output_type: Optional[str] = "pil",
741
+ return_dict: bool = True,
742
+ controlnet=None,
743
+ controlnet_image=None,
744
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
745
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
746
+ callback_steps: int = 1,
747
+ ):
748
+ r"""
749
+ Function invoked when calling the pipeline for generation.
750
+
751
+ Args:
752
+ prompt (`str` or `List[str]`):
753
+ The prompt or prompts to guide the image generation.
754
+ negative_prompt (`str` or `List[str]`, *optional*):
755
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
756
+ if `guidance_scale` is less than `1`).
757
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
758
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
759
+ process.
760
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
761
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
762
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
763
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
764
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
765
+ height (`int`, *optional*, defaults to 512):
766
+ The height in pixels of the generated image.
767
+ width (`int`, *optional*, defaults to 512):
768
+ The width in pixels of the generated image.
769
+ num_inference_steps (`int`, *optional*, defaults to 50):
770
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
771
+ expense of slower inference.
772
+ guidance_scale (`float`, *optional*, defaults to 7.5):
773
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
775
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
776
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
777
+ usually at the expense of lower image quality.
778
+ strength (`float`, *optional*, defaults to 0.8):
779
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
780
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
781
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
782
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
783
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
784
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
785
+ The number of images to generate per prompt.
786
+ eta (`float`, *optional*, defaults to 0.0):
787
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
788
+ [`schedulers.DDIMScheduler`], will be ignored for others.
789
+ generator (`torch.Generator`, *optional*):
790
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
791
+ deterministic.
792
+ latents (`torch.FloatTensor`, *optional*):
793
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
794
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
795
+ tensor will ge generated by sampling using the supplied random `generator`.
796
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
797
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
798
+ output_type (`str`, *optional*, defaults to `"pil"`):
799
+ The output format of the generate image. Choose between
800
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
803
+ plain tuple.
804
+ controlnet (`diffusers.ControlNetModel`, *optional*):
805
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
806
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
807
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
808
+ inference.
809
+ callback (`Callable`, *optional*):
810
+ A function that will be called every `callback_steps` steps during inference. The function will be
811
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
812
+ is_cancelled_callback (`Callable`, *optional*):
813
+ A function that will be called every `callback_steps` steps during inference. If the function returns
814
+ `True`, the inference will be cancelled.
815
+ callback_steps (`int`, *optional*, defaults to 1):
816
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
817
+ called at every step.
818
+
819
+ Returns:
820
+ `None` if cancelled by `is_cancelled_callback`,
821
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
822
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
823
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
824
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
825
+ (nsfw) content, according to the `safety_checker`.
826
+ """
827
+ if controlnet is not None and controlnet_image is None:
828
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
829
+
830
+ # 0. Default height and width to unet
831
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
832
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
833
+
834
+ # 1. Check inputs. Raise error if not correct
835
+ self.check_inputs(prompt, height, width, strength, callback_steps)
836
+
837
+ # 2. Define call parameters
838
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
839
+ device = self._execution_device
840
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
841
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
842
+ # corresponds to doing no classifier free guidance.
843
+ do_classifier_free_guidance = guidance_scale > 1.0
844
+
845
+ # 3. Encode input prompt
846
+ text_embeddings = self._encode_prompt(
847
+ prompt,
848
+ device,
849
+ num_images_per_prompt,
850
+ do_classifier_free_guidance,
851
+ negative_prompt,
852
+ max_embeddings_multiples,
853
+ )
854
+ dtype = text_embeddings.dtype
855
+
856
+ # 4. Preprocess image and mask
857
+ if isinstance(image, PIL.Image.Image):
858
+ image = preprocess_image(image)
859
+ if image is not None:
860
+ image = image.to(device=self.device, dtype=dtype)
861
+ if isinstance(mask_image, PIL.Image.Image):
862
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
863
+ if mask_image is not None:
864
+ mask = mask_image.to(device=self.device, dtype=dtype)
865
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
866
+ else:
867
+ mask = None
868
+
869
+ if controlnet_image is not None:
870
+ controlnet_image = prepare_controlnet_image(
871
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
872
+ )
873
+
874
+ # 5. set timesteps
875
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
876
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
877
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
878
+
879
+ # 6. Prepare latent variables
880
+ latents, init_latents_orig, noise = self.prepare_latents(
881
+ image,
882
+ latent_timestep,
883
+ batch_size * num_images_per_prompt,
884
+ height,
885
+ width,
886
+ dtype,
887
+ device,
888
+ generator,
889
+ latents,
890
+ )
891
+
892
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
893
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
894
+
895
+ # 8. Denoising loop
896
+ for i, t in enumerate(self.progress_bar(timesteps)):
897
+ # expand the latents if we are doing classifier free guidance
898
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
899
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
900
+
901
+ unet_additional_args = {}
902
+ if controlnet is not None:
903
+ down_block_res_samples, mid_block_res_sample = controlnet(
904
+ latent_model_input,
905
+ t,
906
+ encoder_hidden_states=text_embeddings,
907
+ controlnet_cond=controlnet_image,
908
+ conditioning_scale=1.0,
909
+ guess_mode=False,
910
+ return_dict=False,
911
+ )
912
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
913
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
914
+
915
+ # predict the noise residual
916
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
917
+
918
+ # perform guidance
919
+ if do_classifier_free_guidance:
920
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
921
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
922
+
923
+ # compute the previous noisy sample x_t -> x_t-1
924
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
925
+
926
+ if mask is not None:
927
+ # masking
928
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
929
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
930
+
931
+ # call the callback, if provided
932
+ if i % callback_steps == 0:
933
+ if callback is not None:
934
+ callback(i, t, latents)
935
+ if is_cancelled_callback is not None and is_cancelled_callback():
936
+ return None
937
+
938
+ return latents
939
+
940
+ def latents_to_image(self, latents):
941
+ # 9. Post-processing
942
+ image = self.decode_latents(latents.to(self.vae.dtype))
943
+ image = self.numpy_to_pil(image)
944
+ return image
945
+
946
+ def text2img(
947
+ self,
948
+ prompt: Union[str, List[str]],
949
+ negative_prompt: Optional[Union[str, List[str]]] = None,
950
+ height: int = 512,
951
+ width: int = 512,
952
+ num_inference_steps: int = 50,
953
+ guidance_scale: float = 7.5,
954
+ num_images_per_prompt: Optional[int] = 1,
955
+ eta: float = 0.0,
956
+ generator: Optional[torch.Generator] = None,
957
+ latents: Optional[torch.FloatTensor] = None,
958
+ max_embeddings_multiples: Optional[int] = 3,
959
+ output_type: Optional[str] = "pil",
960
+ return_dict: bool = True,
961
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
962
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
963
+ callback_steps: int = 1,
964
+ ):
965
+ r"""
966
+ Function for text-to-image generation.
967
+ Args:
968
+ prompt (`str` or `List[str]`):
969
+ The prompt or prompts to guide the image generation.
970
+ negative_prompt (`str` or `List[str]`, *optional*):
971
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
972
+ if `guidance_scale` is less than `1`).
973
+ height (`int`, *optional*, defaults to 512):
974
+ The height in pixels of the generated image.
975
+ width (`int`, *optional*, defaults to 512):
976
+ The width in pixels of the generated image.
977
+ num_inference_steps (`int`, *optional*, defaults to 50):
978
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
979
+ expense of slower inference.
980
+ guidance_scale (`float`, *optional*, defaults to 7.5):
981
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
982
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
983
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
984
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
985
+ usually at the expense of lower image quality.
986
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
987
+ The number of images to generate per prompt.
988
+ eta (`float`, *optional*, defaults to 0.0):
989
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
990
+ [`schedulers.DDIMScheduler`], will be ignored for others.
991
+ generator (`torch.Generator`, *optional*):
992
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
993
+ deterministic.
994
+ latents (`torch.FloatTensor`, *optional*):
995
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
996
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
997
+ tensor will ge generated by sampling using the supplied random `generator`.
998
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
999
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1000
+ output_type (`str`, *optional*, defaults to `"pil"`):
1001
+ The output format of the generate image. Choose between
1002
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1003
+ return_dict (`bool`, *optional*, defaults to `True`):
1004
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1005
+ plain tuple.
1006
+ callback (`Callable`, *optional*):
1007
+ A function that will be called every `callback_steps` steps during inference. The function will be
1008
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1009
+ is_cancelled_callback (`Callable`, *optional*):
1010
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1011
+ `True`, the inference will be cancelled.
1012
+ callback_steps (`int`, *optional*, defaults to 1):
1013
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1014
+ called at every step.
1015
+ Returns:
1016
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1017
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1018
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1019
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1020
+ (nsfw) content, according to the `safety_checker`.
1021
+ """
1022
+ return self.__call__(
1023
+ prompt=prompt,
1024
+ negative_prompt=negative_prompt,
1025
+ height=height,
1026
+ width=width,
1027
+ num_inference_steps=num_inference_steps,
1028
+ guidance_scale=guidance_scale,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ eta=eta,
1031
+ generator=generator,
1032
+ latents=latents,
1033
+ max_embeddings_multiples=max_embeddings_multiples,
1034
+ output_type=output_type,
1035
+ return_dict=return_dict,
1036
+ callback=callback,
1037
+ is_cancelled_callback=is_cancelled_callback,
1038
+ callback_steps=callback_steps,
1039
+ )
1040
+
1041
+ def img2img(
1042
+ self,
1043
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1044
+ prompt: Union[str, List[str]],
1045
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1046
+ strength: float = 0.8,
1047
+ num_inference_steps: Optional[int] = 50,
1048
+ guidance_scale: Optional[float] = 7.5,
1049
+ num_images_per_prompt: Optional[int] = 1,
1050
+ eta: Optional[float] = 0.0,
1051
+ generator: Optional[torch.Generator] = None,
1052
+ max_embeddings_multiples: Optional[int] = 3,
1053
+ output_type: Optional[str] = "pil",
1054
+ return_dict: bool = True,
1055
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1056
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1057
+ callback_steps: int = 1,
1058
+ ):
1059
+ r"""
1060
+ Function for image-to-image generation.
1061
+ Args:
1062
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1063
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1064
+ process.
1065
+ prompt (`str` or `List[str]`):
1066
+ The prompt or prompts to guide the image generation.
1067
+ negative_prompt (`str` or `List[str]`, *optional*):
1068
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1069
+ if `guidance_scale` is less than `1`).
1070
+ strength (`float`, *optional*, defaults to 0.8):
1071
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1072
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1073
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1074
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1075
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1076
+ num_inference_steps (`int`, *optional*, defaults to 50):
1077
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1078
+ expense of slower inference. This parameter will be modulated by `strength`.
1079
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1080
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1081
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1082
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1083
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1084
+ usually at the expense of lower image quality.
1085
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1086
+ The number of images to generate per prompt.
1087
+ eta (`float`, *optional*, defaults to 0.0):
1088
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1089
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1090
+ generator (`torch.Generator`, *optional*):
1091
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1092
+ deterministic.
1093
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1094
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1095
+ output_type (`str`, *optional*, defaults to `"pil"`):
1096
+ The output format of the generate image. Choose between
1097
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1098
+ return_dict (`bool`, *optional*, defaults to `True`):
1099
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1100
+ plain tuple.
1101
+ callback (`Callable`, *optional*):
1102
+ A function that will be called every `callback_steps` steps during inference. The function will be
1103
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1104
+ is_cancelled_callback (`Callable`, *optional*):
1105
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1106
+ `True`, the inference will be cancelled.
1107
+ callback_steps (`int`, *optional*, defaults to 1):
1108
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1109
+ called at every step.
1110
+ Returns:
1111
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1112
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1113
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1114
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1115
+ (nsfw) content, according to the `safety_checker`.
1116
+ """
1117
+ return self.__call__(
1118
+ prompt=prompt,
1119
+ negative_prompt=negative_prompt,
1120
+ image=image,
1121
+ num_inference_steps=num_inference_steps,
1122
+ guidance_scale=guidance_scale,
1123
+ strength=strength,
1124
+ num_images_per_prompt=num_images_per_prompt,
1125
+ eta=eta,
1126
+ generator=generator,
1127
+ max_embeddings_multiples=max_embeddings_multiples,
1128
+ output_type=output_type,
1129
+ return_dict=return_dict,
1130
+ callback=callback,
1131
+ is_cancelled_callback=is_cancelled_callback,
1132
+ callback_steps=callback_steps,
1133
+ )
1134
+
1135
+ def inpaint(
1136
+ self,
1137
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1138
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1139
+ prompt: Union[str, List[str]],
1140
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1141
+ strength: float = 0.8,
1142
+ num_inference_steps: Optional[int] = 50,
1143
+ guidance_scale: Optional[float] = 7.5,
1144
+ num_images_per_prompt: Optional[int] = 1,
1145
+ eta: Optional[float] = 0.0,
1146
+ generator: Optional[torch.Generator] = None,
1147
+ max_embeddings_multiples: Optional[int] = 3,
1148
+ output_type: Optional[str] = "pil",
1149
+ return_dict: bool = True,
1150
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1151
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1152
+ callback_steps: int = 1,
1153
+ ):
1154
+ r"""
1155
+ Function for inpaint.
1156
+ Args:
1157
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1158
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1159
+ process. This is the image whose masked region will be inpainted.
1160
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1161
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1162
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1163
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1164
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1165
+ prompt (`str` or `List[str]`):
1166
+ The prompt or prompts to guide the image generation.
1167
+ negative_prompt (`str` or `List[str]`, *optional*):
1168
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1169
+ if `guidance_scale` is less than `1`).
1170
+ strength (`float`, *optional*, defaults to 0.8):
1171
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1172
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1173
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1174
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1175
+ num_inference_steps (`int`, *optional*, defaults to 50):
1176
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1177
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1178
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1179
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1180
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1181
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1182
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1183
+ usually at the expense of lower image quality.
1184
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1185
+ The number of images to generate per prompt.
1186
+ eta (`float`, *optional*, defaults to 0.0):
1187
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1188
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1189
+ generator (`torch.Generator`, *optional*):
1190
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1191
+ deterministic.
1192
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1193
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1194
+ output_type (`str`, *optional*, defaults to `"pil"`):
1195
+ The output format of the generate image. Choose between
1196
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1197
+ return_dict (`bool`, *optional*, defaults to `True`):
1198
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1199
+ plain tuple.
1200
+ callback (`Callable`, *optional*):
1201
+ A function that will be called every `callback_steps` steps during inference. The function will be
1202
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1203
+ is_cancelled_callback (`Callable`, *optional*):
1204
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1205
+ `True`, the inference will be cancelled.
1206
+ callback_steps (`int`, *optional*, defaults to 1):
1207
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1208
+ called at every step.
1209
+ Returns:
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1212
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1213
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1214
+ (nsfw) content, according to the `safety_checker`.
1215
+ """
1216
+ return self.__call__(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ image=image,
1220
+ mask_image=mask_image,
1221
+ num_inference_steps=num_inference_steps,
1222
+ guidance_scale=guidance_scale,
1223
+ strength=strength,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ eta=eta,
1226
+ generator=generator,
1227
+ max_embeddings_multiples=max_embeddings_multiples,
1228
+ output_type=output_type,
1229
+ return_dict=return_dict,
1230
+ callback=callback,
1231
+ is_cancelled_callback=is_cancelled_callback,
1232
+ callback_steps=callback_steps,
1233
+ )
library/model_util.py ADDED
@@ -0,0 +1,1356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v1: split from train_db_fixed.py.
2
+ # v2: support safetensors
3
+
4
+ import math
5
+ import os
6
+
7
+ import torch
8
+ from library.device_utils import init_ipex
9
+ init_ipex()
10
+
11
+ import diffusers
12
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
13
+ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
14
+ from safetensors.torch import load_file, save_file
15
+ from library.original_unet import UNet2DConditionModel
16
+ from library.utils import setup_logging
17
+ setup_logging()
18
+ import logging
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # DiffUsers版StableDiffusionのモデルパラメータ
22
+ NUM_TRAIN_TIMESTEPS = 1000
23
+ BETA_START = 0.00085
24
+ BETA_END = 0.0120
25
+
26
+ UNET_PARAMS_MODEL_CHANNELS = 320
27
+ UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
28
+ UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
29
+ UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
30
+ UNET_PARAMS_IN_CHANNELS = 4
31
+ UNET_PARAMS_OUT_CHANNELS = 4
32
+ UNET_PARAMS_NUM_RES_BLOCKS = 2
33
+ UNET_PARAMS_CONTEXT_DIM = 768
34
+ UNET_PARAMS_NUM_HEADS = 8
35
+ # UNET_PARAMS_USE_LINEAR_PROJECTION = False
36
+
37
+ VAE_PARAMS_Z_CHANNELS = 4
38
+ VAE_PARAMS_RESOLUTION = 256
39
+ VAE_PARAMS_IN_CHANNELS = 3
40
+ VAE_PARAMS_OUT_CH = 3
41
+ VAE_PARAMS_CH = 128
42
+ VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
43
+ VAE_PARAMS_NUM_RES_BLOCKS = 2
44
+
45
+ # V2
46
+ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
47
+ V2_UNET_PARAMS_CONTEXT_DIM = 1024
48
+ # V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
49
+
50
+ # Diffusersの設定を読み込むための参照モデル
51
+ DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
52
+ DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
53
+
54
+
55
+ # region StableDiffusion->Diffusersの変換コード
56
+ # convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
57
+
58
+
59
+ def shave_segments(path, n_shave_prefix_segments=1):
60
+ """
61
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
62
+ """
63
+ if n_shave_prefix_segments >= 0:
64
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
65
+ else:
66
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
67
+
68
+
69
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
70
+ """
71
+ Updates paths inside resnets to the new naming scheme (local renaming)
72
+ """
73
+ mapping = []
74
+ for old_item in old_list:
75
+ new_item = old_item.replace("in_layers.0", "norm1")
76
+ new_item = new_item.replace("in_layers.2", "conv1")
77
+
78
+ new_item = new_item.replace("out_layers.0", "norm2")
79
+ new_item = new_item.replace("out_layers.3", "conv2")
80
+
81
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
82
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
83
+
84
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
85
+
86
+ mapping.append({"old": old_item, "new": new_item})
87
+
88
+ return mapping
89
+
90
+
91
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
92
+ """
93
+ Updates paths inside resnets to the new naming scheme (local renaming)
94
+ """
95
+ mapping = []
96
+ for old_item in old_list:
97
+ new_item = old_item
98
+
99
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
100
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
101
+
102
+ mapping.append({"old": old_item, "new": new_item})
103
+
104
+ return mapping
105
+
106
+
107
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
108
+ """
109
+ Updates paths inside attentions to the new naming scheme (local renaming)
110
+ """
111
+ mapping = []
112
+ for old_item in old_list:
113
+ new_item = old_item
114
+
115
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
116
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
117
+
118
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
119
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
120
+
121
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
122
+
123
+ mapping.append({"old": old_item, "new": new_item})
124
+
125
+ return mapping
126
+
127
+
128
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
129
+ """
130
+ Updates paths inside attentions to the new naming scheme (local renaming)
131
+ """
132
+ mapping = []
133
+ for old_item in old_list:
134
+ new_item = old_item
135
+
136
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
137
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
138
+
139
+ if diffusers.__version__ < "0.17.0":
140
+ new_item = new_item.replace("q.weight", "query.weight")
141
+ new_item = new_item.replace("q.bias", "query.bias")
142
+
143
+ new_item = new_item.replace("k.weight", "key.weight")
144
+ new_item = new_item.replace("k.bias", "key.bias")
145
+
146
+ new_item = new_item.replace("v.weight", "value.weight")
147
+ new_item = new_item.replace("v.bias", "value.bias")
148
+
149
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
150
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
151
+ else:
152
+ new_item = new_item.replace("q.weight", "to_q.weight")
153
+ new_item = new_item.replace("q.bias", "to_q.bias")
154
+
155
+ new_item = new_item.replace("k.weight", "to_k.weight")
156
+ new_item = new_item.replace("k.bias", "to_k.bias")
157
+
158
+ new_item = new_item.replace("v.weight", "to_v.weight")
159
+ new_item = new_item.replace("v.bias", "to_v.bias")
160
+
161
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
162
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
163
+
164
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
165
+
166
+ mapping.append({"old": old_item, "new": new_item})
167
+
168
+ return mapping
169
+
170
+
171
+ def assign_to_checkpoint(
172
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
173
+ ):
174
+ """
175
+ This does the final conversion step: take locally converted weights and apply a global renaming
176
+ to them. It splits attention layers, and takes into account additional replacements
177
+ that may arise.
178
+
179
+ Assigns the weights to the new checkpoint.
180
+ """
181
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
182
+
183
+ # Splits the attention layers into three variables.
184
+ if attention_paths_to_split is not None:
185
+ for path, path_map in attention_paths_to_split.items():
186
+ old_tensor = old_checkpoint[path]
187
+ channels = old_tensor.shape[0] // 3
188
+
189
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
190
+
191
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
192
+
193
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
194
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
195
+
196
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
197
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
198
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
199
+
200
+ for path in paths:
201
+ new_path = path["new"]
202
+
203
+ # These have already been assigned
204
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
205
+ continue
206
+
207
+ # Global renaming happens here
208
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
209
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
210
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
211
+
212
+ if additional_replacements is not None:
213
+ for replacement in additional_replacements:
214
+ new_path = new_path.replace(replacement["old"], replacement["new"])
215
+
216
+ # proj_attn.weight has to be converted from conv 1D to linear
217
+ reshaping = False
218
+ if diffusers.__version__ < "0.17.0":
219
+ if "proj_attn.weight" in new_path:
220
+ reshaping = True
221
+ else:
222
+ if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
223
+ reshaping = True
224
+
225
+ if reshaping:
226
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
227
+ else:
228
+ checkpoint[new_path] = old_checkpoint[path["old"]]
229
+
230
+
231
+ def conv_attn_to_linear(checkpoint):
232
+ keys = list(checkpoint.keys())
233
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
234
+ for key in keys:
235
+ if ".".join(key.split(".")[-2:]) in attn_keys:
236
+ if checkpoint[key].ndim > 2:
237
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
238
+ elif "proj_attn.weight" in key:
239
+ if checkpoint[key].ndim > 2:
240
+ checkpoint[key] = checkpoint[key][:, :, 0]
241
+
242
+
243
+ def linear_transformer_to_conv(checkpoint):
244
+ keys = list(checkpoint.keys())
245
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
246
+ for key in keys:
247
+ if ".".join(key.split(".")[-2:]) in tf_keys:
248
+ if checkpoint[key].ndim == 2:
249
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
250
+
251
+
252
+ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
253
+ """
254
+ Takes a state dict and a config, and returns a converted checkpoint.
255
+ """
256
+
257
+ # extract state_dict for UNet
258
+ unet_state_dict = {}
259
+ unet_key = "model.diffusion_model."
260
+ keys = list(checkpoint.keys())
261
+ for key in keys:
262
+ if key.startswith(unet_key):
263
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
264
+
265
+ new_checkpoint = {}
266
+
267
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
268
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
269
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
270
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
271
+
272
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
273
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
274
+
275
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
276
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
277
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
278
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
279
+
280
+ # Retrieves the keys for the input blocks only
281
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
282
+ input_blocks = {
283
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks)
284
+ }
285
+
286
+ # Retrieves the keys for the middle blocks only
287
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
288
+ middle_blocks = {
289
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks)
290
+ }
291
+
292
+ # Retrieves the keys for the output blocks only
293
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
294
+ output_blocks = {
295
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks)
296
+ }
297
+
298
+ for i in range(1, num_input_blocks):
299
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
300
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
301
+
302
+ resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key]
303
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
304
+
305
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
306
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
307
+ f"input_blocks.{i}.0.op.weight"
308
+ )
309
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
310
+
311
+ paths = renew_resnet_paths(resnets)
312
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
313
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
314
+
315
+ if len(attentions):
316
+ paths = renew_attention_paths(attentions)
317
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
318
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
319
+
320
+ resnet_0 = middle_blocks[0]
321
+ attentions = middle_blocks[1]
322
+ resnet_1 = middle_blocks[2]
323
+
324
+ resnet_0_paths = renew_resnet_paths(resnet_0)
325
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
326
+
327
+ resnet_1_paths = renew_resnet_paths(resnet_1)
328
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
329
+
330
+ attentions_paths = renew_attention_paths(attentions)
331
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
332
+ assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
333
+
334
+ for i in range(num_output_blocks):
335
+ block_id = i // (config["layers_per_block"] + 1)
336
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
337
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
338
+ output_block_list = {}
339
+
340
+ for layer in output_block_layers:
341
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
342
+ if layer_id in output_block_list:
343
+ output_block_list[layer_id].append(layer_name)
344
+ else:
345
+ output_block_list[layer_id] = [layer_name]
346
+
347
+ if len(output_block_list) > 1:
348
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
349
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
350
+
351
+ resnet_0_paths = renew_resnet_paths(resnets)
352
+ paths = renew_resnet_paths(resnets)
353
+
354
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
355
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
356
+
357
+ # オリジナル:
358
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
359
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
360
+
361
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
362
+ for l in output_block_list.values():
363
+ l.sort()
364
+
365
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
366
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
367
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
368
+ f"output_blocks.{i}.{index}.conv.bias"
369
+ ]
370
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
371
+ f"output_blocks.{i}.{index}.conv.weight"
372
+ ]
373
+
374
+ # Clear attentions as they have been attributed above.
375
+ if len(attentions) == 2:
376
+ attentions = []
377
+
378
+ if len(attentions):
379
+ paths = renew_attention_paths(attentions)
380
+ meta_path = {
381
+ "old": f"output_blocks.{i}.1",
382
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
383
+ }
384
+ assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
385
+ else:
386
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
387
+ for path in resnet_0_paths:
388
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
389
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
390
+
391
+ new_checkpoint[new_path] = unet_state_dict[old_path]
392
+
393
+ # SDのv2では1*1のconv2dがlinearに変わっている
394
+ # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
395
+ if v2 and not config.get("use_linear_projection", False):
396
+ linear_transformer_to_conv(new_checkpoint)
397
+
398
+ return new_checkpoint
399
+
400
+
401
+ def convert_ldm_vae_checkpoint(checkpoint, config):
402
+ # extract state dict for VAE
403
+ vae_state_dict = {}
404
+ vae_key = "first_stage_model."
405
+ keys = list(checkpoint.keys())
406
+ for key in keys:
407
+ if key.startswith(vae_key):
408
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
409
+ # if len(vae_state_dict) == 0:
410
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
411
+ # vae_state_dict = checkpoint
412
+
413
+ new_checkpoint = {}
414
+
415
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
416
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
417
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
418
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
419
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
420
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
421
+
422
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
423
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
424
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
425
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
426
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
427
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
428
+
429
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
430
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
431
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
432
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
433
+
434
+ # Retrieves the keys for the encoder down blocks only
435
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
436
+ down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)}
437
+
438
+ # Retrieves the keys for the decoder up blocks only
439
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
440
+ up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)}
441
+
442
+ for i in range(num_down_blocks):
443
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
444
+
445
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
446
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
447
+ f"encoder.down.{i}.downsample.conv.weight"
448
+ )
449
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
450
+ f"encoder.down.{i}.downsample.conv.bias"
451
+ )
452
+
453
+ paths = renew_vae_resnet_paths(resnets)
454
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
455
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
456
+
457
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
458
+ num_mid_res_blocks = 2
459
+ for i in range(1, num_mid_res_blocks + 1):
460
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
461
+
462
+ paths = renew_vae_resnet_paths(resnets)
463
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
464
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
465
+
466
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
467
+ paths = renew_vae_attention_paths(mid_attentions)
468
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
469
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
470
+ conv_attn_to_linear(new_checkpoint)
471
+
472
+ for i in range(num_up_blocks):
473
+ block_id = num_up_blocks - 1 - i
474
+ resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key]
475
+
476
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
477
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
478
+ f"decoder.up.{block_id}.upsample.conv.weight"
479
+ ]
480
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
481
+ f"decoder.up.{block_id}.upsample.conv.bias"
482
+ ]
483
+
484
+ paths = renew_vae_resnet_paths(resnets)
485
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
486
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
487
+
488
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
489
+ num_mid_res_blocks = 2
490
+ for i in range(1, num_mid_res_blocks + 1):
491
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
492
+
493
+ paths = renew_vae_resnet_paths(resnets)
494
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
495
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
496
+
497
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
498
+ paths = renew_vae_attention_paths(mid_attentions)
499
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
500
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
501
+ conv_attn_to_linear(new_checkpoint)
502
+ return new_checkpoint
503
+
504
+
505
+ def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
506
+ """
507
+ Creates a config for the diffusers based on the config of the LDM model.
508
+ """
509
+ # unet_params = original_config.model.params.unet_config.params
510
+
511
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
512
+
513
+ down_block_types = []
514
+ resolution = 1
515
+ for i in range(len(block_out_channels)):
516
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
517
+ down_block_types.append(block_type)
518
+ if i != len(block_out_channels) - 1:
519
+ resolution *= 2
520
+
521
+ up_block_types = []
522
+ for i in range(len(block_out_channels)):
523
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
524
+ up_block_types.append(block_type)
525
+ resolution //= 2
526
+
527
+ config = dict(
528
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
529
+ in_channels=UNET_PARAMS_IN_CHANNELS,
530
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
531
+ down_block_types=tuple(down_block_types),
532
+ up_block_types=tuple(up_block_types),
533
+ block_out_channels=tuple(block_out_channels),
534
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
535
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
536
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
537
+ # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
538
+ )
539
+ if v2 and use_linear_projection_in_v2:
540
+ config["use_linear_projection"] = True
541
+
542
+ return config
543
+
544
+
545
+ def create_vae_diffusers_config():
546
+ """
547
+ Creates a config for the diffusers based on the config of the LDM model.
548
+ """
549
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
550
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
551
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
552
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
553
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
554
+
555
+ config = dict(
556
+ sample_size=VAE_PARAMS_RESOLUTION,
557
+ in_channels=VAE_PARAMS_IN_CHANNELS,
558
+ out_channels=VAE_PARAMS_OUT_CH,
559
+ down_block_types=tuple(down_block_types),
560
+ up_block_types=tuple(up_block_types),
561
+ block_out_channels=tuple(block_out_channels),
562
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
563
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
564
+ )
565
+ return config
566
+
567
+
568
+ def convert_ldm_clip_checkpoint_v1(checkpoint):
569
+ keys = list(checkpoint.keys())
570
+ text_model_dict = {}
571
+ for key in keys:
572
+ if key.startswith("cond_stage_model.transformer"):
573
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
574
+
575
+ # remove position_ids for newer transformer, which causes error :(
576
+ if "text_model.embeddings.position_ids" in text_model_dict:
577
+ text_model_dict.pop("text_model.embeddings.position_ids")
578
+
579
+ return text_model_dict
580
+
581
+
582
+ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
583
+ # 嫌になるくらい違うぞ!
584
+ def convert_key(key):
585
+ if not key.startswith("cond_stage_model"):
586
+ return None
587
+
588
+ # common conversion
589
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
590
+ key = key.replace("cond_stage_model.model.", "text_model.")
591
+
592
+ if "resblocks" in key:
593
+ # resblocks conversion
594
+ key = key.replace(".resblocks.", ".layers.")
595
+ if ".ln_" in key:
596
+ key = key.replace(".ln_", ".layer_norm")
597
+ elif ".mlp." in key:
598
+ key = key.replace(".c_fc.", ".fc1.")
599
+ key = key.replace(".c_proj.", ".fc2.")
600
+ elif ".attn.out_proj" in key:
601
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
602
+ elif ".attn.in_proj" in key:
603
+ key = None # 特殊なので後で処理する
604
+ else:
605
+ raise ValueError(f"unexpected key in SD: {key}")
606
+ elif ".positional_embedding" in key:
607
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
608
+ elif ".text_projection" in key:
609
+ key = None # 使われない???
610
+ elif ".logit_scale" in key:
611
+ key = None # 使われない???
612
+ elif ".token_embedding" in key:
613
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
614
+ elif ".ln_final" in key:
615
+ key = key.replace(".ln_final", ".final_layer_norm")
616
+ return key
617
+
618
+ keys = list(checkpoint.keys())
619
+ new_sd = {}
620
+ for key in keys:
621
+ # remove resblocks 23
622
+ if ".resblocks.23." in key:
623
+ continue
624
+ new_key = convert_key(key)
625
+ if new_key is None:
626
+ continue
627
+ new_sd[new_key] = checkpoint[key]
628
+
629
+ # attnの変換
630
+ for key in keys:
631
+ if ".resblocks.23." in key:
632
+ continue
633
+ if ".resblocks" in key and ".attn.in_proj_" in key:
634
+ # 三つに分割
635
+ values = torch.chunk(checkpoint[key], 3)
636
+
637
+ key_suffix = ".weight" if "weight" in key else ".bias"
638
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
639
+ key_pfx = key_pfx.replace("_weight", "")
640
+ key_pfx = key_pfx.replace("_bias", "")
641
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
642
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
643
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
644
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
645
+
646
+ # rename or add position_ids
647
+ ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
648
+ if ANOTHER_POSITION_IDS_KEY in new_sd:
649
+ # waifu diffusion v1.4
650
+ position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
651
+ del new_sd[ANOTHER_POSITION_IDS_KEY]
652
+ else:
653
+ position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
654
+
655
+ new_sd["text_model.embeddings.position_ids"] = position_ids
656
+ return new_sd
657
+
658
+
659
+ # endregion
660
+
661
+
662
+ # region Diffusers->StableDiffusion の変換コード
663
+ # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
664
+
665
+
666
+ def conv_transformer_to_linear(checkpoint):
667
+ keys = list(checkpoint.keys())
668
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
669
+ for key in keys:
670
+ if ".".join(key.split(".")[-2:]) in tf_keys:
671
+ if checkpoint[key].ndim > 2:
672
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
673
+
674
+
675
+ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
676
+ unet_conversion_map = [
677
+ # (stable-diffusion, HF Diffusers)
678
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
679
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
680
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
681
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
682
+ ("input_blocks.0.0.weight", "conv_in.weight"),
683
+ ("input_blocks.0.0.bias", "conv_in.bias"),
684
+ ("out.0.weight", "conv_norm_out.weight"),
685
+ ("out.0.bias", "conv_norm_out.bias"),
686
+ ("out.2.weight", "conv_out.weight"),
687
+ ("out.2.bias", "conv_out.bias"),
688
+ ]
689
+
690
+ unet_conversion_map_resnet = [
691
+ # (stable-diffusion, HF Diffusers)
692
+ ("in_layers.0", "norm1"),
693
+ ("in_layers.2", "conv1"),
694
+ ("out_layers.0", "norm2"),
695
+ ("out_layers.3", "conv2"),
696
+ ("emb_layers.1", "time_emb_proj"),
697
+ ("skip_connection", "conv_shortcut"),
698
+ ]
699
+
700
+ unet_conversion_map_layer = []
701
+ for i in range(4):
702
+ # loop over downblocks/upblocks
703
+
704
+ for j in range(2):
705
+ # loop over resnets/attentions for downblocks
706
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
707
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
708
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
709
+
710
+ if i < 3:
711
+ # no attention layers in down_blocks.3
712
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
713
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
714
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
715
+
716
+ for j in range(3):
717
+ # loop over resnets/attentions for upblocks
718
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
719
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
720
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
721
+
722
+ if i > 0:
723
+ # no attention layers in up_blocks.0
724
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
725
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
726
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
727
+
728
+ if i < 3:
729
+ # no downsample in down_blocks.3
730
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
731
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
732
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
733
+
734
+ # no upsample in up_blocks.3
735
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
736
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
737
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
738
+
739
+ hf_mid_atn_prefix = "mid_block.attentions.0."
740
+ sd_mid_atn_prefix = "middle_block.1."
741
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
742
+
743
+ for j in range(2):
744
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
745
+ sd_mid_res_prefix = f"middle_block.{2*j}."
746
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
747
+
748
+ # buyer beware: this is a *brittle* function,
749
+ # and correct output requires that all of these pieces interact in
750
+ # the exact order in which I have arranged them.
751
+ mapping = {k: k for k in unet_state_dict.keys()}
752
+ for sd_name, hf_name in unet_conversion_map:
753
+ mapping[hf_name] = sd_name
754
+ for k, v in mapping.items():
755
+ if "resnets" in k:
756
+ for sd_part, hf_part in unet_conversion_map_resnet:
757
+ v = v.replace(hf_part, sd_part)
758
+ mapping[k] = v
759
+ for k, v in mapping.items():
760
+ for sd_part, hf_part in unet_conversion_map_layer:
761
+ v = v.replace(hf_part, sd_part)
762
+ mapping[k] = v
763
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
764
+
765
+ if v2:
766
+ conv_transformer_to_linear(new_state_dict)
767
+
768
+ return new_state_dict
769
+
770
+
771
+ def controlnet_conversion_map():
772
+ unet_conversion_map = [
773
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
774
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
775
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
776
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
777
+ ("input_blocks.0.0.weight", "conv_in.weight"),
778
+ ("input_blocks.0.0.bias", "conv_in.bias"),
779
+ ("middle_block_out.0.weight", "controlnet_mid_block.weight"),
780
+ ("middle_block_out.0.bias", "controlnet_mid_block.bias"),
781
+ ]
782
+
783
+ unet_conversion_map_resnet = [
784
+ ("in_layers.0", "norm1"),
785
+ ("in_layers.2", "conv1"),
786
+ ("out_layers.0", "norm2"),
787
+ ("out_layers.3", "conv2"),
788
+ ("emb_layers.1", "time_emb_proj"),
789
+ ("skip_connection", "conv_shortcut"),
790
+ ]
791
+
792
+ unet_conversion_map_layer = []
793
+ for i in range(4):
794
+ for j in range(2):
795
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
796
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
797
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
798
+
799
+ if i < 3:
800
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
801
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
802
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
803
+
804
+ if i < 3:
805
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
806
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
807
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
808
+
809
+ hf_mid_atn_prefix = "mid_block.attentions.0."
810
+ sd_mid_atn_prefix = "middle_block.1."
811
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
812
+
813
+ for j in range(2):
814
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
815
+ sd_mid_res_prefix = f"middle_block.{2*j}."
816
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
817
+
818
+ controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
819
+ for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
820
+ hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
821
+ sd_prefix = f"input_hint_block.{i*2}."
822
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
823
+
824
+ for i in range(12):
825
+ hf_prefix = f"controlnet_down_blocks.{i}."
826
+ sd_prefix = f"zero_convs.{i}.0."
827
+ unet_conversion_map_layer.append((sd_prefix, hf_prefix))
828
+
829
+ return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
830
+
831
+
832
+ def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
833
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
834
+
835
+ mapping = {k: k for k in controlnet_state_dict.keys()}
836
+ for sd_name, diffusers_name in unet_conversion_map:
837
+ mapping[diffusers_name] = sd_name
838
+ for k, v in mapping.items():
839
+ if "resnets" in k:
840
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
841
+ v = v.replace(diffusers_part, sd_part)
842
+ mapping[k] = v
843
+ for k, v in mapping.items():
844
+ for sd_part, diffusers_part in unet_conversion_map_layer:
845
+ v = v.replace(diffusers_part, sd_part)
846
+ mapping[k] = v
847
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
848
+ return new_state_dict
849
+
850
+
851
+ def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
852
+ unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
853
+
854
+ mapping = {k: k for k in controlnet_state_dict.keys()}
855
+ for sd_name, diffusers_name in unet_conversion_map:
856
+ mapping[sd_name] = diffusers_name
857
+ for k, v in mapping.items():
858
+ for sd_part, diffusers_part in unet_conversion_map_layer:
859
+ v = v.replace(sd_part, diffusers_part)
860
+ mapping[k] = v
861
+ for k, v in mapping.items():
862
+ if "resnets" in v:
863
+ for sd_part, diffusers_part in unet_conversion_map_resnet:
864
+ v = v.replace(sd_part, diffusers_part)
865
+ mapping[k] = v
866
+ new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
867
+ return new_state_dict
868
+
869
+
870
+ # ================#
871
+ # VAE Conversion #
872
+ # ================#
873
+
874
+
875
+ def reshape_weight_for_sd(w):
876
+ # convert HF linear weights to SD conv2d weights
877
+ return w.reshape(*w.shape, 1, 1)
878
+
879
+
880
+ def convert_vae_state_dict(vae_state_dict):
881
+ vae_conversion_map = [
882
+ # (stable-diffusion, HF Diffusers)
883
+ ("nin_shortcut", "conv_shortcut"),
884
+ ("norm_out", "conv_norm_out"),
885
+ ("mid.attn_1.", "mid_block.attentions.0."),
886
+ ]
887
+
888
+ for i in range(4):
889
+ # down_blocks have two resnets
890
+ for j in range(2):
891
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
892
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
893
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
894
+
895
+ if i < 3:
896
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
897
+ sd_downsample_prefix = f"down.{i}.downsample."
898
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
899
+
900
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
901
+ sd_upsample_prefix = f"up.{3-i}.upsample."
902
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
903
+
904
+ # up_blocks have three resnets
905
+ # also, up blocks in hf are numbered in reverse from sd
906
+ for j in range(3):
907
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
908
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
909
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
910
+
911
+ # this part accounts for mid blocks in both the encoder and the decoder
912
+ for i in range(2):
913
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
914
+ sd_mid_res_prefix = f"mid.block_{i+1}."
915
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
916
+
917
+ if diffusers.__version__ < "0.17.0":
918
+ vae_conversion_map_attn = [
919
+ # (stable-diffusion, HF Diffusers)
920
+ ("norm.", "group_norm."),
921
+ ("q.", "query."),
922
+ ("k.", "key."),
923
+ ("v.", "value."),
924
+ ("proj_out.", "proj_attn."),
925
+ ]
926
+ else:
927
+ vae_conversion_map_attn = [
928
+ # (stable-diffusion, HF Diffusers)
929
+ ("norm.", "group_norm."),
930
+ ("q.", "to_q."),
931
+ ("k.", "to_k."),
932
+ ("v.", "to_v."),
933
+ ("proj_out.", "to_out.0."),
934
+ ]
935
+
936
+ mapping = {k: k for k in vae_state_dict.keys()}
937
+ for k, v in mapping.items():
938
+ for sd_part, hf_part in vae_conversion_map:
939
+ v = v.replace(hf_part, sd_part)
940
+ mapping[k] = v
941
+ for k, v in mapping.items():
942
+ if "attentions" in k:
943
+ for sd_part, hf_part in vae_conversion_map_attn:
944
+ v = v.replace(hf_part, sd_part)
945
+ mapping[k] = v
946
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
947
+ weights_to_convert = ["q", "k", "v", "proj_out"]
948
+ for k, v in new_state_dict.items():
949
+ for weight_name in weights_to_convert:
950
+ if f"mid.attn_1.{weight_name}.weight" in k:
951
+ # logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
952
+ new_state_dict[k] = reshape_weight_for_sd(v)
953
+
954
+ return new_state_dict
955
+
956
+
957
+ # endregion
958
+
959
+ # region 自作のモデル読み書きなど
960
+
961
+
962
+ def is_safetensors(path):
963
+ return os.path.splitext(path)[1].lower() == ".safetensors"
964
+
965
+
966
+ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
967
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
968
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
969
+ ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
970
+ ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
971
+ ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
972
+ ]
973
+
974
+ if is_safetensors(ckpt_path):
975
+ checkpoint = None
976
+ state_dict = load_file(ckpt_path) # , device) # may causes error
977
+ else:
978
+ checkpoint = torch.load(ckpt_path, map_location=device)
979
+ if "state_dict" in checkpoint:
980
+ state_dict = checkpoint["state_dict"]
981
+ else:
982
+ state_dict = checkpoint
983
+ checkpoint = None
984
+
985
+ key_reps = []
986
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
987
+ for key in state_dict.keys():
988
+ if key.startswith(rep_from):
989
+ new_key = rep_to + key[len(rep_from) :]
990
+ key_reps.append((key, new_key))
991
+
992
+ for key, new_key in key_reps:
993
+ state_dict[new_key] = state_dict[key]
994
+ del state_dict[key]
995
+
996
+ return checkpoint, state_dict
997
+
998
+
999
+ # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
1000
+ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
1001
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
1002
+
1003
+ # Convert the UNet2DConditionModel model.
1004
+ unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
1005
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
1006
+
1007
+ unet = UNet2DConditionModel(**unet_config).to(device)
1008
+ info = unet.load_state_dict(converted_unet_checkpoint)
1009
+ logger.info(f"loading u-net: {info}")
1010
+
1011
+ # Convert the VAE model.
1012
+ vae_config = create_vae_diffusers_config()
1013
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
1014
+
1015
+ vae = AutoencoderKL(**vae_config).to(device)
1016
+ info = vae.load_state_dict(converted_vae_checkpoint)
1017
+ logger.info(f"loading vae: {info}")
1018
+
1019
+ # convert text_model
1020
+ if v2:
1021
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
1022
+ cfg = CLIPTextConfig(
1023
+ vocab_size=49408,
1024
+ hidden_size=1024,
1025
+ intermediate_size=4096,
1026
+ num_hidden_layers=23,
1027
+ num_attention_heads=16,
1028
+ max_position_embeddings=77,
1029
+ hidden_act="gelu",
1030
+ layer_norm_eps=1e-05,
1031
+ dropout=0.0,
1032
+ attention_dropout=0.0,
1033
+ initializer_range=0.02,
1034
+ initializer_factor=1.0,
1035
+ pad_token_id=1,
1036
+ bos_token_id=0,
1037
+ eos_token_id=2,
1038
+ model_type="clip_text_model",
1039
+ projection_dim=512,
1040
+ torch_dtype="float32",
1041
+ transformers_version="4.25.0.dev0",
1042
+ )
1043
+ text_model = CLIPTextModel._from_config(cfg)
1044
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1045
+ else:
1046
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
1047
+
1048
+ # logging.set_verbosity_error() # don't show annoying warning
1049
+ # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
1050
+ # logging.set_verbosity_warning()
1051
+ # logger.info(f"config: {text_model.config}")
1052
+ cfg = CLIPTextConfig(
1053
+ vocab_size=49408,
1054
+ hidden_size=768,
1055
+ intermediate_size=3072,
1056
+ num_hidden_layers=12,
1057
+ num_attention_heads=12,
1058
+ max_position_embeddings=77,
1059
+ hidden_act="quick_gelu",
1060
+ layer_norm_eps=1e-05,
1061
+ dropout=0.0,
1062
+ attention_dropout=0.0,
1063
+ initializer_range=0.02,
1064
+ initializer_factor=1.0,
1065
+ pad_token_id=1,
1066
+ bos_token_id=0,
1067
+ eos_token_id=2,
1068
+ model_type="clip_text_model",
1069
+ projection_dim=768,
1070
+ torch_dtype="float32",
1071
+ )
1072
+ text_model = CLIPTextModel._from_config(cfg)
1073
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
1074
+ logger.info(f"loading text encoder: {info}")
1075
+
1076
+ return text_model, vae, unet
1077
+
1078
+
1079
+ def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
1080
+ # only for reference
1081
+ version_str = "sd"
1082
+ if v2:
1083
+ version_str += "_v2"
1084
+ else:
1085
+ version_str += "_v1"
1086
+ if v_parameterization:
1087
+ version_str += "_v"
1088
+ return version_str
1089
+
1090
+
1091
+ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
1092
+ def convert_key(key):
1093
+ # position_idsの除去
1094
+ if ".position_ids" in key:
1095
+ return None
1096
+
1097
+ # common
1098
+ key = key.replace("text_model.encoder.", "transformer.")
1099
+ key = key.replace("text_model.", "")
1100
+ if "layers" in key:
1101
+ # resblocks conversion
1102
+ key = key.replace(".layers.", ".resblocks.")
1103
+ if ".layer_norm" in key:
1104
+ key = key.replace(".layer_norm", ".ln_")
1105
+ elif ".mlp." in key:
1106
+ key = key.replace(".fc1.", ".c_fc.")
1107
+ key = key.replace(".fc2.", ".c_proj.")
1108
+ elif ".self_attn.out_proj" in key:
1109
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
1110
+ elif ".self_attn." in key:
1111
+ key = None # 特殊なので後で処理する
1112
+ else:
1113
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
1114
+ elif ".position_embedding" in key:
1115
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
1116
+ elif ".token_embedding" in key:
1117
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
1118
+ elif "final_layer_norm" in key:
1119
+ key = key.replace("final_layer_norm", "ln_final")
1120
+ return key
1121
+
1122
+ keys = list(checkpoint.keys())
1123
+ new_sd = {}
1124
+ for key in keys:
1125
+ new_key = convert_key(key)
1126
+ if new_key is None:
1127
+ continue
1128
+ new_sd[new_key] = checkpoint[key]
1129
+
1130
+ # attnの変換
1131
+ for key in keys:
1132
+ if "layers" in key and "q_proj" in key:
1133
+ # 三つを結合
1134
+ key_q = key
1135
+ key_k = key.replace("q_proj", "k_proj")
1136
+ key_v = key.replace("q_proj", "v_proj")
1137
+
1138
+ value_q = checkpoint[key_q]
1139
+ value_k = checkpoint[key_k]
1140
+ value_v = checkpoint[key_v]
1141
+ value = torch.cat([value_q, value_k, value_v])
1142
+
1143
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
1144
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
1145
+ new_sd[new_key] = value
1146
+
1147
+ # 最後の層などを捏造するか
1148
+ if make_dummy_weights:
1149
+ logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
1150
+ keys = list(new_sd.keys())
1151
+ for key in keys:
1152
+ if key.startswith("transformer.resblocks.22."):
1153
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
1154
+
1155
+ # Diffusersに含まれない重みを作っておく
1156
+ new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
1157
+ new_sd["logit_scale"] = torch.tensor(1)
1158
+
1159
+ return new_sd
1160
+
1161
+
1162
+ def save_stable_diffusion_checkpoint(
1163
+ v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
1164
+ ):
1165
+ if ckpt_path is not None:
1166
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
1167
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
1168
+ if checkpoint is None: # safetensors または state_dictのckpt
1169
+ checkpoint = {}
1170
+ strict = False
1171
+ else:
1172
+ strict = True
1173
+ if "state_dict" in state_dict:
1174
+ del state_dict["state_dict"]
1175
+ else:
1176
+ # 新しく作る
1177
+ assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
1178
+ checkpoint = {}
1179
+ state_dict = {}
1180
+ strict = False
1181
+
1182
+ def update_sd(prefix, sd):
1183
+ for k, v in sd.items():
1184
+ key = prefix + k
1185
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
1186
+ if save_dtype is not None:
1187
+ v = v.detach().clone().to("cpu").to(save_dtype)
1188
+ state_dict[key] = v
1189
+
1190
+ # Convert the UNet model
1191
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
1192
+ update_sd("model.diffusion_model.", unet_state_dict)
1193
+
1194
+ # Convert the text encoder model
1195
+ if v2:
1196
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
1197
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
1198
+ update_sd("cond_stage_model.model.", text_enc_dict)
1199
+ else:
1200
+ text_enc_dict = text_encoder.state_dict()
1201
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
1202
+
1203
+ # Convert the VAE
1204
+ if vae is not None:
1205
+ vae_dict = convert_vae_state_dict(vae.state_dict())
1206
+ update_sd("first_stage_model.", vae_dict)
1207
+
1208
+ # Put together new checkpoint
1209
+ key_count = len(state_dict.keys())
1210
+ new_ckpt = {"state_dict": state_dict}
1211
+
1212
+ # epoch and global_step are sometimes not int
1213
+ try:
1214
+ if "epoch" in checkpoint:
1215
+ epochs += checkpoint["epoch"]
1216
+ if "global_step" in checkpoint:
1217
+ steps += checkpoint["global_step"]
1218
+ except:
1219
+ pass
1220
+
1221
+ new_ckpt["epoch"] = epochs
1222
+ new_ckpt["global_step"] = steps
1223
+
1224
+ if is_safetensors(output_file):
1225
+ # TODO Tensor以外のdictの値を削除したほうがいいか
1226
+ save_file(state_dict, output_file, metadata)
1227
+ else:
1228
+ torch.save(new_ckpt, output_file)
1229
+
1230
+ return key_count
1231
+
1232
+
1233
+ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
1234
+ if pretrained_model_name_or_path is None:
1235
+ # load default settings for v1/v2
1236
+ if v2:
1237
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
1238
+ else:
1239
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
1240
+
1241
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
1242
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
1243
+ if vae is None:
1244
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
1245
+
1246
+ # original U-Net cannot be saved, so we need to convert it to the Diffusers version
1247
+ # TODO this consumes a lot of memory
1248
+ diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
1249
+ diffusers_unet.load_state_dict(unet.state_dict())
1250
+
1251
+ pipeline = StableDiffusionPipeline(
1252
+ unet=diffusers_unet,
1253
+ text_encoder=text_encoder,
1254
+ vae=vae,
1255
+ scheduler=scheduler,
1256
+ tokenizer=tokenizer,
1257
+ safety_checker=None,
1258
+ feature_extractor=None,
1259
+ requires_safety_checker=None,
1260
+ )
1261
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
1262
+
1263
+
1264
+ VAE_PREFIX = "first_stage_model."
1265
+
1266
+
1267
+ def load_vae(vae_id, dtype):
1268
+ logger.info(f"load VAE: {vae_id}")
1269
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
1270
+ # Diffusers local/remote
1271
+ try:
1272
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
1273
+ except EnvironmentError as e:
1274
+ logger.error(f"exception occurs in loading vae: {e}")
1275
+ logger.error("retry with subfolder='vae'")
1276
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
1277
+ return vae
1278
+
1279
+ # local
1280
+ vae_config = create_vae_diffusers_config()
1281
+
1282
+ if vae_id.endswith(".bin"):
1283
+ # SD 1.5 VAE on Huggingface
1284
+ converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
1285
+ else:
1286
+ # StableDiffusion
1287
+ vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu")
1288
+ vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model
1289
+
1290
+ # vae only or full model
1291
+ full_model = False
1292
+ for vae_key in vae_sd:
1293
+ if vae_key.startswith(VAE_PREFIX):
1294
+ full_model = True
1295
+ break
1296
+ if not full_model:
1297
+ sd = {}
1298
+ for key, value in vae_sd.items():
1299
+ sd[VAE_PREFIX + key] = value
1300
+ vae_sd = sd
1301
+ del sd
1302
+
1303
+ # Convert the VAE model.
1304
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
1305
+
1306
+ vae = AutoencoderKL(**vae_config)
1307
+ vae.load_state_dict(converted_vae_checkpoint)
1308
+ return vae
1309
+
1310
+
1311
+ # endregion
1312
+
1313
+
1314
+ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
1315
+ max_width, max_height = max_reso
1316
+ max_area = max_width * max_height
1317
+
1318
+ resos = set()
1319
+
1320
+ width = int(math.sqrt(max_area) // divisible) * divisible
1321
+ resos.add((width, width))
1322
+
1323
+ width = min_size
1324
+ while width <= max_size:
1325
+ height = min(max_size, int((max_area // width) // divisible) * divisible)
1326
+ if height >= min_size:
1327
+ resos.add((width, height))
1328
+ resos.add((height, width))
1329
+
1330
+ # # make additional resos
1331
+ # if width >= height and width - divisible >= min_size:
1332
+ # resos.add((width - divisible, height))
1333
+ # resos.add((height, width - divisible))
1334
+ # if height >= width and height - divisible >= min_size:
1335
+ # resos.add((width, height - divisible))
1336
+ # resos.add((height - divisible, width))
1337
+
1338
+ width += divisible
1339
+
1340
+ resos = list(resos)
1341
+ resos.sort()
1342
+ return resos
1343
+
1344
+
1345
+ if __name__ == "__main__":
1346
+ resos = make_bucket_resolutions((512, 768))
1347
+ logger.info(f"{len(resos)}")
1348
+ logger.info(f"{resos}")
1349
+ aspect_ratios = [w / h for w, h in resos]
1350
+ logger.info(f"{aspect_ratios}")
1351
+
1352
+ ars = set()
1353
+ for ar in aspect_ratios:
1354
+ if ar in ars:
1355
+ logger.error(f"error! duplicate ar: {ar}")
1356
+ ars.add(ar)
library/original_unet.py ADDED
@@ -0,0 +1,1919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusers 0.10.2からStable Diffusionに必要な部分だけを持ってくる
2
+ # 条件分岐等で不要な部分は削除している
3
+ # コードの多くはDiffusersからコピーしている
4
+ # 制約として、モデルのstate_dictがDiffusers 0.10.2のものと同じ形式である必要がある
5
+
6
+ # Copy from Diffusers 0.10.2 for Stable Diffusion. Most of the code is copied from Diffusers.
7
+ # Unnecessary parts are deleted by condition branching.
8
+ # As a constraint, the state_dict of the model must be in the same format as that of Diffusers 0.10.2
9
+
10
+ """
11
+ v1.5とv2.1の相違点は
12
+ - attention_head_dimがintかlist[int]か
13
+ - cross_attention_dimが768か1024か
14
+ - use_linear_projection: trueがない(=False, 1.5)かあるか
15
+ - upcast_attentionがFalse(1.5)かTrue(2.1)か
16
+ - (以下は多分無視していい)
17
+ - sample_sizeが64か96か
18
+ - dual_cross_attentionがあるかないか
19
+ - num_class_embedsがあるかないか
20
+ - only_cross_attentionがあるかないか
21
+
22
+ v1.5
23
+ {
24
+ "_class_name": "UNet2DConditionModel",
25
+ "_diffusers_version": "0.6.0",
26
+ "act_fn": "silu",
27
+ "attention_head_dim": 8,
28
+ "block_out_channels": [
29
+ 320,
30
+ 640,
31
+ 1280,
32
+ 1280
33
+ ],
34
+ "center_input_sample": false,
35
+ "cross_attention_dim": 768,
36
+ "down_block_types": [
37
+ "CrossAttnDownBlock2D",
38
+ "CrossAttnDownBlock2D",
39
+ "CrossAttnDownBlock2D",
40
+ "DownBlock2D"
41
+ ],
42
+ "downsample_padding": 1,
43
+ "flip_sin_to_cos": true,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_scale_factor": 1,
48
+ "norm_eps": 1e-05,
49
+ "norm_num_groups": 32,
50
+ "out_channels": 4,
51
+ "sample_size": 64,
52
+ "up_block_types": [
53
+ "UpBlock2D",
54
+ "CrossAttnUpBlock2D",
55
+ "CrossAttnUpBlock2D",
56
+ "CrossAttnUpBlock2D"
57
+ ]
58
+ }
59
+
60
+ v2.1
61
+ {
62
+ "_class_name": "UNet2DConditionModel",
63
+ "_diffusers_version": "0.10.0.dev0",
64
+ "act_fn": "silu",
65
+ "attention_head_dim": [
66
+ 5,
67
+ 10,
68
+ 20,
69
+ 20
70
+ ],
71
+ "block_out_channels": [
72
+ 320,
73
+ 640,
74
+ 1280,
75
+ 1280
76
+ ],
77
+ "center_input_sample": false,
78
+ "cross_attention_dim": 1024,
79
+ "down_block_types": [
80
+ "CrossAttnDownBlock2D",
81
+ "CrossAttnDownBlock2D",
82
+ "CrossAttnDownBlock2D",
83
+ "DownBlock2D"
84
+ ],
85
+ "downsample_padding": 1,
86
+ "dual_cross_attention": false,
87
+ "flip_sin_to_cos": true,
88
+ "freq_shift": 0,
89
+ "in_channels": 4,
90
+ "layers_per_block": 2,
91
+ "mid_block_scale_factor": 1,
92
+ "norm_eps": 1e-05,
93
+ "norm_num_groups": 32,
94
+ "num_class_embeds": null,
95
+ "only_cross_attention": false,
96
+ "out_channels": 4,
97
+ "sample_size": 96,
98
+ "up_block_types": [
99
+ "UpBlock2D",
100
+ "CrossAttnUpBlock2D",
101
+ "CrossAttnUpBlock2D",
102
+ "CrossAttnUpBlock2D"
103
+ ],
104
+ "use_linear_projection": true,
105
+ "upcast_attention": true
106
+ }
107
+ """
108
+
109
+ import math
110
+ from types import SimpleNamespace
111
+ from typing import Dict, Optional, Tuple, Union
112
+ import torch
113
+ from torch import nn
114
+ from torch.nn import functional as F
115
+ from einops import rearrange
116
+ from library.utils import setup_logging
117
+ setup_logging()
118
+ import logging
119
+ logger = logging.getLogger(__name__)
120
+
121
+ BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
122
+ TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
123
+ TIME_EMBED_DIM = BLOCK_OUT_CHANNELS[0] * 4
124
+ IN_CHANNELS: int = 4
125
+ OUT_CHANNELS: int = 4
126
+ LAYERS_PER_BLOCK: int = 2
127
+ LAYERS_PER_BLOCK_UP: int = LAYERS_PER_BLOCK + 1
128
+ TIME_EMBED_FLIP_SIN_TO_COS: bool = True
129
+ TIME_EMBED_FREQ_SHIFT: int = 0
130
+ NORM_GROUPS: int = 32
131
+ NORM_EPS: float = 1e-5
132
+ TRANSFORMER_NORM_NUM_GROUPS = 32
133
+
134
+ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]
135
+ UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
136
+
137
+
138
+ # region memory efficient attention
139
+
140
+ # FlashAttentionを使うCrossAttention
141
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
142
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
143
+
144
+ # constants
145
+
146
+ EPSILON = 1e-6
147
+
148
+ # helper functions
149
+
150
+
151
+ def exists(val):
152
+ return val is not None
153
+
154
+
155
+ def default(val, d):
156
+ return val if exists(val) else d
157
+
158
+
159
+ # flash attention forwards and backwards
160
+
161
+ # https://arxiv.org/abs/2205.14135
162
+
163
+
164
+ class FlashAttentionFunction(torch.autograd.Function):
165
+ @staticmethod
166
+ @torch.no_grad()
167
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
168
+ """Algorithm 2 in the paper"""
169
+
170
+ device = q.device
171
+ dtype = q.dtype
172
+ max_neg_value = -torch.finfo(q.dtype).max
173
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
174
+
175
+ o = torch.zeros_like(q)
176
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
177
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
178
+
179
+ scale = q.shape[-1] ** -0.5
180
+
181
+ if not exists(mask):
182
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
183
+ else:
184
+ mask = rearrange(mask, "b n -> b 1 1 n")
185
+ mask = mask.split(q_bucket_size, dim=-1)
186
+
187
+ row_splits = zip(
188
+ q.split(q_bucket_size, dim=-2),
189
+ o.split(q_bucket_size, dim=-2),
190
+ mask,
191
+ all_row_sums.split(q_bucket_size, dim=-2),
192
+ all_row_maxes.split(q_bucket_size, dim=-2),
193
+ )
194
+
195
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
196
+ q_start_index = ind * q_bucket_size - qk_len_diff
197
+
198
+ col_splits = zip(
199
+ k.split(k_bucket_size, dim=-2),
200
+ v.split(k_bucket_size, dim=-2),
201
+ )
202
+
203
+ for k_ind, (kc, vc) in enumerate(col_splits):
204
+ k_start_index = k_ind * k_bucket_size
205
+
206
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
207
+
208
+ if exists(row_mask):
209
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
210
+
211
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
212
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
213
+ q_start_index - k_start_index + 1
214
+ )
215
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
216
+
217
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
218
+ attn_weights -= block_row_maxes
219
+ exp_weights = torch.exp(attn_weights)
220
+
221
+ if exists(row_mask):
222
+ exp_weights.masked_fill_(~row_mask, 0.0)
223
+
224
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
225
+
226
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
227
+
228
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
229
+
230
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
231
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
232
+
233
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
234
+
235
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
236
+
237
+ row_maxes.copy_(new_row_maxes)
238
+ row_sums.copy_(new_row_sums)
239
+
240
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
241
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
242
+
243
+ return o
244
+
245
+ @staticmethod
246
+ @torch.no_grad()
247
+ def backward(ctx, do):
248
+ """Algorithm 4 in the paper"""
249
+
250
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
251
+ q, k, v, o, l, m = ctx.saved_tensors
252
+
253
+ device = q.device
254
+
255
+ max_neg_value = -torch.finfo(q.dtype).max
256
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
257
+
258
+ dq = torch.zeros_like(q)
259
+ dk = torch.zeros_like(k)
260
+ dv = torch.zeros_like(v)
261
+
262
+ row_splits = zip(
263
+ q.split(q_bucket_size, dim=-2),
264
+ o.split(q_bucket_size, dim=-2),
265
+ do.split(q_bucket_size, dim=-2),
266
+ mask,
267
+ l.split(q_bucket_size, dim=-2),
268
+ m.split(q_bucket_size, dim=-2),
269
+ dq.split(q_bucket_size, dim=-2),
270
+ )
271
+
272
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
273
+ q_start_index = ind * q_bucket_size - qk_len_diff
274
+
275
+ col_splits = zip(
276
+ k.split(k_bucket_size, dim=-2),
277
+ v.split(k_bucket_size, dim=-2),
278
+ dk.split(k_bucket_size, dim=-2),
279
+ dv.split(k_bucket_size, dim=-2),
280
+ )
281
+
282
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
283
+ k_start_index = k_ind * k_bucket_size
284
+
285
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
286
+
287
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
288
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
289
+ q_start_index - k_start_index + 1
290
+ )
291
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
292
+
293
+ exp_attn_weights = torch.exp(attn_weights - mc)
294
+
295
+ if exists(row_mask):
296
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
297
+
298
+ p = exp_attn_weights / lc
299
+
300
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
301
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
302
+
303
+ D = (doc * oc).sum(dim=-1, keepdims=True)
304
+ ds = p * scale * (dp - D)
305
+
306
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
307
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
308
+
309
+ dqc.add_(dq_chunk)
310
+ dkc.add_(dk_chunk)
311
+ dvc.add_(dv_chunk)
312
+
313
+ return dq, dk, dv, None, None, None, None
314
+
315
+
316
+ # endregion
317
+
318
+
319
+ def get_parameter_dtype(parameter: torch.nn.Module):
320
+ return next(parameter.parameters()).dtype
321
+
322
+
323
+ def get_parameter_device(parameter: torch.nn.Module):
324
+ return next(parameter.parameters()).device
325
+
326
+
327
+ def get_timestep_embedding(
328
+ timesteps: torch.Tensor,
329
+ embedding_dim: int,
330
+ flip_sin_to_cos: bool = False,
331
+ downscale_freq_shift: float = 1,
332
+ scale: float = 1,
333
+ max_period: int = 10000,
334
+ ):
335
+ """
336
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
337
+
338
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
339
+ These may be fractional.
340
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
341
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
342
+ """
343
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
344
+
345
+ half_dim = embedding_dim // 2
346
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
347
+ exponent = exponent / (half_dim - downscale_freq_shift)
348
+
349
+ emb = torch.exp(exponent)
350
+ emb = timesteps[:, None].float() * emb[None, :]
351
+
352
+ # scale embeddings
353
+ emb = scale * emb
354
+
355
+ # concat sine and cosine embeddings
356
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
357
+
358
+ # flip sine and cosine embeddings
359
+ if flip_sin_to_cos:
360
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
361
+
362
+ # zero pad
363
+ if embedding_dim % 2 == 1:
364
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
365
+ return emb
366
+
367
+
368
+ # Deep Shrink: We do not common this function, because minimize dependencies.
369
+ def resize_like(x, target, mode="bicubic", align_corners=False):
370
+ org_dtype = x.dtype
371
+ if org_dtype == torch.bfloat16:
372
+ x = x.to(torch.float32)
373
+
374
+ if x.shape[-2:] != target.shape[-2:]:
375
+ if mode == "nearest":
376
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
377
+ else:
378
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
379
+
380
+ if org_dtype == torch.bfloat16:
381
+ x = x.to(org_dtype)
382
+ return x
383
+
384
+
385
+ class SampleOutput:
386
+ def __init__(self, sample):
387
+ self.sample = sample
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
392
+ super().__init__()
393
+
394
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
395
+ self.act = None
396
+ if act_fn == "silu":
397
+ self.act = nn.SiLU()
398
+ elif act_fn == "mish":
399
+ self.act = nn.Mish()
400
+
401
+ if out_dim is not None:
402
+ time_embed_dim_out = out_dim
403
+ else:
404
+ time_embed_dim_out = time_embed_dim
405
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
406
+
407
+ def forward(self, sample):
408
+ sample = self.linear_1(sample)
409
+
410
+ if self.act is not None:
411
+ sample = self.act(sample)
412
+
413
+ sample = self.linear_2(sample)
414
+ return sample
415
+
416
+
417
+ class Timesteps(nn.Module):
418
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
419
+ super().__init__()
420
+ self.num_channels = num_channels
421
+ self.flip_sin_to_cos = flip_sin_to_cos
422
+ self.downscale_freq_shift = downscale_freq_shift
423
+
424
+ def forward(self, timesteps):
425
+ t_emb = get_timestep_embedding(
426
+ timesteps,
427
+ self.num_channels,
428
+ flip_sin_to_cos=self.flip_sin_to_cos,
429
+ downscale_freq_shift=self.downscale_freq_shift,
430
+ )
431
+ return t_emb
432
+
433
+
434
+ class ResnetBlock2D(nn.Module):
435
+ def __init__(
436
+ self,
437
+ in_channels,
438
+ out_channels,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+
444
+ self.norm1 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=in_channels, eps=NORM_EPS, affine=True)
445
+
446
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
447
+
448
+ self.time_emb_proj = torch.nn.Linear(TIME_EMBED_DIM, out_channels)
449
+
450
+ self.norm2 = torch.nn.GroupNorm(num_groups=NORM_GROUPS, num_channels=out_channels, eps=NORM_EPS, affine=True)
451
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
452
+
453
+ # if non_linearity == "swish":
454
+ self.nonlinearity = lambda x: F.silu(x)
455
+
456
+ self.use_in_shortcut = self.in_channels != self.out_channels
457
+
458
+ self.conv_shortcut = None
459
+ if self.use_in_shortcut:
460
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
461
+
462
+ def forward(self, input_tensor, temb):
463
+ hidden_states = input_tensor
464
+
465
+ hidden_states = self.norm1(hidden_states)
466
+ hidden_states = self.nonlinearity(hidden_states)
467
+
468
+ hidden_states = self.conv1(hidden_states)
469
+
470
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
471
+ hidden_states = hidden_states + temb
472
+
473
+ hidden_states = self.norm2(hidden_states)
474
+ hidden_states = self.nonlinearity(hidden_states)
475
+
476
+ hidden_states = self.conv2(hidden_states)
477
+
478
+ if self.conv_shortcut is not None:
479
+ input_tensor = self.conv_shortcut(input_tensor)
480
+
481
+ output_tensor = input_tensor + hidden_states
482
+
483
+ return output_tensor
484
+
485
+
486
+ class DownBlock2D(nn.Module):
487
+ def __init__(
488
+ self,
489
+ in_channels: int,
490
+ out_channels: int,
491
+ add_downsample=True,
492
+ ):
493
+ super().__init__()
494
+
495
+ self.has_cross_attention = False
496
+ resnets = []
497
+
498
+ for i in range(LAYERS_PER_BLOCK):
499
+ in_channels = in_channels if i == 0 else out_channels
500
+ resnets.append(
501
+ ResnetBlock2D(
502
+ in_channels=in_channels,
503
+ out_channels=out_channels,
504
+ )
505
+ )
506
+ self.resnets = nn.ModuleList(resnets)
507
+
508
+ if add_downsample:
509
+ self.downsamplers = [Downsample2D(out_channels, out_channels=out_channels)]
510
+ else:
511
+ self.downsamplers = None
512
+
513
+ self.gradient_checkpointing = False
514
+
515
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
516
+ pass
517
+
518
+ def set_use_sdpa(self, sdpa):
519
+ pass
520
+
521
+ def forward(self, hidden_states, temb=None):
522
+ output_states = ()
523
+
524
+ for resnet in self.resnets:
525
+ if self.training and self.gradient_checkpointing:
526
+
527
+ def create_custom_forward(module):
528
+ def custom_forward(*inputs):
529
+ return module(*inputs)
530
+
531
+ return custom_forward
532
+
533
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
534
+ else:
535
+ hidden_states = resnet(hidden_states, temb)
536
+
537
+ output_states += (hidden_states,)
538
+
539
+ if self.downsamplers is not None:
540
+ for downsampler in self.downsamplers:
541
+ hidden_states = downsampler(hidden_states)
542
+
543
+ output_states += (hidden_states,)
544
+
545
+ return hidden_states, output_states
546
+
547
+
548
+ class Downsample2D(nn.Module):
549
+ def __init__(self, channels, out_channels):
550
+ super().__init__()
551
+
552
+ self.channels = channels
553
+ self.out_channels = out_channels
554
+
555
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
556
+
557
+ def forward(self, hidden_states):
558
+ assert hidden_states.shape[1] == self.channels
559
+ hidden_states = self.conv(hidden_states)
560
+
561
+ return hidden_states
562
+
563
+
564
+ class CrossAttention(nn.Module):
565
+ def __init__(
566
+ self,
567
+ query_dim: int,
568
+ cross_attention_dim: Optional[int] = None,
569
+ heads: int = 8,
570
+ dim_head: int = 64,
571
+ upcast_attention: bool = False,
572
+ ):
573
+ super().__init__()
574
+ inner_dim = dim_head * heads
575
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
576
+ self.upcast_attention = upcast_attention
577
+
578
+ self.scale = dim_head**-0.5
579
+ self.heads = heads
580
+
581
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
582
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
583
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
584
+
585
+ self.to_out = nn.ModuleList([])
586
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
587
+ # no dropout here
588
+
589
+ self.use_memory_efficient_attention_xformers = False
590
+ self.use_memory_efficient_attention_mem_eff = False
591
+ self.use_sdpa = False
592
+
593
+ # Attention processor
594
+ self.processor = None
595
+
596
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
597
+ self.use_memory_efficient_attention_xformers = xformers
598
+ self.use_memory_efficient_attention_mem_eff = mem_eff
599
+
600
+ def set_use_sdpa(self, sdpa):
601
+ self.use_sdpa = sdpa
602
+
603
+ def reshape_heads_to_batch_dim(self, tensor):
604
+ batch_size, seq_len, dim = tensor.shape
605
+ head_size = self.heads
606
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
607
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
608
+ return tensor
609
+
610
+ def reshape_batch_dim_to_heads(self, tensor):
611
+ batch_size, seq_len, dim = tensor.shape
612
+ head_size = self.heads
613
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
614
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
615
+ return tensor
616
+
617
+ def set_processor(self):
618
+ return self.processor
619
+
620
+ def get_processor(self):
621
+ return self.processor
622
+
623
+ def forward(self, hidden_states, context=None, mask=None, **kwargs):
624
+ if self.processor is not None:
625
+ (
626
+ hidden_states,
627
+ encoder_hidden_states,
628
+ attention_mask,
629
+ ) = translate_attention_names_from_diffusers(
630
+ hidden_states=hidden_states, context=context, mask=mask, **kwargs
631
+ )
632
+ return self.processor(
633
+ attn=self,
634
+ hidden_states=hidden_states,
635
+ encoder_hidden_states=context,
636
+ attention_mask=mask,
637
+ **kwargs
638
+ )
639
+ if self.use_memory_efficient_attention_xformers:
640
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
641
+ if self.use_memory_efficient_attention_mem_eff:
642
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
643
+ if self.use_sdpa:
644
+ return self.forward_sdpa(hidden_states, context, mask)
645
+
646
+ query = self.to_q(hidden_states)
647
+ context = context if context is not None else hidden_states
648
+ key = self.to_k(context)
649
+ value = self.to_v(context)
650
+
651
+ query = self.reshape_heads_to_batch_dim(query)
652
+ key = self.reshape_heads_to_batch_dim(key)
653
+ value = self.reshape_heads_to_batch_dim(value)
654
+
655
+ hidden_states = self._attention(query, key, value)
656
+
657
+ # linear proj
658
+ hidden_states = self.to_out[0](hidden_states)
659
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
660
+ return hidden_states
661
+
662
+ def _attention(self, query, key, value):
663
+ if self.upcast_attention:
664
+ query = query.float()
665
+ key = key.float()
666
+
667
+ attention_scores = torch.baddbmm(
668
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
669
+ query,
670
+ key.transpose(-1, -2),
671
+ beta=0,
672
+ alpha=self.scale,
673
+ )
674
+ attention_probs = attention_scores.softmax(dim=-1)
675
+
676
+ # cast back to the original dtype
677
+ attention_probs = attention_probs.to(value.dtype)
678
+
679
+ # compute attention output
680
+ hidden_states = torch.bmm(attention_probs, value)
681
+
682
+ # reshape hidden_states
683
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
684
+ return hidden_states
685
+
686
+ # TODO support Hypernetworks
687
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
688
+ import xformers.ops
689
+
690
+ h = self.heads
691
+ q_in = self.to_q(x)
692
+ context = context if context is not None else x
693
+ context = context.to(x.dtype)
694
+ k_in = self.to_k(context)
695
+ v_in = self.to_v(context)
696
+
697
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
698
+ del q_in, k_in, v_in
699
+
700
+ q = q.contiguous()
701
+ k = k.contiguous()
702
+ v = v.contiguous()
703
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
704
+
705
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
706
+
707
+ out = self.to_out[0](out)
708
+ return out
709
+
710
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
711
+ flash_func = FlashAttentionFunction
712
+
713
+ q_bucket_size = 512
714
+ k_bucket_size = 1024
715
+
716
+ h = self.heads
717
+ q = self.to_q(x)
718
+ context = context if context is not None else x
719
+ context = context.to(x.dtype)
720
+ k = self.to_k(context)
721
+ v = self.to_v(context)
722
+ del context, x
723
+
724
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
725
+
726
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
727
+
728
+ out = rearrange(out, "b h n d -> b n (h d)")
729
+
730
+ out = self.to_out[0](out)
731
+ return out
732
+
733
+ def forward_sdpa(self, x, context=None, mask=None):
734
+ h = self.heads
735
+ q_in = self.to_q(x)
736
+ context = context if context is not None else x
737
+ context = context.to(x.dtype)
738
+ k_in = self.to_k(context)
739
+ v_in = self.to_v(context)
740
+
741
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
742
+ del q_in, k_in, v_in
743
+
744
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
745
+
746
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
747
+
748
+ out = self.to_out[0](out)
749
+ return out
750
+
751
+ def translate_attention_names_from_diffusers(
752
+ hidden_states: torch.FloatTensor,
753
+ context: Optional[torch.FloatTensor] = None,
754
+ mask: Optional[torch.FloatTensor] = None,
755
+ # HF naming
756
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
757
+ attention_mask: Optional[torch.FloatTensor] = None
758
+ ):
759
+ # translate from hugging face diffusers
760
+ context = context if context is not None else encoder_hidden_states
761
+
762
+ # translate from hugging face diffusers
763
+ mask = mask if mask is not None else attention_mask
764
+
765
+ return hidden_states, context, mask
766
+
767
+ # feedforward
768
+ class GEGLU(nn.Module):
769
+ r"""
770
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
771
+
772
+ Parameters:
773
+ dim_in (`int`): The number of channels in the input.
774
+ dim_out (`int`): The number of channels in the output.
775
+ """
776
+
777
+ def __init__(self, dim_in: int, dim_out: int):
778
+ super().__init__()
779
+ self.proj = nn.Linear(dim_in, dim_out * 2)
780
+
781
+ def gelu(self, gate):
782
+ if gate.device.type != "mps":
783
+ return F.gelu(gate)
784
+ # mps: gelu is not implemented for float16
785
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
786
+
787
+ def forward(self, hidden_states):
788
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
789
+ return hidden_states * self.gelu(gate)
790
+
791
+
792
+ class FeedForward(nn.Module):
793
+ def __init__(
794
+ self,
795
+ dim: int,
796
+ ):
797
+ super().__init__()
798
+ inner_dim = int(dim * 4) # mult is always 4
799
+
800
+ self.net = nn.ModuleList([])
801
+ # project in
802
+ self.net.append(GEGLU(dim, inner_dim))
803
+ # project dropout
804
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
805
+ # project out
806
+ self.net.append(nn.Linear(inner_dim, dim))
807
+
808
+ def forward(self, hidden_states):
809
+ for module in self.net:
810
+ hidden_states = module(hidden_states)
811
+ return hidden_states
812
+
813
+
814
+ class BasicTransformerBlock(nn.Module):
815
+ def __init__(
816
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
817
+ ):
818
+ super().__init__()
819
+
820
+ # 1. Self-Attn
821
+ self.attn1 = CrossAttention(
822
+ query_dim=dim,
823
+ cross_attention_dim=None,
824
+ heads=num_attention_heads,
825
+ dim_head=attention_head_dim,
826
+ upcast_attention=upcast_attention,
827
+ )
828
+ self.ff = FeedForward(dim)
829
+
830
+ # 2. Cross-Attn
831
+ self.attn2 = CrossAttention(
832
+ query_dim=dim,
833
+ cross_attention_dim=cross_attention_dim,
834
+ heads=num_attention_heads,
835
+ dim_head=attention_head_dim,
836
+ upcast_attention=upcast_attention,
837
+ )
838
+
839
+ self.norm1 = nn.LayerNorm(dim)
840
+ self.norm2 = nn.LayerNorm(dim)
841
+
842
+ # 3. Feed-forward
843
+ self.norm3 = nn.LayerNorm(dim)
844
+
845
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
846
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
847
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
848
+
849
+ def set_use_sdpa(self, sdpa: bool):
850
+ self.attn1.set_use_sdpa(sdpa)
851
+ self.attn2.set_use_sdpa(sdpa)
852
+
853
+ def forward(self, hidden_states, context=None, timestep=None):
854
+ # 1. Self-Attention
855
+ norm_hidden_states = self.norm1(hidden_states)
856
+
857
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
858
+
859
+ # 2. Cross-Attention
860
+ norm_hidden_states = self.norm2(hidden_states)
861
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
862
+
863
+ # 3. Feed-forward
864
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
865
+
866
+ return hidden_states
867
+
868
+
869
+ class Transformer2DModel(nn.Module):
870
+ def __init__(
871
+ self,
872
+ num_attention_heads: int = 16,
873
+ attention_head_dim: int = 88,
874
+ in_channels: Optional[int] = None,
875
+ cross_attention_dim: Optional[int] = None,
876
+ use_linear_projection: bool = False,
877
+ upcast_attention: bool = False,
878
+ ):
879
+ super().__init__()
880
+ self.in_channels = in_channels
881
+ self.num_attention_heads = num_attention_heads
882
+ self.attention_head_dim = attention_head_dim
883
+ inner_dim = num_attention_heads * attention_head_dim
884
+ self.use_linear_projection = use_linear_projection
885
+
886
+ self.norm = torch.nn.GroupNorm(num_groups=TRANSFORMER_NORM_NUM_GROUPS, num_channels=in_channels, eps=1e-6, affine=True)
887
+
888
+ if use_linear_projection:
889
+ self.proj_in = nn.Linear(in_channels, inner_dim)
890
+ else:
891
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
892
+
893
+ self.transformer_blocks = nn.ModuleList(
894
+ [
895
+ BasicTransformerBlock(
896
+ inner_dim,
897
+ num_attention_heads,
898
+ attention_head_dim,
899
+ cross_attention_dim=cross_attention_dim,
900
+ upcast_attention=upcast_attention,
901
+ )
902
+ ]
903
+ )
904
+
905
+ if use_linear_projection:
906
+ self.proj_out = nn.Linear(in_channels, inner_dim)
907
+ else:
908
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
909
+
910
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
911
+ for transformer in self.transformer_blocks:
912
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
913
+
914
+ def set_use_sdpa(self, sdpa):
915
+ for transformer in self.transformer_blocks:
916
+ transformer.set_use_sdpa(sdpa)
917
+
918
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
919
+ # 1. Input
920
+ batch, _, height, weight = hidden_states.shape
921
+ residual = hidden_states
922
+
923
+ hidden_states = self.norm(hidden_states)
924
+ if not self.use_linear_projection:
925
+ hidden_states = self.proj_in(hidden_states)
926
+ inner_dim = hidden_states.shape[1]
927
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
928
+ else:
929
+ inner_dim = hidden_states.shape[1]
930
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
931
+ hidden_states = self.proj_in(hidden_states)
932
+
933
+ # 2. Blocks
934
+ for block in self.transformer_blocks:
935
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
936
+
937
+ # 3. Output
938
+ if not self.use_linear_projection:
939
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
940
+ hidden_states = self.proj_out(hidden_states)
941
+ else:
942
+ hidden_states = self.proj_out(hidden_states)
943
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
944
+
945
+ output = hidden_states + residual
946
+
947
+ if not return_dict:
948
+ return (output,)
949
+
950
+ return SampleOutput(sample=output)
951
+
952
+
953
+ class CrossAttnDownBlock2D(nn.Module):
954
+ def __init__(
955
+ self,
956
+ in_channels: int,
957
+ out_channels: int,
958
+ add_downsample=True,
959
+ cross_attention_dim=1280,
960
+ attn_num_head_channels=1,
961
+ use_linear_projection=False,
962
+ upcast_attention=False,
963
+ ):
964
+ super().__init__()
965
+ self.has_cross_attention = True
966
+ resnets = []
967
+ attentions = []
968
+
969
+ self.attn_num_head_channels = attn_num_head_channels
970
+
971
+ for i in range(LAYERS_PER_BLOCK):
972
+ in_channels = in_channels if i == 0 else out_channels
973
+
974
+ resnets.append(ResnetBlock2D(in_channels=in_channels, out_channels=out_channels))
975
+ attentions.append(
976
+ Transformer2DModel(
977
+ attn_num_head_channels,
978
+ out_channels // attn_num_head_channels,
979
+ in_channels=out_channels,
980
+ cross_attention_dim=cross_attention_dim,
981
+ use_linear_projection=use_linear_projection,
982
+ upcast_attention=upcast_attention,
983
+ )
984
+ )
985
+ self.attentions = nn.ModuleList(attentions)
986
+ self.resnets = nn.ModuleList(resnets)
987
+
988
+ if add_downsample:
989
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
990
+ else:
991
+ self.downsamplers = None
992
+
993
+ self.gradient_checkpointing = False
994
+
995
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
996
+ for attn in self.attentions:
997
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
998
+
999
+ def set_use_sdpa(self, sdpa):
1000
+ for attn in self.attentions:
1001
+ attn.set_use_sdpa(sdpa)
1002
+
1003
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1004
+ output_states = ()
1005
+
1006
+ for resnet, attn in zip(self.resnets, self.attentions):
1007
+ if self.training and self.gradient_checkpointing:
1008
+
1009
+ def create_custom_forward(module, return_dict=None):
1010
+ def custom_forward(*inputs):
1011
+ if return_dict is not None:
1012
+ return module(*inputs, return_dict=return_dict)
1013
+ else:
1014
+ return module(*inputs)
1015
+
1016
+ return custom_forward
1017
+
1018
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1019
+ hidden_states = torch.utils.checkpoint.checkpoint(
1020
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1021
+ )[0]
1022
+ else:
1023
+ hidden_states = resnet(hidden_states, temb)
1024
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1025
+
1026
+ output_states += (hidden_states,)
1027
+
1028
+ if self.downsamplers is not None:
1029
+ for downsampler in self.downsamplers:
1030
+ hidden_states = downsampler(hidden_states)
1031
+
1032
+ output_states += (hidden_states,)
1033
+
1034
+ return hidden_states, output_states
1035
+
1036
+
1037
+ class UNetMidBlock2DCrossAttn(nn.Module):
1038
+ def __init__(
1039
+ self,
1040
+ in_channels: int,
1041
+ attn_num_head_channels=1,
1042
+ cross_attention_dim=1280,
1043
+ use_linear_projection=False,
1044
+ ):
1045
+ super().__init__()
1046
+
1047
+ self.has_cross_attention = True
1048
+ self.attn_num_head_channels = attn_num_head_channels
1049
+
1050
+ # Middle block has two resnets and one attention
1051
+ resnets = [
1052
+ ResnetBlock2D(
1053
+ in_channels=in_channels,
1054
+ out_channels=in_channels,
1055
+ ),
1056
+ ResnetBlock2D(
1057
+ in_channels=in_channels,
1058
+ out_channels=in_channels,
1059
+ ),
1060
+ ]
1061
+ attentions = [
1062
+ Transformer2DModel(
1063
+ attn_num_head_channels,
1064
+ in_channels // attn_num_head_channels,
1065
+ in_channels=in_channels,
1066
+ cross_attention_dim=cross_attention_dim,
1067
+ use_linear_projection=use_linear_projection,
1068
+ )
1069
+ ]
1070
+
1071
+ self.attentions = nn.ModuleList(attentions)
1072
+ self.resnets = nn.ModuleList(resnets)
1073
+
1074
+ self.gradient_checkpointing = False
1075
+
1076
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1077
+ for attn in self.attentions:
1078
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1079
+
1080
+ def set_use_sdpa(self, sdpa):
1081
+ for attn in self.attentions:
1082
+ attn.set_use_sdpa(sdpa)
1083
+
1084
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
1085
+ for i, resnet in enumerate(self.resnets):
1086
+ attn = None if i == 0 else self.attentions[i - 1]
1087
+
1088
+ if self.training and self.gradient_checkpointing:
1089
+
1090
+ def create_custom_forward(module, return_dict=None):
1091
+ def custom_forward(*inputs):
1092
+ if return_dict is not None:
1093
+ return module(*inputs, return_dict=return_dict)
1094
+ else:
1095
+ return module(*inputs)
1096
+
1097
+ return custom_forward
1098
+
1099
+ if attn is not None:
1100
+ hidden_states = torch.utils.checkpoint.checkpoint(
1101
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1102
+ )[0]
1103
+
1104
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1105
+ else:
1106
+ if attn is not None:
1107
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
1108
+ hidden_states = resnet(hidden_states, temb)
1109
+
1110
+ return hidden_states
1111
+
1112
+
1113
+ class Upsample2D(nn.Module):
1114
+ def __init__(self, channels, out_channels):
1115
+ super().__init__()
1116
+ self.channels = channels
1117
+ self.out_channels = out_channels
1118
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
1119
+
1120
+ def forward(self, hidden_states, output_size):
1121
+ assert hidden_states.shape[1] == self.channels
1122
+
1123
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
1124
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
1125
+ # https://github.com/pytorch/pytorch/issues/86679
1126
+ dtype = hidden_states.dtype
1127
+ if dtype == torch.bfloat16:
1128
+ hidden_states = hidden_states.to(torch.float32)
1129
+
1130
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
1131
+ if hidden_states.shape[0] >= 64:
1132
+ hidden_states = hidden_states.contiguous()
1133
+
1134
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
1135
+ if output_size is None:
1136
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
1137
+ else:
1138
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
1139
+
1140
+ # If the input is bfloat16, we cast back to bfloat16
1141
+ if dtype == torch.bfloat16:
1142
+ hidden_states = hidden_states.to(dtype)
1143
+
1144
+ hidden_states = self.conv(hidden_states)
1145
+
1146
+ return hidden_states
1147
+
1148
+
1149
+ class UpBlock2D(nn.Module):
1150
+ def __init__(
1151
+ self,
1152
+ in_channels: int,
1153
+ prev_output_channel: int,
1154
+ out_channels: int,
1155
+ add_upsample=True,
1156
+ ):
1157
+ super().__init__()
1158
+
1159
+ self.has_cross_attention = False
1160
+ resnets = []
1161
+
1162
+ for i in range(LAYERS_PER_BLOCK_UP):
1163
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1164
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1165
+
1166
+ resnets.append(
1167
+ ResnetBlock2D(
1168
+ in_channels=resnet_in_channels + res_skip_channels,
1169
+ out_channels=out_channels,
1170
+ )
1171
+ )
1172
+
1173
+ self.resnets = nn.ModuleList(resnets)
1174
+
1175
+ if add_upsample:
1176
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1177
+ else:
1178
+ self.upsamplers = None
1179
+
1180
+ self.gradient_checkpointing = False
1181
+
1182
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1183
+ pass
1184
+
1185
+ def set_use_sdpa(self, sdpa):
1186
+ pass
1187
+
1188
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1189
+ for resnet in self.resnets:
1190
+ # pop res hidden states
1191
+ res_hidden_states = res_hidden_states_tuple[-1]
1192
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1193
+
1194
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1195
+
1196
+ if self.training and self.gradient_checkpointing:
1197
+
1198
+ def create_custom_forward(module):
1199
+ def custom_forward(*inputs):
1200
+ return module(*inputs)
1201
+
1202
+ return custom_forward
1203
+
1204
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1205
+ else:
1206
+ hidden_states = resnet(hidden_states, temb)
1207
+
1208
+ if self.upsamplers is not None:
1209
+ for upsampler in self.upsamplers:
1210
+ hidden_states = upsampler(hidden_states, upsample_size)
1211
+
1212
+ return hidden_states
1213
+
1214
+
1215
+ class CrossAttnUpBlock2D(nn.Module):
1216
+ def __init__(
1217
+ self,
1218
+ in_channels: int,
1219
+ out_channels: int,
1220
+ prev_output_channel: int,
1221
+ attn_num_head_channels=1,
1222
+ cross_attention_dim=1280,
1223
+ add_upsample=True,
1224
+ use_linear_projection=False,
1225
+ upcast_attention=False,
1226
+ ):
1227
+ super().__init__()
1228
+ resnets = []
1229
+ attentions = []
1230
+
1231
+ self.has_cross_attention = True
1232
+ self.attn_num_head_channels = attn_num_head_channels
1233
+
1234
+ for i in range(LAYERS_PER_BLOCK_UP):
1235
+ res_skip_channels = in_channels if (i == LAYERS_PER_BLOCK_UP - 1) else out_channels
1236
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1237
+
1238
+ resnets.append(
1239
+ ResnetBlock2D(
1240
+ in_channels=resnet_in_channels + res_skip_channels,
1241
+ out_channels=out_channels,
1242
+ )
1243
+ )
1244
+ attentions.append(
1245
+ Transformer2DModel(
1246
+ attn_num_head_channels,
1247
+ out_channels // attn_num_head_channels,
1248
+ in_channels=out_channels,
1249
+ cross_attention_dim=cross_attention_dim,
1250
+ use_linear_projection=use_linear_projection,
1251
+ upcast_attention=upcast_attention,
1252
+ )
1253
+ )
1254
+
1255
+ self.attentions = nn.ModuleList(attentions)
1256
+ self.resnets = nn.ModuleList(resnets)
1257
+
1258
+ if add_upsample:
1259
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
1260
+ else:
1261
+ self.upsamplers = None
1262
+
1263
+ self.gradient_checkpointing = False
1264
+
1265
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
1266
+ for attn in self.attentions:
1267
+ attn.set_use_memory_efficient_attention(xformers, mem_eff)
1268
+
1269
+ def set_use_sdpa(self, sdpa):
1270
+ for attn in self.attentions:
1271
+ attn.set_use_sdpa(sdpa)
1272
+
1273
+ def forward(
1274
+ self,
1275
+ hidden_states,
1276
+ res_hidden_states_tuple,
1277
+ temb=None,
1278
+ encoder_hidden_states=None,
1279
+ upsample_size=None,
1280
+ ):
1281
+ for resnet, attn in zip(self.resnets, self.attentions):
1282
+ # pop res hidden states
1283
+ res_hidden_states = res_hidden_states_tuple[-1]
1284
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1285
+
1286
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1287
+
1288
+ if self.training and self.gradient_checkpointing:
1289
+
1290
+ def create_custom_forward(module, return_dict=None):
1291
+ def custom_forward(*inputs):
1292
+ if return_dict is not None:
1293
+ return module(*inputs, return_dict=return_dict)
1294
+ else:
1295
+ return module(*inputs)
1296
+
1297
+ return custom_forward
1298
+
1299
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1300
+ hidden_states = torch.utils.checkpoint.checkpoint(
1301
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
1302
+ )[0]
1303
+ else:
1304
+ hidden_states = resnet(hidden_states, temb)
1305
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1306
+
1307
+ if self.upsamplers is not None:
1308
+ for upsampler in self.upsamplers:
1309
+ hidden_states = upsampler(hidden_states, upsample_size)
1310
+
1311
+ return hidden_states
1312
+
1313
+
1314
+ def get_down_block(
1315
+ down_block_type,
1316
+ in_channels,
1317
+ out_channels,
1318
+ add_downsample,
1319
+ attn_num_head_channels,
1320
+ cross_attention_dim,
1321
+ use_linear_projection,
1322
+ upcast_attention,
1323
+ ):
1324
+ if down_block_type == "DownBlock2D":
1325
+ return DownBlock2D(
1326
+ in_channels=in_channels,
1327
+ out_channels=out_channels,
1328
+ add_downsample=add_downsample,
1329
+ )
1330
+ elif down_block_type == "CrossAttnDownBlock2D":
1331
+ return CrossAttnDownBlock2D(
1332
+ in_channels=in_channels,
1333
+ out_channels=out_channels,
1334
+ add_downsample=add_downsample,
1335
+ cross_attention_dim=cross_attention_dim,
1336
+ attn_num_head_channels=attn_num_head_channels,
1337
+ use_linear_projection=use_linear_projection,
1338
+ upcast_attention=upcast_attention,
1339
+ )
1340
+
1341
+
1342
+ def get_up_block(
1343
+ up_block_type,
1344
+ in_channels,
1345
+ out_channels,
1346
+ prev_output_channel,
1347
+ add_upsample,
1348
+ attn_num_head_channels,
1349
+ cross_attention_dim=None,
1350
+ use_linear_projection=False,
1351
+ upcast_attention=False,
1352
+ ):
1353
+ if up_block_type == "UpBlock2D":
1354
+ return UpBlock2D(
1355
+ in_channels=in_channels,
1356
+ prev_output_channel=prev_output_channel,
1357
+ out_channels=out_channels,
1358
+ add_upsample=add_upsample,
1359
+ )
1360
+ elif up_block_type == "CrossAttnUpBlock2D":
1361
+ return CrossAttnUpBlock2D(
1362
+ in_channels=in_channels,
1363
+ out_channels=out_channels,
1364
+ prev_output_channel=prev_output_channel,
1365
+ attn_num_head_channels=attn_num_head_channels,
1366
+ cross_attention_dim=cross_attention_dim,
1367
+ add_upsample=add_upsample,
1368
+ use_linear_projection=use_linear_projection,
1369
+ upcast_attention=upcast_attention,
1370
+ )
1371
+
1372
+
1373
+ class UNet2DConditionModel(nn.Module):
1374
+ _supports_gradient_checkpointing = True
1375
+
1376
+ def __init__(
1377
+ self,
1378
+ sample_size: Optional[int] = None,
1379
+ attention_head_dim: Union[int, Tuple[int]] = 8,
1380
+ cross_attention_dim: int = 1280,
1381
+ use_linear_projection: bool = False,
1382
+ upcast_attention: bool = False,
1383
+ **kwargs,
1384
+ ):
1385
+ super().__init__()
1386
+ assert sample_size is not None, "sample_size must be specified"
1387
+ logger.info(
1388
+ f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
1389
+ )
1390
+
1391
+ # 外部からの参照用に定義しておく
1392
+ self.in_channels = IN_CHANNELS
1393
+ self.out_channels = OUT_CHANNELS
1394
+
1395
+ self.sample_size = sample_size
1396
+ self.prepare_config(sample_size=sample_size)
1397
+
1398
+ # state_dictの書式が変わるのでmoduleの持ち方は変えられない
1399
+
1400
+ # input
1401
+ self.conv_in = nn.Conv2d(IN_CHANNELS, BLOCK_OUT_CHANNELS[0], kernel_size=3, padding=(1, 1))
1402
+
1403
+ # time
1404
+ self.time_proj = Timesteps(BLOCK_OUT_CHANNELS[0], TIME_EMBED_FLIP_SIN_TO_COS, TIME_EMBED_FREQ_SHIFT)
1405
+
1406
+ self.time_embedding = TimestepEmbedding(TIMESTEP_INPUT_DIM, TIME_EMBED_DIM)
1407
+
1408
+ self.down_blocks = nn.ModuleList([])
1409
+ self.mid_block = None
1410
+ self.up_blocks = nn.ModuleList([])
1411
+
1412
+ if isinstance(attention_head_dim, int):
1413
+ attention_head_dim = (attention_head_dim,) * 4
1414
+
1415
+ # down
1416
+ output_channel = BLOCK_OUT_CHANNELS[0]
1417
+ for i, down_block_type in enumerate(DOWN_BLOCK_TYPES):
1418
+ input_channel = output_channel
1419
+ output_channel = BLOCK_OUT_CHANNELS[i]
1420
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1421
+
1422
+ down_block = get_down_block(
1423
+ down_block_type,
1424
+ in_channels=input_channel,
1425
+ out_channels=output_channel,
1426
+ add_downsample=not is_final_block,
1427
+ attn_num_head_channels=attention_head_dim[i],
1428
+ cross_attention_dim=cross_attention_dim,
1429
+ use_linear_projection=use_linear_projection,
1430
+ upcast_attention=upcast_attention,
1431
+ )
1432
+ self.down_blocks.append(down_block)
1433
+
1434
+ # mid
1435
+ self.mid_block = UNetMidBlock2DCrossAttn(
1436
+ in_channels=BLOCK_OUT_CHANNELS[-1],
1437
+ attn_num_head_channels=attention_head_dim[-1],
1438
+ cross_attention_dim=cross_attention_dim,
1439
+ use_linear_projection=use_linear_projection,
1440
+ )
1441
+
1442
+ # count how many layers upsample the images
1443
+ self.num_upsamplers = 0
1444
+
1445
+ # up
1446
+ reversed_block_out_channels = list(reversed(BLOCK_OUT_CHANNELS))
1447
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
1448
+ output_channel = reversed_block_out_channels[0]
1449
+ for i, up_block_type in enumerate(UP_BLOCK_TYPES):
1450
+ is_final_block = i == len(BLOCK_OUT_CHANNELS) - 1
1451
+
1452
+ prev_output_channel = output_channel
1453
+ output_channel = reversed_block_out_channels[i]
1454
+ input_channel = reversed_block_out_channels[min(i + 1, len(BLOCK_OUT_CHANNELS) - 1)]
1455
+
1456
+ # add upsample block for all BUT final layer
1457
+ if not is_final_block:
1458
+ add_upsample = True
1459
+ self.num_upsamplers += 1
1460
+ else:
1461
+ add_upsample = False
1462
+
1463
+ up_block = get_up_block(
1464
+ up_block_type,
1465
+ in_channels=input_channel,
1466
+ out_channels=output_channel,
1467
+ prev_output_channel=prev_output_channel,
1468
+ add_upsample=add_upsample,
1469
+ attn_num_head_channels=reversed_attention_head_dim[i],
1470
+ cross_attention_dim=cross_attention_dim,
1471
+ use_linear_projection=use_linear_projection,
1472
+ upcast_attention=upcast_attention,
1473
+ )
1474
+ self.up_blocks.append(up_block)
1475
+ prev_output_channel = output_channel
1476
+
1477
+ # out
1478
+ self.conv_norm_out = nn.GroupNorm(num_channels=BLOCK_OUT_CHANNELS[0], num_groups=NORM_GROUPS, eps=NORM_EPS)
1479
+ self.conv_act = nn.SiLU()
1480
+ self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
1481
+
1482
+ # region diffusers compatibility
1483
+ def prepare_config(self, *args, **kwargs):
1484
+ self.config = SimpleNamespace(**kwargs)
1485
+
1486
+ @property
1487
+ def dtype(self) -> torch.dtype:
1488
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1489
+ return get_parameter_dtype(self)
1490
+
1491
+ @property
1492
+ def device(self) -> torch.device:
1493
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1494
+ return get_parameter_device(self)
1495
+
1496
+ def set_attention_slice(self, slice_size):
1497
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1498
+
1499
+ def is_gradient_checkpointing(self) -> bool:
1500
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1501
+
1502
+ def enable_gradient_checkpointing(self):
1503
+ self.set_gradient_checkpointing(value=True)
1504
+
1505
+ def disable_gradient_checkpointing(self):
1506
+ self.set_gradient_checkpointing(value=False)
1507
+
1508
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1509
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1510
+ for module in modules:
1511
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1512
+
1513
+ def set_use_sdpa(self, sdpa: bool) -> None:
1514
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1515
+ for module in modules:
1516
+ module.set_use_sdpa(sdpa)
1517
+
1518
+ def set_gradient_checkpointing(self, value=False):
1519
+ modules = self.down_blocks + [self.mid_block] + self.up_blocks
1520
+ for module in modules:
1521
+ logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
1522
+ module.gradient_checkpointing = value
1523
+
1524
+ # endregion
1525
+
1526
+ def forward(
1527
+ self,
1528
+ sample: torch.FloatTensor,
1529
+ timestep: Union[torch.Tensor, float, int],
1530
+ encoder_hidden_states: torch.Tensor,
1531
+ class_labels: Optional[torch.Tensor] = None,
1532
+ return_dict: bool = True,
1533
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1534
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1535
+ ) -> Union[Dict, Tuple]:
1536
+ r"""
1537
+ Args:
1538
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1539
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1540
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1541
+ return_dict (`bool`, *optional*, defaults to `True`):
1542
+ Whether or not to return a dict instead of a plain tuple.
1543
+
1544
+ Returns:
1545
+ `SampleOutput` or `tuple`:
1546
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1547
+ """
1548
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1549
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1550
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1551
+ # on the fly if necessary.
1552
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1553
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1554
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1555
+ default_overall_up_factor = 2**self.num_upsamplers
1556
+
1557
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1558
+ # 64で割り切れないときはupsamplerにサイズを伝える
1559
+ forward_upsample_size = False
1560
+ upsample_size = None
1561
+
1562
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1563
+ # logger.info("Forward upsample size to force interpolation output size.")
1564
+ forward_upsample_size = True
1565
+
1566
+ # 1. time
1567
+ timesteps = timestep
1568
+ timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1569
+
1570
+ t_emb = self.time_proj(timesteps)
1571
+
1572
+ # timesteps does not contain any weights and will always return f32 tensors
1573
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1574
+ # there might be better ways to encapsulate this.
1575
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1576
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1577
+ # time_projでキャストしておけばいいんじゃね?
1578
+ t_emb = t_emb.to(dtype=self.dtype)
1579
+ emb = self.time_embedding(t_emb)
1580
+
1581
+ # 2. pre-process
1582
+ sample = self.conv_in(sample)
1583
+
1584
+ down_block_res_samples = (sample,)
1585
+ for downsample_block in self.down_blocks:
1586
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1587
+ # まあこちらのほうがわかりやすいかもしれない
1588
+ if downsample_block.has_cross_attention:
1589
+ sample, res_samples = downsample_block(
1590
+ hidden_states=sample,
1591
+ temb=emb,
1592
+ encoder_hidden_states=encoder_hidden_states,
1593
+ )
1594
+ else:
1595
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1596
+
1597
+ down_block_res_samples += res_samples
1598
+
1599
+ # skip connectionにControlNetの出力を追加する
1600
+ if down_block_additional_residuals is not None:
1601
+ down_block_res_samples = list(down_block_res_samples)
1602
+ for i in range(len(down_block_res_samples)):
1603
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1604
+ down_block_res_samples = tuple(down_block_res_samples)
1605
+
1606
+ # 4. mid
1607
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1608
+
1609
+ # ControlNetの出力を追加する
1610
+ if mid_block_additional_residual is not None:
1611
+ sample += mid_block_additional_residual
1612
+
1613
+ # 5. up
1614
+ for i, upsample_block in enumerate(self.up_blocks):
1615
+ is_final_block = i == len(self.up_blocks) - 1
1616
+
1617
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1618
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1619
+
1620
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1621
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1622
+ if not is_final_block and forward_upsample_size:
1623
+ upsample_size = down_block_res_samples[-1].shape[2:]
1624
+
1625
+ if upsample_block.has_cross_attention:
1626
+ sample = upsample_block(
1627
+ hidden_states=sample,
1628
+ temb=emb,
1629
+ res_hidden_states_tuple=res_samples,
1630
+ encoder_hidden_states=encoder_hidden_states,
1631
+ upsample_size=upsample_size,
1632
+ )
1633
+ else:
1634
+ sample = upsample_block(
1635
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1636
+ )
1637
+
1638
+ # 6. post-process
1639
+ sample = self.conv_norm_out(sample)
1640
+ sample = self.conv_act(sample)
1641
+ sample = self.conv_out(sample)
1642
+
1643
+ if not return_dict:
1644
+ return (sample,)
1645
+
1646
+ return SampleOutput(sample=sample)
1647
+
1648
+ def handle_unusual_timesteps(self, sample, timesteps):
1649
+ r"""
1650
+ timestampsがTensorでない場合、Tensorに変換する。またOnnx/Core MLと互換性のあるようにbatchサイズまでbroadcastする。
1651
+ """
1652
+ if not torch.is_tensor(timesteps):
1653
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1654
+ # This would be a good case for the `match` statement (Python 3.10+)
1655
+ is_mps = sample.device.type == "mps"
1656
+ if isinstance(timesteps, float):
1657
+ dtype = torch.float32 if is_mps else torch.float64
1658
+ else:
1659
+ dtype = torch.int32 if is_mps else torch.int64
1660
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1661
+ elif len(timesteps.shape) == 0:
1662
+ timesteps = timesteps[None].to(sample.device)
1663
+
1664
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1665
+ timesteps = timesteps.expand(sample.shape[0])
1666
+
1667
+ return timesteps
1668
+
1669
+
1670
+ class InferUNet2DConditionModel:
1671
+ def __init__(self, original_unet: UNet2DConditionModel):
1672
+ self.delegate = original_unet
1673
+
1674
+ # override original model's forward method: because forward is not called by `__call__`
1675
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1676
+ self.delegate.forward = self.forward
1677
+
1678
+ # override original model's up blocks' forward method
1679
+ for up_block in self.delegate.up_blocks:
1680
+ if up_block.__class__.__name__ == "UpBlock2D":
1681
+
1682
+ def resnet_wrapper(func, block):
1683
+ def forward(*args, **kwargs):
1684
+ return func(block, *args, **kwargs)
1685
+
1686
+ return forward
1687
+
1688
+ up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
1689
+
1690
+ elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
1691
+
1692
+ def cross_attn_up_wrapper(func, block):
1693
+ def forward(*args, **kwargs):
1694
+ return func(block, *args, **kwargs)
1695
+
1696
+ return forward
1697
+
1698
+ up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
1699
+
1700
+ # Deep Shrink
1701
+ self.ds_depth_1 = None
1702
+ self.ds_depth_2 = None
1703
+ self.ds_timesteps_1 = None
1704
+ self.ds_timesteps_2 = None
1705
+ self.ds_ratio = None
1706
+
1707
+ # call original model's methods
1708
+ def __getattr__(self, name):
1709
+ return getattr(self.delegate, name)
1710
+
1711
+ def __call__(self, *args, **kwargs):
1712
+ return self.delegate(*args, **kwargs)
1713
+
1714
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1715
+ if ds_depth_1 is None:
1716
+ logger.info("Deep Shrink is disabled.")
1717
+ self.ds_depth_1 = None
1718
+ self.ds_timesteps_1 = None
1719
+ self.ds_depth_2 = None
1720
+ self.ds_timesteps_2 = None
1721
+ self.ds_ratio = None
1722
+ else:
1723
+ logger.info(
1724
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1725
+ )
1726
+ self.ds_depth_1 = ds_depth_1
1727
+ self.ds_timesteps_1 = ds_timesteps_1
1728
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1729
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1730
+ self.ds_ratio = ds_ratio
1731
+
1732
+ def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1733
+ for resnet in _self.resnets:
1734
+ # pop res hidden states
1735
+ res_hidden_states = res_hidden_states_tuple[-1]
1736
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1737
+
1738
+ # Deep Shrink
1739
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1740
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1741
+
1742
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1743
+ hidden_states = resnet(hidden_states, temb)
1744
+
1745
+ if _self.upsamplers is not None:
1746
+ for upsampler in _self.upsamplers:
1747
+ hidden_states = upsampler(hidden_states, upsample_size)
1748
+
1749
+ return hidden_states
1750
+
1751
+ def cross_attn_up_block_forward(
1752
+ self,
1753
+ _self,
1754
+ hidden_states,
1755
+ res_hidden_states_tuple,
1756
+ temb=None,
1757
+ encoder_hidden_states=None,
1758
+ upsample_size=None,
1759
+ ):
1760
+ for resnet, attn in zip(_self.resnets, _self.attentions):
1761
+ # pop res hidden states
1762
+ res_hidden_states = res_hidden_states_tuple[-1]
1763
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1764
+
1765
+ # Deep Shrink
1766
+ if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
1767
+ hidden_states = resize_like(hidden_states, res_hidden_states)
1768
+
1769
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1770
+ hidden_states = resnet(hidden_states, temb)
1771
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
1772
+
1773
+ if _self.upsamplers is not None:
1774
+ for upsampler in _self.upsamplers:
1775
+ hidden_states = upsampler(hidden_states, upsample_size)
1776
+
1777
+ return hidden_states
1778
+
1779
+ def forward(
1780
+ self,
1781
+ sample: torch.FloatTensor,
1782
+ timestep: Union[torch.Tensor, float, int],
1783
+ encoder_hidden_states: torch.Tensor,
1784
+ class_labels: Optional[torch.Tensor] = None,
1785
+ return_dict: bool = True,
1786
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1787
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1788
+ ) -> Union[Dict, Tuple]:
1789
+ r"""
1790
+ current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
1791
+ """
1792
+
1793
+ r"""
1794
+ Args:
1795
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
1796
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
1797
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
1798
+ return_dict (`bool`, *optional*, defaults to `True`):
1799
+ Whether or not to return a dict instead of a plain tuple.
1800
+
1801
+ Returns:
1802
+ `SampleOutput` or `tuple`:
1803
+ `SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
1804
+ """
1805
+
1806
+ _self = self.delegate
1807
+
1808
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1809
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
1810
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1811
+ # on the fly if necessary.
1812
+ # デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
1813
+ # ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
1814
+ # 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
1815
+ default_overall_up_factor = 2**_self.num_upsamplers
1816
+
1817
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1818
+ # 64で割り切れないときはupsamplerにサイズを伝える
1819
+ forward_upsample_size = False
1820
+ upsample_size = None
1821
+
1822
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1823
+ # logger.info("Forward upsample size to force interpolation output size.")
1824
+ forward_upsample_size = True
1825
+
1826
+ # 1. time
1827
+ timesteps = timestep
1828
+ timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
1829
+
1830
+ t_emb = _self.time_proj(timesteps)
1831
+
1832
+ # timesteps does not contain any weights and will always return f32 tensors
1833
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1834
+ # there might be better ways to encapsulate this.
1835
+ # timestepsは重みを含まないので常にfloat32のテンソルを返す
1836
+ # しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
1837
+ # time_projでキャストしておけばいいんじゃね?
1838
+ t_emb = t_emb.to(dtype=_self.dtype)
1839
+ emb = _self.time_embedding(t_emb)
1840
+
1841
+ # 2. pre-process
1842
+ sample = _self.conv_in(sample)
1843
+
1844
+ down_block_res_samples = (sample,)
1845
+ for depth, downsample_block in enumerate(_self.down_blocks):
1846
+ # Deep Shrink
1847
+ if self.ds_depth_1 is not None:
1848
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1849
+ self.ds_depth_2 is not None
1850
+ and depth == self.ds_depth_2
1851
+ and timesteps[0] < self.ds_timesteps_1
1852
+ and timesteps[0] >= self.ds_timesteps_2
1853
+ ):
1854
+ org_dtype = sample.dtype
1855
+ if org_dtype == torch.bfloat16:
1856
+ sample = sample.to(torch.float32)
1857
+ sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1858
+
1859
+ # downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
1860
+ # まあこちらのほうがわかりやすいかもしれない
1861
+ if downsample_block.has_cross_attention:
1862
+ sample, res_samples = downsample_block(
1863
+ hidden_states=sample,
1864
+ temb=emb,
1865
+ encoder_hidden_states=encoder_hidden_states,
1866
+ )
1867
+ else:
1868
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1869
+
1870
+ down_block_res_samples += res_samples
1871
+
1872
+ # skip connectionにControlNetの出力を追加する
1873
+ if down_block_additional_residuals is not None:
1874
+ down_block_res_samples = list(down_block_res_samples)
1875
+ for i in range(len(down_block_res_samples)):
1876
+ down_block_res_samples[i] += down_block_additional_residuals[i]
1877
+ down_block_res_samples = tuple(down_block_res_samples)
1878
+
1879
+ # 4. mid
1880
+ sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
1881
+
1882
+ # ControlNetの出力を追加する
1883
+ if mid_block_additional_residual is not None:
1884
+ sample += mid_block_additional_residual
1885
+
1886
+ # 5. up
1887
+ for i, upsample_block in enumerate(_self.up_blocks):
1888
+ is_final_block = i == len(_self.up_blocks) - 1
1889
+
1890
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1891
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
1892
+
1893
+ # if we have not reached the final block and need to forward the upsample size, we do it here
1894
+ # 前述のように最後のブロック以外ではupsample_sizeを伝える
1895
+ if not is_final_block and forward_upsample_size:
1896
+ upsample_size = down_block_res_samples[-1].shape[2:]
1897
+
1898
+ if upsample_block.has_cross_attention:
1899
+ sample = upsample_block(
1900
+ hidden_states=sample,
1901
+ temb=emb,
1902
+ res_hidden_states_tuple=res_samples,
1903
+ encoder_hidden_states=encoder_hidden_states,
1904
+ upsample_size=upsample_size,
1905
+ )
1906
+ else:
1907
+ sample = upsample_block(
1908
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1909
+ )
1910
+
1911
+ # 6. post-process
1912
+ sample = _self.conv_norm_out(sample)
1913
+ sample = _self.conv_act(sample)
1914
+ sample = _self.conv_out(sample)
1915
+
1916
+ if not return_dict:
1917
+ return (sample,)
1918
+
1919
+ return SampleOutput(sample=sample)
library/sai_model_spec.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/Stability-AI/ModelSpec
2
+ import datetime
3
+ import hashlib
4
+ from io import BytesIO
5
+ import os
6
+ from typing import List, Optional, Tuple, Union
7
+ import safetensors
8
+ from library.utils import setup_logging
9
+
10
+ setup_logging()
11
+ import logging
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ r"""
16
+ # Metadata Example
17
+ metadata = {
18
+ # === Must ===
19
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
20
+ "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
21
+ "modelspec.implementation": "sgm",
22
+ "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
23
+ # === Should ===
24
+ "modelspec.author": "Example Corp", # Your name or company name
25
+ "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
26
+ "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
27
+ # === Can ===
28
+ "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
29
+ "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
30
+ }
31
+ """
32
+
33
+ BASE_METADATA = {
34
+ # === Must ===
35
+ "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
36
+ "modelspec.architecture": None,
37
+ "modelspec.implementation": None,
38
+ "modelspec.title": None,
39
+ "modelspec.resolution": None,
40
+ # === Should ===
41
+ "modelspec.description": None,
42
+ "modelspec.author": None,
43
+ "modelspec.date": None,
44
+ # === Can ===
45
+ "modelspec.license": None,
46
+ "modelspec.tags": None,
47
+ "modelspec.merged_from": None,
48
+ "modelspec.prediction_type": None,
49
+ "modelspec.timestep_range": None,
50
+ "modelspec.encoder_layer": None,
51
+ }
52
+
53
+ # 別に使うやつだけ定義
54
+ MODELSPEC_TITLE = "modelspec.title"
55
+
56
+ ARCH_SD_V1 = "stable-diffusion-v1"
57
+ ARCH_SD_V2_512 = "stable-diffusion-v2-512"
58
+ ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
59
+ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
60
+ ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
61
+ # ARCH_SD3_UNKNOWN = "stable-diffusion-3"
62
+ ARCH_FLUX_1_DEV = "flux-1-dev"
63
+ ARCH_FLUX_1_UNKNOWN = "flux-1"
64
+
65
+ ADAPTER_LORA = "lora"
66
+ ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
67
+
68
+ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
69
+ IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
70
+ IMPL_DIFFUSERS = "diffusers"
71
+ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
72
+
73
+ PRED_TYPE_EPSILON = "epsilon"
74
+ PRED_TYPE_V = "v"
75
+
76
+
77
+ def load_bytes_in_safetensors(tensors):
78
+ bytes = safetensors.torch.save(tensors)
79
+ b = BytesIO(bytes)
80
+
81
+ b.seek(0)
82
+ header = b.read(8)
83
+ n = int.from_bytes(header, "little")
84
+
85
+ offset = n + 8
86
+ b.seek(offset)
87
+
88
+ return b.read()
89
+
90
+
91
+ def precalculate_safetensors_hashes(state_dict):
92
+ # calculate each tensor one by one to reduce memory usage
93
+ hash_sha256 = hashlib.sha256()
94
+ for tensor in state_dict.values():
95
+ single_tensor_sd = {"tensor": tensor}
96
+ bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
97
+ hash_sha256.update(bytes_for_tensor)
98
+
99
+ return f"0x{hash_sha256.hexdigest()}"
100
+
101
+
102
+ def update_hash_sha256(metadata: dict, state_dict: dict):
103
+ raise NotImplementedError
104
+
105
+
106
+ def build_metadata(
107
+ state_dict: Optional[dict],
108
+ v2: bool,
109
+ v_parameterization: bool,
110
+ sdxl: bool,
111
+ lora: bool,
112
+ textual_inversion: bool,
113
+ timestamp: float,
114
+ title: Optional[str] = None,
115
+ reso: Optional[Union[int, Tuple[int, int]]] = None,
116
+ is_stable_diffusion_ckpt: Optional[bool] = None,
117
+ author: Optional[str] = None,
118
+ description: Optional[str] = None,
119
+ license: Optional[str] = None,
120
+ tags: Optional[str] = None,
121
+ merged_from: Optional[str] = None,
122
+ timesteps: Optional[Tuple[int, int]] = None,
123
+ clip_skip: Optional[int] = None,
124
+ sd3: Optional[str] = None,
125
+ flux: Optional[str] = None,
126
+ ):
127
+ """
128
+ sd3: only supports "m", flux: only supports "dev"
129
+ """
130
+ # if state_dict is None, hash is not calculated
131
+
132
+ metadata = {}
133
+ metadata.update(BASE_METADATA)
134
+
135
+ # TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
136
+ # if state_dict is not None:
137
+ # hash = precalculate_safetensors_hashes(state_dict)
138
+ # metadata["modelspec.hash_sha256"] = hash
139
+
140
+ if sdxl:
141
+ arch = ARCH_SD_XL_V1_BASE
142
+ elif sd3 is not None:
143
+ arch = ARCH_SD3_M + "-" + sd3
144
+ elif flux is not None:
145
+ if flux == "dev":
146
+ arch = ARCH_FLUX_1_DEV
147
+ else:
148
+ arch = ARCH_FLUX_1_UNKNOWN
149
+ elif v2:
150
+ if v_parameterization:
151
+ arch = ARCH_SD_V2_768_V
152
+ else:
153
+ arch = ARCH_SD_V2_512
154
+ else:
155
+ arch = ARCH_SD_V1
156
+
157
+ if lora:
158
+ arch += f"/{ADAPTER_LORA}"
159
+ elif textual_inversion:
160
+ arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
161
+
162
+ metadata["modelspec.architecture"] = arch
163
+
164
+ if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
165
+ is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
166
+
167
+ if flux is not None:
168
+ # Flux
169
+ impl = IMPL_FLUX
170
+ elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
171
+ # Stable Diffusion ckpt, TI, SDXL LoRA
172
+ impl = IMPL_STABILITY_AI
173
+ else:
174
+ # v1/v2 LoRA or Diffusers
175
+ impl = IMPL_DIFFUSERS
176
+ metadata["modelspec.implementation"] = impl
177
+
178
+ if title is None:
179
+ if lora:
180
+ title = "LoRA"
181
+ elif textual_inversion:
182
+ title = "TextualInversion"
183
+ else:
184
+ title = "Checkpoint"
185
+ title += f"@{timestamp}"
186
+ metadata[MODELSPEC_TITLE] = title
187
+
188
+ if author is not None:
189
+ metadata["modelspec.author"] = author
190
+ else:
191
+ del metadata["modelspec.author"]
192
+
193
+ if description is not None:
194
+ metadata["modelspec.description"] = description
195
+ else:
196
+ del metadata["modelspec.description"]
197
+
198
+ if merged_from is not None:
199
+ metadata["modelspec.merged_from"] = merged_from
200
+ else:
201
+ del metadata["modelspec.merged_from"]
202
+
203
+ if license is not None:
204
+ metadata["modelspec.license"] = license
205
+ else:
206
+ del metadata["modelspec.license"]
207
+
208
+ if tags is not None:
209
+ metadata["modelspec.tags"] = tags
210
+ else:
211
+ del metadata["modelspec.tags"]
212
+
213
+ # remove microsecond from time
214
+ int_ts = int(timestamp)
215
+
216
+ # time to iso-8601 compliant date
217
+ date = datetime.datetime.fromtimestamp(int_ts).isoformat()
218
+ metadata["modelspec.date"] = date
219
+
220
+ if reso is not None:
221
+ # comma separated to tuple
222
+ if isinstance(reso, str):
223
+ reso = tuple(map(int, reso.split(",")))
224
+ if len(reso) == 1:
225
+ reso = (reso[0], reso[0])
226
+ else:
227
+ # resolution is defined in dataset, so use default
228
+ if sdxl or sd3 is not None or flux is not None:
229
+ reso = 1024
230
+ elif v2 and v_parameterization:
231
+ reso = 768
232
+ else:
233
+ reso = 512
234
+ if isinstance(reso, int):
235
+ reso = (reso, reso)
236
+
237
+ metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
238
+
239
+ if flux is not None:
240
+ del metadata["modelspec.prediction_type"]
241
+ elif v_parameterization:
242
+ metadata["modelspec.prediction_type"] = PRED_TYPE_V
243
+ else:
244
+ metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
245
+
246
+ if timesteps is not None:
247
+ if isinstance(timesteps, str) or isinstance(timesteps, int):
248
+ timesteps = (timesteps, timesteps)
249
+ if len(timesteps) == 1:
250
+ timesteps = (timesteps[0], timesteps[0])
251
+ metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
252
+ else:
253
+ del metadata["modelspec.timestep_range"]
254
+
255
+ if clip_skip is not None:
256
+ metadata["modelspec.encoder_layer"] = f"{clip_skip}"
257
+ else:
258
+ del metadata["modelspec.encoder_layer"]
259
+
260
+ # # assert all values are filled
261
+ # assert all([v is not None for v in metadata.values()]), metadata
262
+ if not all([v is not None for v in metadata.values()]):
263
+ logger.error(f"Internal error: some metadata values are None: {metadata}")
264
+
265
+ return metadata
266
+
267
+
268
+ # region utils
269
+
270
+
271
+ def get_title(metadata: dict) -> Optional[str]:
272
+ return metadata.get(MODELSPEC_TITLE, None)
273
+
274
+
275
+ def load_metadata_from_safetensors(model: str) -> dict:
276
+ if not model.endswith(".safetensors"):
277
+ return {}
278
+
279
+ with safetensors.safe_open(model, framework="pt") as f:
280
+ metadata = f.metadata()
281
+ if metadata is None:
282
+ metadata = {}
283
+ return metadata
284
+
285
+
286
+ def build_merged_from(models: List[str]) -> str:
287
+ def get_title(model: str):
288
+ metadata = load_metadata_from_safetensors(model)
289
+ title = metadata.get(MODELSPEC_TITLE, None)
290
+ if title is None:
291
+ title = os.path.splitext(os.path.basename(model))[0] # use filename
292
+ return title
293
+
294
+ titles = [get_title(model) for model in models]
295
+ return ", ".join(titles)
296
+
297
+
298
+ # endregion
299
+
300
+
301
+ r"""
302
+ if __name__ == "__main__":
303
+ import argparse
304
+ import torch
305
+ from safetensors.torch import load_file
306
+ from library import train_util
307
+
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument("--ckpt", type=str, required=True)
310
+ args = parser.parse_args()
311
+
312
+ print(f"Loading {args.ckpt}")
313
+ state_dict = load_file(args.ckpt)
314
+
315
+ print(f"Calculating metadata")
316
+ metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
317
+ print(metadata)
318
+ del state_dict
319
+
320
+ # by reference implementation
321
+ with open(args.ckpt, mode="rb") as file_data:
322
+ file_hash = hashlib.sha256()
323
+ head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
324
+ header = json.loads(file_data.read(head_len[0])) # header itself, json string
325
+ content = (
326
+ file_data.read()
327
+ ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
328
+ file_hash.update(content)
329
+ # ===== Update the hash for modelspec =====
330
+ by_ref = f"0x{file_hash.hexdigest()}"
331
+ print(by_ref)
332
+ print("is same?", by_ref == metadata["modelspec.hash_sha256"])
333
+
334
+ """
library/sd3_models.py ADDED
@@ -0,0 +1,1413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
2
+ # the original code is licensed under the MIT License
3
+
4
+ # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
5
+
6
+ from ast import Tuple
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from dataclasses import dataclass
9
+ from functools import partial
10
+ import math
11
+ from types import SimpleNamespace
12
+ from typing import Dict, List, Optional, Union
13
+ import einops
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.utils.checkpoint import checkpoint
19
+ from transformers import CLIPTokenizer, T5TokenizerFast
20
+
21
+ from library import custom_offloading_utils
22
+ from library.device_utils import clean_memory_on_device
23
+
24
+ from .utils import setup_logging
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ memory_efficient_attention = None
33
+ try:
34
+ import xformers
35
+ except:
36
+ pass
37
+
38
+ try:
39
+ from xformers.ops import memory_efficient_attention
40
+ except:
41
+ memory_efficient_attention = None
42
+
43
+
44
+ # region mmdit
45
+
46
+
47
+ @dataclass
48
+ class SD3Params:
49
+ patch_size: int
50
+ depth: int
51
+ num_patches: int
52
+ pos_embed_max_size: int
53
+ adm_in_channels: int
54
+ qk_norm: Optional[str]
55
+ x_block_self_attn_layers: list[int]
56
+ context_embedder_in_features: int
57
+ context_embedder_out_features: int
58
+ model_type: str
59
+
60
+
61
+ def get_2d_sincos_pos_embed(
62
+ embed_dim,
63
+ grid_size,
64
+ scaling_factor=None,
65
+ offset=None,
66
+ ):
67
+ grid_h = np.arange(grid_size, dtype=np.float32)
68
+ grid_w = np.arange(grid_size, dtype=np.float32)
69
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
70
+ grid = np.stack(grid, axis=0)
71
+ if scaling_factor is not None:
72
+ grid = grid / scaling_factor
73
+ if offset is not None:
74
+ grid = grid - offset
75
+
76
+ grid = grid.reshape([2, 1, grid_size, grid_size])
77
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78
+ return pos_embed
79
+
80
+
81
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
82
+ assert embed_dim % 2 == 0
83
+
84
+ # use half of dimensions to encode grid_h
85
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
86
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
87
+
88
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
89
+ return emb
90
+
91
+
92
+ def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16):
93
+ """
94
+ This function is contributed by KohakuBlueleaf. Thanks for the contribution!
95
+
96
+ Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions
97
+ when the resolution differs from the training resolution.
98
+
99
+ Args:
100
+ embed_dim (int): Dimension of the positional embedding.
101
+ grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid.
102
+ cls_token (bool): Whether to include class token. Defaults to False.
103
+ extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0.
104
+ sample_size (int): Reference resolution (typically training resolution). Defaults to 64.
105
+ base_size (int): Base grid size used during training. Defaults to 16.
106
+
107
+ Returns:
108
+ numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or
109
+ (H*W + extra_tokens, embed_dim) if cls_token is True.
110
+ """
111
+ # Convert grid_size to tuple if it's an integer
112
+ if isinstance(grid_size, int):
113
+ grid_size = (grid_size, grid_size)
114
+
115
+ # Create normalized grid coordinates (0 to 1)
116
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0]
117
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1]
118
+
119
+ # Calculate scaling factors for height and width
120
+ # This ensures that the central region matches the original resolution's embeddings
121
+ scale_h = base_size * grid_size[0] / (sample_size)
122
+ scale_w = base_size * grid_size[1] / (sample_size)
123
+
124
+ # Calculate shift values to center the original resolution's embedding region
125
+ # This ensures that the central sample_size x sample_size region has similar
126
+ # positional embeddings to the original resolution
127
+ shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0])
128
+ shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1])
129
+
130
+ # Apply scaling and shifting to create the final grid coordinates
131
+ grid_h = grid_h * scale_h - shift_h
132
+ grid_w = grid_w * scale_w - shift_w
133
+
134
+ # Create 2D grid using meshgrid (note: w goes first)
135
+ grid = np.meshgrid(grid_w, grid_h)
136
+ grid = np.stack(grid, axis=0)
137
+
138
+ # # Calculate the starting indices for the central region
139
+ # # This is used for debugging/visualization of the central region
140
+ # st_h = (grid_size[0] - sample_size) // 2
141
+ # st_w = (grid_size[1] - sample_size) // 2
142
+ # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size])
143
+
144
+ # Reshape grid for positional embedding calculation
145
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
146
+
147
+ # Generate the sinusoidal positional embeddings
148
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
149
+
150
+ # Add zeros for extra tokens (e.g., [CLS] token) if required
151
+ if cls_token and extra_tokens > 0:
152
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
153
+
154
+ return pos_embed
155
+
156
+
157
+ # if __name__ == "__main__":
158
+ # # This is what you get when you load SD3.5 state dict
159
+ # pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed(
160
+ # 1536, [384, 384], sample_size=64, base_size=16
161
+ # )).float().unsqueeze(0)
162
+
163
+
164
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
165
+ """
166
+ embed_dim: output dimension for each position
167
+ pos: a list of positions to be encoded: size (M,)
168
+ out: (M, D)
169
+ """
170
+ assert embed_dim % 2 == 0
171
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
172
+ omega /= embed_dim / 2.0
173
+ omega = 1.0 / 10000**omega # (D/2,)
174
+
175
+ pos = pos.reshape(-1) # (M,)
176
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
177
+
178
+ emb_sin = np.sin(out) # (M, D/2)
179
+ emb_cos = np.cos(out) # (M, D/2)
180
+
181
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
182
+ return emb
183
+
184
+
185
+ def get_1d_sincos_pos_embed_from_grid_torch(
186
+ embed_dim,
187
+ pos,
188
+ device=None,
189
+ dtype=torch.float32,
190
+ ):
191
+ omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
192
+ omega *= 2.0 / embed_dim
193
+ omega = 1.0 / 10000**omega
194
+ out = torch.outer(pos.reshape(-1), omega)
195
+ emb = torch.cat([out.sin(), out.cos()], dim=1)
196
+ return emb
197
+
198
+
199
+ def get_2d_sincos_pos_embed_torch(
200
+ embed_dim,
201
+ w,
202
+ h,
203
+ val_center=7.5,
204
+ val_magnitude=7.5,
205
+ device=None,
206
+ dtype=torch.float32,
207
+ ):
208
+ small = min(h, w)
209
+ val_h = (h / small) * val_magnitude
210
+ val_w = (w / small) * val_magnitude
211
+ grid_h, grid_w = torch.meshgrid(
212
+ torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
213
+ torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
214
+ indexing="ij",
215
+ )
216
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
217
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
218
+ emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
219
+ return emb
220
+
221
+
222
+ def modulate(x, shift, scale):
223
+ if shift is None:
224
+ shift = torch.zeros_like(scale)
225
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
226
+
227
+
228
+ def default(x, default_value):
229
+ if x is None:
230
+ return default_value
231
+ return x
232
+
233
+
234
+ def timestep_embedding(t, dim, max_period=10000):
235
+ half = dim // 2
236
+ # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
237
+ # device=t.device, dtype=t.dtype
238
+ # )
239
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
240
+ args = t[:, None].float() * freqs[None]
241
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
242
+ if dim % 2:
243
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
244
+ if torch.is_floating_point(t):
245
+ embedding = embedding.to(dtype=t.dtype)
246
+ return embedding
247
+
248
+
249
+ class PatchEmbed(nn.Module):
250
+ def __init__(
251
+ self,
252
+ img_size=256,
253
+ patch_size=4,
254
+ in_channels=3,
255
+ embed_dim=512,
256
+ norm_layer=None,
257
+ flatten=True,
258
+ bias=True,
259
+ strict_img_size=True,
260
+ dynamic_img_pad=False,
261
+ ):
262
+ # dynamic_img_pad and norm is omitted in SD3.5
263
+ super().__init__()
264
+ self.patch_size = patch_size
265
+ self.flatten = flatten
266
+ self.strict_img_size = strict_img_size
267
+ self.dynamic_img_pad = dynamic_img_pad
268
+ if img_size is not None:
269
+ self.img_size = img_size
270
+ self.grid_size = img_size // patch_size
271
+ self.num_patches = self.grid_size**2
272
+ else:
273
+ self.img_size = None
274
+ self.grid_size = None
275
+ self.num_patches = None
276
+
277
+ self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
278
+ self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)
279
+
280
+ def forward(self, x):
281
+ B, C, H, W = x.shape
282
+
283
+ if self.dynamic_img_pad:
284
+ # Pad input so we won't have partial patch
285
+ pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
286
+ pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
287
+ x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
288
+ x = self.proj(x)
289
+ if self.flatten:
290
+ x = x.flatten(2).transpose(1, 2)
291
+ x = self.norm(x)
292
+ return x
293
+
294
+
295
+ # FinalLayer in mmdit.py
296
+ class UnPatch(nn.Module):
297
+ def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
298
+ super().__init__()
299
+ self.patch_size = patch_size
300
+ self.c = out_channels
301
+
302
+ # eps is default in mmdit.py
303
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
304
+ self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
305
+ self.adaLN_modulation = nn.Sequential(
306
+ nn.SiLU(),
307
+ nn.Linear(hidden_size, 2 * hidden_size),
308
+ )
309
+
310
+ def forward(self, x: torch.Tensor, cmod, H=None, W=None):
311
+ b, n, _ = x.shape
312
+ p = self.patch_size
313
+ c = self.c
314
+ if H is None and W is None:
315
+ w = h = int(n**0.5)
316
+ assert h * w == n
317
+ else:
318
+ h = H // p if H else n // (W // p)
319
+ w = W // p if W else n // h
320
+ assert h * w == n
321
+
322
+ shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
323
+ x = modulate(self.norm_final(x), shift, scale)
324
+ x = self.linear(x)
325
+
326
+ x = x.view(b, h, w, p, p, c)
327
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
328
+ x = x.view(b, c, h * p, w * p)
329
+ return x
330
+
331
+
332
+ class MLP(nn.Module):
333
+ def __init__(
334
+ self,
335
+ in_features,
336
+ hidden_features=None,
337
+ out_features=None,
338
+ act_layer=lambda: nn.GELU(),
339
+ norm_layer=None,
340
+ bias=True,
341
+ use_conv=False,
342
+ ):
343
+ super().__init__()
344
+ out_features = out_features or in_features
345
+ hidden_features = hidden_features or in_features
346
+ self.use_conv = use_conv
347
+
348
+ layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear
349
+
350
+ self.fc1 = layer(in_features, hidden_features, bias=bias)
351
+ self.fc2 = layer(hidden_features, out_features, bias=bias)
352
+ self.act = act_layer()
353
+ self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
354
+
355
+ def forward(self, x):
356
+ x = self.fc1(x)
357
+ x = self.act(x)
358
+ x = self.norm(x)
359
+ x = self.fc2(x)
360
+ return x
361
+
362
+
363
+ class TimestepEmbedding(nn.Module):
364
+ def __init__(self, hidden_size, freq_embed_size=256):
365
+ super().__init__()
366
+ self.mlp = nn.Sequential(
367
+ nn.Linear(freq_embed_size, hidden_size),
368
+ nn.SiLU(),
369
+ nn.Linear(hidden_size, hidden_size),
370
+ )
371
+ self.freq_embed_size = freq_embed_size
372
+
373
+ def forward(self, t, dtype=None, **kwargs):
374
+ t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
375
+ t_emb = self.mlp(t_freq)
376
+ return t_emb
377
+
378
+
379
+ class Embedder(nn.Module):
380
+ def __init__(self, input_dim, hidden_size):
381
+ super().__init__()
382
+ self.mlp = nn.Sequential(
383
+ nn.Linear(input_dim, hidden_size),
384
+ nn.SiLU(),
385
+ nn.Linear(hidden_size, hidden_size),
386
+ )
387
+
388
+ def forward(self, x):
389
+ return self.mlp(x)
390
+
391
+
392
+ def rmsnorm(x, eps=1e-6):
393
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
394
+
395
+
396
+ class RMSNorm(torch.nn.Module):
397
+ def __init__(
398
+ self,
399
+ dim: int,
400
+ elementwise_affine: bool = False,
401
+ eps: float = 1e-6,
402
+ device=None,
403
+ dtype=None,
404
+ ):
405
+ """
406
+ Initialize the RMSNorm normalization layer.
407
+ Args:
408
+ dim (int): The dimension of the input tensor.
409
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
410
+ Attributes:
411
+ eps (float): A small value added to the denominator for numerical stability.
412
+ weight (nn.Parameter): Learnable scaling parameter.
413
+ """
414
+ super().__init__()
415
+ self.eps = eps
416
+ self.learnable_scale = elementwise_affine
417
+ if self.learnable_scale:
418
+ self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
419
+ else:
420
+ self.register_parameter("weight", None)
421
+
422
+ def forward(self, x):
423
+ """
424
+ Forward pass through the RMSNorm layer.
425
+ Args:
426
+ x (torch.Tensor): The input tensor.
427
+ Returns:
428
+ torch.Tensor: The output tensor after applying RMSNorm.
429
+ """
430
+ x = rmsnorm(x, eps=self.eps)
431
+ if self.learnable_scale:
432
+ return x * self.weight.to(device=x.device, dtype=x.dtype)
433
+ else:
434
+ return x
435
+
436
+
437
+ class SwiGLUFeedForward(nn.Module):
438
+ def __init__(
439
+ self,
440
+ dim: int,
441
+ hidden_dim: int,
442
+ multiple_of: int,
443
+ ffn_dim_multiplier: float = None,
444
+ ):
445
+ super().__init__()
446
+ hidden_dim = int(2 * hidden_dim / 3)
447
+ # custom dim factor multiplier
448
+ if ffn_dim_multiplier is not None:
449
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
450
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
451
+
452
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
453
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
454
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
455
+
456
+ def forward(self, x):
457
+ return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
458
+
459
+
460
+ # Linears for SelfAttention in mmdit.py
461
+ class AttentionLinears(nn.Module):
462
+ def __init__(
463
+ self,
464
+ dim: int,
465
+ num_heads: int = 8,
466
+ qkv_bias: bool = False,
467
+ pre_only: bool = False,
468
+ qk_norm: Optional[str] = None,
469
+ ):
470
+ super().__init__()
471
+ self.num_heads = num_heads
472
+ self.head_dim = dim // num_heads
473
+
474
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
475
+ if not pre_only:
476
+ self.proj = nn.Linear(dim, dim)
477
+ self.pre_only = pre_only
478
+
479
+ if qk_norm == "rms":
480
+ self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
481
+ self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
482
+ elif qk_norm == "ln":
483
+ self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
484
+ self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
485
+ elif qk_norm is None:
486
+ self.ln_q = nn.Identity()
487
+ self.ln_k = nn.Identity()
488
+ else:
489
+ raise ValueError(qk_norm)
490
+
491
+ def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ output:
494
+ q, k, v: [B, L, D]
495
+ """
496
+ B, L, C = x.shape
497
+ qkv: torch.Tensor = self.qkv(x)
498
+ q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
499
+ q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
500
+ k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
501
+ return (q, k, v)
502
+
503
+ def post_attention(self, x: torch.Tensor) -> torch.Tensor:
504
+ assert not self.pre_only
505
+ x = self.proj(x)
506
+ return x
507
+
508
+
509
+ MEMORY_LAYOUTS = {
510
+ "torch": (
511
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
512
+ lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
513
+ lambda x: (1, x, 1, 1),
514
+ ),
515
+ "xformers": (
516
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
517
+ lambda x: x.reshape(x.shape[0], x.shape[1], -1),
518
+ lambda x: (1, 1, x, 1),
519
+ ),
520
+ "math": (
521
+ lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
522
+ lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
523
+ lambda x: (1, x, 1, 1),
524
+ ),
525
+ }
526
+ # ATTN_FUNCTION = {
527
+ # "torch": F.scaled_dot_product_attention,
528
+ # "xformers": memory_efficient_attention,
529
+ # }
530
+
531
+
532
+ def vanilla_attention(q, k, v, mask, scale=None):
533
+ if scale is None:
534
+ scale = math.sqrt(q.size(-1))
535
+ scores = torch.bmm(q, k.transpose(-1, -2)) / scale
536
+ if mask is not None:
537
+ mask = einops.rearrange(mask, "b ... -> b (...)")
538
+ max_neg_value = -torch.finfo(scores.dtype).max
539
+ mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
540
+ scores = scores.masked_fill(~mask, max_neg_value)
541
+ p_attn = F.softmax(scores, dim=-1)
542
+ return torch.bmm(p_attn, v)
543
+
544
+
545
+ def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
546
+ """
547
+ q, k, v: [B, L, D]
548
+ """
549
+ pre_attn_layout = MEMORY_LAYOUTS[mode][0]
550
+ post_attn_layout = MEMORY_LAYOUTS[mode][1]
551
+ q = pre_attn_layout(q, head_dim)
552
+ k = pre_attn_layout(k, head_dim)
553
+ v = pre_attn_layout(v, head_dim)
554
+
555
+ # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
556
+ if mode == "torch":
557
+ assert scale is None
558
+ scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
559
+ elif mode == "xformers":
560
+ scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
561
+ else:
562
+ scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
563
+
564
+ scores = post_attn_layout(scores)
565
+ return scores
566
+
567
+
568
+ # DismantledBlock in mmdit.py
569
+ class SingleDiTBlock(nn.Module):
570
+ """
571
+ A DiT block with gated adaptive layer norm (adaLN) conditioning.
572
+ """
573
+
574
+ def __init__(
575
+ self,
576
+ hidden_size: int,
577
+ num_heads: int,
578
+ mlp_ratio: float = 4.0,
579
+ attn_mode: str = "xformers",
580
+ qkv_bias: bool = False,
581
+ pre_only: bool = False,
582
+ rmsnorm: bool = False,
583
+ scale_mod_only: bool = False,
584
+ swiglu: bool = False,
585
+ qk_norm: Optional[str] = None,
586
+ x_block_self_attn: bool = False,
587
+ **block_kwargs,
588
+ ):
589
+ super().__init__()
590
+ assert attn_mode in MEMORY_LAYOUTS
591
+ self.attn_mode = attn_mode
592
+ if not rmsnorm:
593
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
594
+ else:
595
+ self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
596
+ self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm)
597
+
598
+ self.x_block_self_attn = x_block_self_attn
599
+ if self.x_block_self_attn:
600
+ assert not pre_only
601
+ assert not scale_mod_only
602
+ self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm)
603
+
604
+ if not pre_only:
605
+ if not rmsnorm:
606
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
607
+ else:
608
+ self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
609
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
610
+ if not pre_only:
611
+ if not swiglu:
612
+ self.mlp = MLP(
613
+ in_features=hidden_size,
614
+ hidden_features=mlp_hidden_dim,
615
+ act_layer=lambda: nn.GELU(approximate="tanh"),
616
+ )
617
+ else:
618
+ self.mlp = SwiGLUFeedForward(
619
+ dim=hidden_size,
620
+ hidden_dim=mlp_hidden_dim,
621
+ multiple_of=256,
622
+ )
623
+ self.scale_mod_only = scale_mod_only
624
+ if self.x_block_self_attn:
625
+ n_mods = 9
626
+ elif not scale_mod_only:
627
+ n_mods = 6 if not pre_only else 2
628
+ else:
629
+ n_mods = 4 if not pre_only else 1
630
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
631
+ self.pre_only = pre_only
632
+
633
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
634
+ if not self.pre_only:
635
+ if not self.scale_mod_only:
636
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1)
637
+ else:
638
+ shift_msa = None
639
+ shift_mlp = None
640
+ (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1)
641
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
642
+ return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
643
+ else:
644
+ if not self.scale_mod_only:
645
+ (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1)
646
+ else:
647
+ shift_msa = None
648
+ scale_msa = self.adaLN_modulation(c)
649
+ qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
650
+ return qkv, None
651
+
652
+ def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
653
+ assert self.x_block_self_attn
654
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation(
655
+ c
656
+ ).chunk(9, dim=1)
657
+ x_norm = self.norm1(x)
658
+ qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
659
+ qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
660
+ return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2)
661
+
662
+ def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
663
+ assert not self.pre_only
664
+ x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
665
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
666
+ return x
667
+
668
+ def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0):
669
+ assert not self.pre_only
670
+ if attn1_dropout > 0.0:
671
+ # Use torch.bernoulli to implement dropout, only dropout the batch dimension
672
+ attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device))
673
+ attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
674
+ else:
675
+ attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
676
+ x = x + attn_
677
+ attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
678
+ x = x + attn2_
679
+ mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
680
+ x = x + mlp_
681
+ return x
682
+
683
+
684
+ # JointBlock + block_mixing in mmdit.py
685
+ class MMDiTBlock(nn.Module):
686
+ def __init__(self, *args, **kwargs):
687
+ super().__init__()
688
+ pre_only = kwargs.pop("pre_only")
689
+ x_block_self_attn = kwargs.pop("x_block_self_attn")
690
+
691
+ self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
692
+ self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs)
693
+
694
+ self.head_dim = self.x_block.attn.head_dim
695
+ self.mode = self.x_block.attn_mode
696
+ self.gradient_checkpointing = False
697
+
698
+ def enable_gradient_checkpointing(self):
699
+ self.gradient_checkpointing = True
700
+
701
+ def _forward(self, context, x, c):
702
+ ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
703
+
704
+ if self.x_block.x_block_self_attn:
705
+ x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c)
706
+ else:
707
+ x_qkv, x_intermediates = self.x_block.pre_attention(x, c)
708
+
709
+ ctx_len = ctx_qkv[0].size(1)
710
+
711
+ q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
712
+ k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
713
+ v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)
714
+
715
+ attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
716
+ ctx_attn_out = attn[:, :ctx_len]
717
+ x_attn_out = attn[:, ctx_len:]
718
+
719
+ if self.x_block.x_block_self_attn:
720
+ x_q2, x_k2, x_v2 = x_qkv2
721
+ attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
722
+ x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
723
+ else:
724
+ x = self.x_block.post_attention(x_attn_out, *x_intermediates)
725
+
726
+ if not self.context_block.pre_only:
727
+ context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
728
+ else:
729
+ context = None
730
+
731
+ return context, x
732
+
733
+ def forward(self, *args, **kwargs):
734
+ if self.training and self.gradient_checkpointing:
735
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
736
+ else:
737
+ return self._forward(*args, **kwargs)
738
+
739
+
740
+ class MMDiT(nn.Module):
741
+ """
742
+ Diffusion model with a Transformer backbone.
743
+ """
744
+
745
+ # prepare pos_embed for latent size * 2
746
+ POS_EMBED_MAX_RATIO = 1.5
747
+
748
+ def __init__(
749
+ self,
750
+ input_size: int = 32,
751
+ patch_size: int = 2,
752
+ in_channels: int = 4,
753
+ depth: int = 28,
754
+ # hidden_size: Optional[int] = None,
755
+ # num_heads: Optional[int] = None,
756
+ mlp_ratio: float = 4.0,
757
+ learn_sigma: bool = False,
758
+ adm_in_channels: Optional[int] = None,
759
+ context_embedder_in_features: Optional[int] = None,
760
+ context_embedder_out_features: Optional[int] = None,
761
+ use_checkpoint: bool = False,
762
+ register_length: int = 0,
763
+ attn_mode: str = "torch",
764
+ rmsnorm: bool = False,
765
+ scale_mod_only: bool = False,
766
+ swiglu: bool = False,
767
+ out_channels: Optional[int] = None,
768
+ pos_embed_scaling_factor: Optional[float] = None,
769
+ pos_embed_offset: Optional[float] = None,
770
+ pos_embed_max_size: Optional[int] = None,
771
+ num_patches=None,
772
+ qk_norm: Optional[str] = None,
773
+ x_block_self_attn_layers: Optional[list[int]] = [],
774
+ qkv_bias: bool = True,
775
+ pos_emb_random_crop_rate: float = 0.0,
776
+ use_scaled_pos_embed: bool = False,
777
+ pos_embed_latent_sizes: Optional[list[int]] = None,
778
+ model_type: str = "sd3m",
779
+ ):
780
+ super().__init__()
781
+ self._model_type = model_type
782
+ self.learn_sigma = learn_sigma
783
+ self.in_channels = in_channels
784
+ default_out_channels = in_channels * 2 if learn_sigma else in_channels
785
+ self.out_channels = default(out_channels, default_out_channels)
786
+ self.patch_size = patch_size
787
+ self.pos_embed_scaling_factor = pos_embed_scaling_factor
788
+ self.pos_embed_offset = pos_embed_offset
789
+ self.pos_embed_max_size = pos_embed_max_size
790
+ self.x_block_self_attn_layers = x_block_self_attn_layers
791
+ self.pos_emb_random_crop_rate = pos_emb_random_crop_rate
792
+ self.gradient_checkpointing = use_checkpoint
793
+
794
+ # hidden_size = default(hidden_size, 64 * depth)
795
+ # num_heads = default(num_heads, hidden_size // 64)
796
+
797
+ # apply magic --> this defines a head_size of 64
798
+ self.hidden_size = 64 * depth
799
+ num_heads = depth
800
+
801
+ self.num_heads = num_heads
802
+
803
+ self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes)
804
+
805
+ self.x_embedder = PatchEmbed(
806
+ input_size,
807
+ patch_size,
808
+ in_channels,
809
+ self.hidden_size,
810
+ bias=True,
811
+ strict_img_size=self.pos_embed_max_size is None,
812
+ )
813
+ self.t_embedder = TimestepEmbedding(self.hidden_size)
814
+
815
+ self.y_embedder = None
816
+ if adm_in_channels is not None:
817
+ assert isinstance(adm_in_channels, int)
818
+ self.y_embedder = Embedder(adm_in_channels, self.hidden_size)
819
+
820
+ if context_embedder_in_features is not None:
821
+ self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
822
+ else:
823
+ self.context_embedder = nn.Identity()
824
+
825
+ self.register_length = register_length
826
+ if self.register_length > 0:
827
+ self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))
828
+
829
+ # num_patches = self.x_embedder.num_patches
830
+ # Will use fixed sin-cos embedding:
831
+ # just use a buffer already
832
+ if num_patches is not None:
833
+ self.register_buffer(
834
+ "pos_embed",
835
+ torch.empty(1, num_patches, self.hidden_size),
836
+ )
837
+ else:
838
+ self.pos_embed = None
839
+
840
+ self.use_checkpoint = use_checkpoint
841
+ self.joint_blocks = nn.ModuleList(
842
+ [
843
+ MMDiTBlock(
844
+ self.hidden_size,
845
+ num_heads,
846
+ mlp_ratio=mlp_ratio,
847
+ attn_mode=attn_mode,
848
+ qkv_bias=qkv_bias,
849
+ pre_only=i == depth - 1,
850
+ rmsnorm=rmsnorm,
851
+ scale_mod_only=scale_mod_only,
852
+ swiglu=swiglu,
853
+ qk_norm=qk_norm,
854
+ x_block_self_attn=(i in self.x_block_self_attn_layers),
855
+ )
856
+ for i in range(depth)
857
+ ]
858
+ )
859
+ for block in self.joint_blocks:
860
+ block.gradient_checkpointing = use_checkpoint
861
+
862
+ self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
863
+ # self.initialize_weights()
864
+
865
+ self.blocks_to_swap = None
866
+ self.offloader = None
867
+ self.num_blocks = len(self.joint_blocks)
868
+
869
+ def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]):
870
+ self.use_scaled_pos_embed = use_scaled_pos_embed
871
+
872
+ if self.use_scaled_pos_embed:
873
+ # remove pos_embed to free up memory up to 0.4 GB
874
+ self.pos_embed = None
875
+
876
+ # remove duplicates and sort latent sizes in ascending order
877
+ latent_sizes = list(set(latent_sizes))
878
+ latent_sizes = sorted(latent_sizes)
879
+
880
+ patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]
881
+
882
+ # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape
883
+ max_areas = []
884
+ for i in range(1, len(patched_sizes)):
885
+ prev_area = patched_sizes[i - 1] ** 2
886
+ area = patched_sizes[i] ** 2
887
+ max_areas.append((prev_area + area) // 2)
888
+
889
+ # area of the last latent size, if the latent size exceeds this, error will be raised
890
+ max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2))
891
+ # print("max_areas", max_areas)
892
+
893
+ self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)]
894
+
895
+ self.resolution_pos_embeds = {}
896
+ for patched_size in patched_sizes:
897
+ grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
898
+ pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size)
899
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
900
+ self.resolution_pos_embeds[patched_size] = pos_embed
901
+ # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}")
902
+
903
+ else:
904
+ self.resolution_area_to_latent_size = None
905
+ self.resolution_pos_embeds = None
906
+
907
+ @property
908
+ def model_type(self):
909
+ return self._model_type
910
+
911
+ @property
912
+ def device(self):
913
+ return next(self.parameters()).device
914
+
915
+ @property
916
+ def dtype(self):
917
+ return next(self.parameters()).dtype
918
+
919
+ def enable_gradient_checkpointing(self):
920
+ self.gradient_checkpointing = True
921
+ for block in self.joint_blocks:
922
+ block.enable_gradient_checkpointing()
923
+
924
+ def disable_gradient_checkpointing(self):
925
+ self.gradient_checkpointing = False
926
+ for block in self.joint_blocks:
927
+ block.disable_gradient_checkpointing()
928
+
929
+ def initialize_weights(self):
930
+ # TODO: Init context_embedder?
931
+ # Initialize transformer layers:
932
+ def _basic_init(module):
933
+ if isinstance(module, nn.Linear):
934
+ torch.nn.init.xavier_uniform_(module.weight)
935
+ if module.bias is not None:
936
+ nn.init.constant_(module.bias, 0)
937
+
938
+ self.apply(_basic_init)
939
+
940
+ # Initialize (and freeze) pos_embed by sin-cos embedding
941
+ if self.pos_embed is not None:
942
+ pos_embed = get_2d_sincos_pos_embed(
943
+ self.pos_embed.shape[-1],
944
+ int(self.pos_embed.shape[-2] ** 0.5),
945
+ scaling_factor=self.pos_embed_scaling_factor,
946
+ )
947
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
948
+
949
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
950
+ w = self.x_embedder.proj.weight.data
951
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
952
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
953
+
954
+ if getattr(self, "y_embedder", None) is not None:
955
+ nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
956
+ nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
957
+
958
+ # Initialize timestep embedding MLP:
959
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
960
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
961
+
962
+ # Zero-out adaLN modulation layers in DiT blocks:
963
+ for block in self.joint_blocks:
964
+ nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
965
+ nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
966
+ nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
967
+ nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
968
+
969
+ # Zero-out output layers:
970
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
971
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
972
+ nn.init.constant_(self.final_layer.linear.weight, 0)
973
+ nn.init.constant_(self.final_layer.linear.bias, 0)
974
+
975
+ def set_pos_emb_random_crop_rate(self, rate: float):
976
+ self.pos_emb_random_crop_rate = rate
977
+
978
+ def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False):
979
+ p = self.x_embedder.patch_size
980
+ # patched size
981
+ h = (h + 1) // p
982
+ w = (w + 1) // p
983
+ if self.pos_embed is None: # should not happen
984
+ return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
985
+ assert self.pos_embed_max_size is not None
986
+ assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
987
+ assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
988
+
989
+ if not random_crop:
990
+ top = (self.pos_embed_max_size - h) // 2
991
+ left = (self.pos_embed_max_size - w) // 2
992
+ else:
993
+ top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item()
994
+ left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item()
995
+
996
+ spatial_pos_embed = self.pos_embed.reshape(
997
+ 1,
998
+ self.pos_embed_max_size,
999
+ self.pos_embed_max_size,
1000
+ self.pos_embed.shape[-1],
1001
+ )
1002
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
1003
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
1004
+ return spatial_pos_embed
1005
+
1006
+ def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False):
1007
+ p = self.x_embedder.patch_size
1008
+ # patched size
1009
+ h = (h + 1) // p
1010
+ w = (w + 1) // p
1011
+
1012
+ # select pos_embed size based on area
1013
+ area = h * w
1014
+ patched_size = None
1015
+ for area_, patched_size_ in self.resolution_area_to_latent_size:
1016
+ if area <= area_:
1017
+ patched_size = patched_size_
1018
+ break
1019
+ if patched_size is None:
1020
+ raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
1021
+
1022
+ pos_embed = self.resolution_pos_embeds[patched_size]
1023
+ pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
1024
+ if h > pos_embed_size or w > pos_embed_size:
1025
+ # # fallback to normal pos_embed
1026
+ # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
1027
+ # extend pos_embed size
1028
+ logger.warning(
1029
+ f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
1030
+ )
1031
+ pos_embed_size = max(h, w)
1032
+ pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
1033
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
1034
+ self.resolution_pos_embeds[patched_size] = pos_embed
1035
+ logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
1036
+
1037
+ if not random_crop:
1038
+ top = (pos_embed_size - h) // 2
1039
+ left = (pos_embed_size - w) // 2
1040
+ else:
1041
+ top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
1042
+ left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
1043
+
1044
+ if pos_embed.device != device:
1045
+ pos_embed = pos_embed.to(device)
1046
+ # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
1047
+ self.resolution_pos_embeds[patched_size] = pos_embed # update device
1048
+ if pos_embed.dtype != dtype:
1049
+ pos_embed = pos_embed.to(dtype)
1050
+ self.resolution_pos_embeds[patched_size] = pos_embed # update dtype
1051
+
1052
+ spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1])
1053
+ spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
1054
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
1055
+ # print(
1056
+ # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}"
1057
+ # )
1058
+ return spatial_pos_embed
1059
+
1060
+ def enable_block_swap(self, num_blocks: int, device: torch.device):
1061
+ self.blocks_to_swap = num_blocks
1062
+
1063
+ assert (
1064
+ self.blocks_to_swap <= self.num_blocks - 2
1065
+ ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
1066
+
1067
+ self.offloader = custom_offloading_utils.ModelOffloader(
1068
+ self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
1069
+ )
1070
+ print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
1071
+
1072
+ def move_to_device_except_swap_blocks(self, device: torch.device):
1073
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
1074
+ if self.blocks_to_swap:
1075
+ save_blocks = self.joint_blocks
1076
+ self.joint_blocks = None
1077
+
1078
+ self.to(device)
1079
+
1080
+ if self.blocks_to_swap:
1081
+ self.joint_blocks = save_blocks
1082
+
1083
+ def prepare_block_swap_before_forward(self):
1084
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1085
+ return
1086
+ self.offloader.prepare_block_devices_before_forward(self.joint_blocks)
1087
+
1088
+ def forward(
1089
+ self,
1090
+ x: torch.Tensor,
1091
+ t: torch.Tensor,
1092
+ y: Optional[torch.Tensor] = None,
1093
+ context: Optional[torch.Tensor] = None,
1094
+ ) -> torch.Tensor:
1095
+ """
1096
+ Forward pass of DiT.
1097
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
1098
+ t: (N,) tensor of diffusion timesteps
1099
+ y: (N, D) tensor of class labels
1100
+ """
1101
+ pos_emb_random_crop = (
1102
+ False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate
1103
+ )
1104
+
1105
+ B, C, H, W = x.shape
1106
+
1107
+ # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
1108
+ if not self.use_scaled_pos_embed:
1109
+ pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype)
1110
+ else:
1111
+ # print(f"Using scaled pos_embed for size {H}x{W}")
1112
+ pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop)
1113
+ x = self.x_embedder(x) + pos_embed
1114
+ del pos_embed
1115
+
1116
+ c = self.t_embedder(t, dtype=x.dtype) # (N, D)
1117
+ if y is not None and self.y_embedder is not None:
1118
+ y = self.y_embedder(y) # (N, D)
1119
+ c = c + y # (N, D)
1120
+
1121
+ if context is not None:
1122
+ context = self.context_embedder(context)
1123
+
1124
+ if self.register_length > 0:
1125
+ context = torch.cat(
1126
+ (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1
1127
+ )
1128
+
1129
+ if not self.blocks_to_swap:
1130
+ for block in self.joint_blocks:
1131
+ context, x = block(context, x, c)
1132
+ else:
1133
+ for block_idx, block in enumerate(self.joint_blocks):
1134
+ self.offloader.wait_for_block(block_idx)
1135
+
1136
+ context, x = block(context, x, c)
1137
+
1138
+ self.offloader.submit_move_blocks(self.joint_blocks, block_idx)
1139
+
1140
+ x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
1141
+ return x[:, :, :H, :W]
1142
+
1143
+
1144
+ def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
1145
+ mmdit = MMDiT(
1146
+ input_size=None,
1147
+ pos_embed_max_size=params.pos_embed_max_size,
1148
+ patch_size=params.patch_size,
1149
+ in_channels=16,
1150
+ adm_in_channels=params.adm_in_channels,
1151
+ context_embedder_in_features=params.context_embedder_in_features,
1152
+ context_embedder_out_features=params.context_embedder_out_features,
1153
+ depth=params.depth,
1154
+ mlp_ratio=4,
1155
+ qk_norm=params.qk_norm,
1156
+ x_block_self_attn_layers=params.x_block_self_attn_layers,
1157
+ num_patches=params.num_patches,
1158
+ attn_mode=attn_mode,
1159
+ model_type=params.model_type,
1160
+ )
1161
+ return mmdit
1162
+
1163
+
1164
+ # endregion
1165
+
1166
+ # region VAE
1167
+
1168
+ VAE_SCALE_FACTOR = 1.5305
1169
+ VAE_SHIFT_FACTOR = 0.0609
1170
+
1171
+
1172
+ def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
1173
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
1174
+
1175
+
1176
+ class ResnetBlock(torch.nn.Module):
1177
+ def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
1178
+ super().__init__()
1179
+ self.in_channels = in_channels
1180
+ out_channels = in_channels if out_channels is None else out_channels
1181
+ self.out_channels = out_channels
1182
+
1183
+ self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
1184
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1185
+ self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
1186
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1187
+ if self.in_channels != self.out_channels:
1188
+ self.nin_shortcut = torch.nn.Conv2d(
1189
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
1190
+ )
1191
+ else:
1192
+ self.nin_shortcut = None
1193
+ self.swish = torch.nn.SiLU(inplace=True)
1194
+
1195
+ def forward(self, x):
1196
+ hidden = x
1197
+ hidden = self.norm1(hidden)
1198
+ hidden = self.swish(hidden)
1199
+ hidden = self.conv1(hidden)
1200
+ hidden = self.norm2(hidden)
1201
+ hidden = self.swish(hidden)
1202
+ hidden = self.conv2(hidden)
1203
+ if self.in_channels != self.out_channels:
1204
+ x = self.nin_shortcut(x)
1205
+ return x + hidden
1206
+
1207
+
1208
+ class AttnBlock(torch.nn.Module):
1209
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1210
+ super().__init__()
1211
+ self.norm = Normalize(in_channels, dtype=dtype, device=device)
1212
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1213
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1214
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1215
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
1216
+
1217
+ def forward(self, x):
1218
+ hidden = self.norm(x)
1219
+ q = self.q(hidden)
1220
+ k = self.k(hidden)
1221
+ v = self.v(hidden)
1222
+ b, c, h, w = q.shape
1223
+ q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
1224
+ hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
1225
+ hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
1226
+ hidden = self.proj_out(hidden)
1227
+ return x + hidden
1228
+
1229
+
1230
+ class Downsample(torch.nn.Module):
1231
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1232
+ super().__init__()
1233
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
1234
+
1235
+ def forward(self, x):
1236
+ pad = (0, 1, 0, 1)
1237
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
1238
+ x = self.conv(x)
1239
+ return x
1240
+
1241
+
1242
+ class Upsample(torch.nn.Module):
1243
+ def __init__(self, in_channels, dtype=torch.float32, device=None):
1244
+ super().__init__()
1245
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1246
+
1247
+ def forward(self, x):
1248
+ org_dtype = x.dtype
1249
+ if x.dtype == torch.bfloat16:
1250
+ x = x.to(torch.float32)
1251
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
1252
+ if x.dtype != org_dtype:
1253
+ x = x.to(org_dtype)
1254
+ x = self.conv(x)
1255
+ return x
1256
+
1257
+
1258
+ class VAEEncoder(torch.nn.Module):
1259
+ def __init__(
1260
+ self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
1261
+ ):
1262
+ super().__init__()
1263
+ self.num_resolutions = len(ch_mult)
1264
+ self.num_res_blocks = num_res_blocks
1265
+ # downsampling
1266
+ self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1267
+ in_ch_mult = (1,) + tuple(ch_mult)
1268
+ self.in_ch_mult = in_ch_mult
1269
+ self.down = torch.nn.ModuleList()
1270
+ for i_level in range(self.num_resolutions):
1271
+ block = torch.nn.ModuleList()
1272
+ attn = torch.nn.ModuleList()
1273
+ block_in = ch * in_ch_mult[i_level]
1274
+ block_out = ch * ch_mult[i_level]
1275
+ for i_block in range(num_res_blocks):
1276
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
1277
+ block_in = block_out
1278
+ down = torch.nn.Module()
1279
+ down.block = block
1280
+ down.attn = attn
1281
+ if i_level != self.num_resolutions - 1:
1282
+ down.downsample = Downsample(block_in, dtype=dtype, device=device)
1283
+ self.down.append(down)
1284
+ # middle
1285
+ self.mid = torch.nn.Module()
1286
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1287
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
1288
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1289
+ # end
1290
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
1291
+ self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1292
+ self.swish = torch.nn.SiLU(inplace=True)
1293
+
1294
+ def forward(self, x):
1295
+ # downsampling
1296
+ hs = [self.conv_in(x)]
1297
+ for i_level in range(self.num_resolutions):
1298
+ for i_block in range(self.num_res_blocks):
1299
+ h = self.down[i_level].block[i_block](hs[-1])
1300
+ hs.append(h)
1301
+ if i_level != self.num_resolutions - 1:
1302
+ hs.append(self.down[i_level].downsample(hs[-1]))
1303
+ # middle
1304
+ h = hs[-1]
1305
+ h = self.mid.block_1(h)
1306
+ h = self.mid.attn_1(h)
1307
+ h = self.mid.block_2(h)
1308
+ # end
1309
+ h = self.norm_out(h)
1310
+ h = self.swish(h)
1311
+ h = self.conv_out(h)
1312
+ return h
1313
+
1314
+
1315
+ class VAEDecoder(torch.nn.Module):
1316
+ def __init__(
1317
+ self,
1318
+ ch=128,
1319
+ out_ch=3,
1320
+ ch_mult=(1, 2, 4, 4),
1321
+ num_res_blocks=2,
1322
+ resolution=256,
1323
+ z_channels=16,
1324
+ dtype=torch.float32,
1325
+ device=None,
1326
+ ):
1327
+ super().__init__()
1328
+ self.num_resolutions = len(ch_mult)
1329
+ self.num_res_blocks = num_res_blocks
1330
+ block_in = ch * ch_mult[self.num_resolutions - 1]
1331
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
1332
+ # z to block_in
1333
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1334
+ # middle
1335
+ self.mid = torch.nn.Module()
1336
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1337
+ self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
1338
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
1339
+ # upsampling
1340
+ self.up = torch.nn.ModuleList()
1341
+ for i_level in reversed(range(self.num_resolutions)):
1342
+ block = torch.nn.ModuleList()
1343
+ block_out = ch * ch_mult[i_level]
1344
+ for i_block in range(self.num_res_blocks + 1):
1345
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
1346
+ block_in = block_out
1347
+ up = torch.nn.Module()
1348
+ up.block = block
1349
+ if i_level != 0:
1350
+ up.upsample = Upsample(block_in, dtype=dtype, device=device)
1351
+ curr_res = curr_res * 2
1352
+ self.up.insert(0, up) # prepend to get consistent order
1353
+ # end
1354
+ self.norm_out = Normalize(block_in, dtype=dtype, device=device)
1355
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
1356
+ self.swish = torch.nn.SiLU(inplace=True)
1357
+
1358
+ def forward(self, z):
1359
+ # z to block_in
1360
+ hidden = self.conv_in(z)
1361
+ # middle
1362
+ hidden = self.mid.block_1(hidden)
1363
+ hidden = self.mid.attn_1(hidden)
1364
+ hidden = self.mid.block_2(hidden)
1365
+ # upsampling
1366
+ for i_level in reversed(range(self.num_resolutions)):
1367
+ for i_block in range(self.num_res_blocks + 1):
1368
+ hidden = self.up[i_level].block[i_block](hidden)
1369
+ if i_level != 0:
1370
+ hidden = self.up[i_level].upsample(hidden)
1371
+ # end
1372
+ hidden = self.norm_out(hidden)
1373
+ hidden = self.swish(hidden)
1374
+ hidden = self.conv_out(hidden)
1375
+ return hidden
1376
+
1377
+
1378
+ class SDVAE(torch.nn.Module):
1379
+ def __init__(self, dtype=torch.float32, device=None):
1380
+ super().__init__()
1381
+ self.encoder = VAEEncoder(dtype=dtype, device=device)
1382
+ self.decoder = VAEDecoder(dtype=dtype, device=device)
1383
+
1384
+ @property
1385
+ def device(self):
1386
+ return next(self.parameters()).device
1387
+
1388
+ @property
1389
+ def dtype(self):
1390
+ return next(self.parameters()).dtype
1391
+
1392
+ # @torch.autocast("cuda", dtype=torch.float16)
1393
+ def decode(self, latent):
1394
+ return self.decoder(latent)
1395
+
1396
+ # @torch.autocast("cuda", dtype=torch.float16)
1397
+ def encode(self, image):
1398
+ hidden = self.encoder(image)
1399
+ mean, logvar = torch.chunk(hidden, 2, dim=1)
1400
+ logvar = torch.clamp(logvar, -30.0, 20.0)
1401
+ std = torch.exp(0.5 * logvar)
1402
+ return mean + std * torch.randn_like(mean)
1403
+
1404
+ @staticmethod
1405
+ def process_in(latent):
1406
+ return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR
1407
+
1408
+ @staticmethod
1409
+ def process_out(latent):
1410
+ return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR
1411
+
1412
+
1413
+ # endregion
library/sd3_train_utils.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import toml
5
+ import json
6
+ import time
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from safetensors.torch import save_file
11
+ from accelerate import Accelerator, PartialState
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from transformers import CLIPTextModelWithProjection, T5EncoderModel
15
+
16
+ from library.device_utils import init_ipex, clean_memory_on_device
17
+
18
+ init_ipex()
19
+
20
+ # from transformers import CLIPTokenizer
21
+ # from library import model_util
22
+ # , sdxl_model_util, train_util, sdxl_original_unet
23
+ # from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
24
+ from .utils import setup_logging
25
+
26
+ setup_logging()
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ from library import sd3_models, sd3_utils, strategy_base, train_util
32
+
33
+
34
+ def save_models(
35
+ ckpt_path: str,
36
+ mmdit: Optional[sd3_models.MMDiT],
37
+ vae: Optional[sd3_models.SDVAE],
38
+ clip_l: Optional[CLIPTextModelWithProjection],
39
+ clip_g: Optional[CLIPTextModelWithProjection],
40
+ t5xxl: Optional[T5EncoderModel],
41
+ sai_metadata: Optional[dict],
42
+ save_dtype: Optional[torch.dtype] = None,
43
+ ):
44
+ r"""
45
+ Save models to checkpoint file. Only supports unified checkpoint format.
46
+ """
47
+
48
+ state_dict = {}
49
+
50
+ def update_sd(prefix, sd):
51
+ for k, v in sd.items():
52
+ key = prefix + k
53
+ if save_dtype is not None:
54
+ v = v.detach().clone().to("cpu").to(save_dtype)
55
+ state_dict[key] = v
56
+
57
+ update_sd("model.diffusion_model.", mmdit.state_dict())
58
+ update_sd("first_stage_model.", vae.state_dict())
59
+
60
+ # do not support unified checkpoint format for now
61
+ # if clip_l is not None:
62
+ # update_sd("text_encoders.clip_l.", clip_l.state_dict())
63
+ # if clip_g is not None:
64
+ # update_sd("text_encoders.clip_g.", clip_g.state_dict())
65
+ # if t5xxl is not None:
66
+ # update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
67
+
68
+ save_file(state_dict, ckpt_path, metadata=sai_metadata)
69
+
70
+ if clip_l is not None:
71
+ clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
72
+ save_file(clip_l.state_dict(), clip_l_path)
73
+ if clip_g is not None:
74
+ clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
75
+ save_file(clip_g.state_dict(), clip_g_path)
76
+ if t5xxl is not None:
77
+ t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
78
+ t5xxl_state_dict = t5xxl.state_dict()
79
+
80
+ # replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
81
+ shared_weight = t5xxl_state_dict["shared.weight"]
82
+ shared_weight_copy = shared_weight.detach().clone()
83
+ t5xxl_state_dict["shared.weight"] = shared_weight_copy
84
+
85
+ save_file(t5xxl_state_dict, t5xxl_path)
86
+
87
+
88
+ def save_sd3_model_on_train_end(
89
+ args: argparse.Namespace,
90
+ save_dtype: torch.dtype,
91
+ epoch: int,
92
+ global_step: int,
93
+ clip_l: Optional[CLIPTextModelWithProjection],
94
+ clip_g: Optional[CLIPTextModelWithProjection],
95
+ t5xxl: Optional[T5EncoderModel],
96
+ mmdit: sd3_models.MMDiT,
97
+ vae: sd3_models.SDVAE,
98
+ ):
99
+ def sd_saver(ckpt_file, epoch_no, global_step):
100
+ sai_metadata = train_util.get_sai_model_spec(
101
+ None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
102
+ )
103
+ save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
104
+
105
+ train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
106
+
107
+
108
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
109
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
110
+ def save_sd3_model_on_epoch_end_or_stepwise(
111
+ args: argparse.Namespace,
112
+ on_epoch_end: bool,
113
+ accelerator,
114
+ save_dtype: torch.dtype,
115
+ epoch: int,
116
+ num_train_epochs: int,
117
+ global_step: int,
118
+ clip_l: Optional[CLIPTextModelWithProjection],
119
+ clip_g: Optional[CLIPTextModelWithProjection],
120
+ t5xxl: Optional[T5EncoderModel],
121
+ mmdit: sd3_models.MMDiT,
122
+ vae: sd3_models.SDVAE,
123
+ ):
124
+ def sd_saver(ckpt_file, epoch_no, global_step):
125
+ sai_metadata = train_util.get_sai_model_spec(
126
+ None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
127
+ )
128
+ save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
129
+
130
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
131
+ args,
132
+ on_epoch_end,
133
+ accelerator,
134
+ True,
135
+ True,
136
+ epoch,
137
+ num_train_epochs,
138
+ global_step,
139
+ sd_saver,
140
+ None,
141
+ )
142
+
143
+
144
+ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
145
+ parser.add_argument(
146
+ "--clip_l",
147
+ type=str,
148
+ required=False,
149
+ help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
150
+ )
151
+ parser.add_argument(
152
+ "--clip_g",
153
+ type=str,
154
+ required=False,
155
+ help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
156
+ )
157
+ parser.add_argument(
158
+ "--t5xxl",
159
+ type=str,
160
+ required=False,
161
+ help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
162
+ )
163
+ parser.add_argument(
164
+ "--save_clip",
165
+ action="store_true",
166
+ help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
167
+ )
168
+ parser.add_argument(
169
+ "--save_t5xxl",
170
+ action="store_true",
171
+ help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--t5xxl_device",
176
+ type=str,
177
+ default=None,
178
+ help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
179
+ )
180
+ parser.add_argument(
181
+ "--t5xxl_dtype",
182
+ type=str,
183
+ default=None,
184
+ help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--t5xxl_max_token_length",
189
+ type=int,
190
+ default=256,
191
+ help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
192
+ )
193
+ parser.add_argument(
194
+ "--apply_lg_attn_mask",
195
+ action="store_true",
196
+ help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスク(ゼロ埋め)を適用する",
197
+ )
198
+ parser.add_argument(
199
+ "--apply_t5_attn_mask",
200
+ action="store_true",
201
+ help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスク(ゼロ埋め)を適用する",
202
+ )
203
+ parser.add_argument(
204
+ "--clip_l_dropout_rate",
205
+ type=float,
206
+ default=0.0,
207
+ help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
208
+ )
209
+ parser.add_argument(
210
+ "--clip_g_dropout_rate",
211
+ type=float,
212
+ default=0.0,
213
+ help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
214
+ )
215
+ parser.add_argument(
216
+ "--t5_dropout_rate",
217
+ type=float,
218
+ default=0.0,
219
+ help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
220
+ )
221
+ parser.add_argument(
222
+ "--pos_emb_random_crop_rate",
223
+ type=float,
224
+ default=0.0,
225
+ help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
226
+ " / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
227
+ )
228
+ parser.add_argument(
229
+ "--enable_scaled_pos_embed",
230
+ action="store_true",
231
+ help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
232
+ " / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
233
+ )
234
+
235
+ # Dependencies of Diffusers noise sampler has been removed for clarity in training
236
+
237
+ parser.add_argument(
238
+ "--training_shift",
239
+ type=float,
240
+ default=1.0,
241
+ help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
242
+ )
243
+
244
+
245
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
246
+ assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
247
+ if args.v_parameterization:
248
+ logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
249
+
250
+ if args.clip_skip is not None:
251
+ logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
252
+
253
+ # if args.multires_noise_iterations:
254
+ # logger.info(
255
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
256
+ # )
257
+ # else:
258
+ # if args.noise_offset is None:
259
+ # args.noise_offset = DEFAULT_NOISE_OFFSET
260
+ # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
261
+ # logger.info(
262
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
263
+ # )
264
+ # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
265
+
266
+ assert (
267
+ not hasattr(args, "weighted_captions") or not args.weighted_captions
268
+ ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
269
+
270
+ if supportTextEncoderCaching:
271
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
272
+ args.cache_text_encoder_outputs = True
273
+ logger.warning(
274
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
275
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
276
+ )
277
+
278
+
279
+ # temporary copied from sd3_minimal_inferece.py
280
+
281
+
282
+ def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
283
+ start = sampling.timestep(sampling.sigma_max)
284
+ end = sampling.timestep(sampling.sigma_min)
285
+ timesteps = torch.linspace(start, end, steps)
286
+ sigs = []
287
+ for x in range(len(timesteps)):
288
+ ts = timesteps[x]
289
+ sigs.append(sampling.sigma(ts))
290
+ sigs += [0.0]
291
+ return torch.FloatTensor(sigs)
292
+
293
+
294
+ def max_denoise(model_sampling, sigmas):
295
+ max_sigma = float(model_sampling.sigma_max)
296
+ sigma = float(sigmas[0])
297
+ return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
298
+
299
+
300
+ def do_sample(
301
+ height: int,
302
+ width: int,
303
+ seed: int,
304
+ cond: Tuple[torch.Tensor, torch.Tensor],
305
+ neg_cond: Tuple[torch.Tensor, torch.Tensor],
306
+ mmdit: sd3_models.MMDiT,
307
+ steps: int,
308
+ guidance_scale: float,
309
+ dtype: torch.dtype,
310
+ device: str,
311
+ ):
312
+ latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
313
+ latent = latent.to(dtype).to(device)
314
+
315
+ # noise = get_noise(seed, latent).to(device)
316
+ if seed is not None:
317
+ generator = torch.manual_seed(seed)
318
+ else:
319
+ generator = None
320
+ noise = (
321
+ torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
322
+ .to(latent.dtype)
323
+ .to(device)
324
+ )
325
+
326
+ model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
327
+
328
+ sigmas = get_all_sigmas(model_sampling, steps).to(device)
329
+
330
+ noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
331
+
332
+ c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
333
+ y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
334
+
335
+ x = noise_scaled.to(device).to(dtype)
336
+ # print(x.shape)
337
+
338
+ # with torch.no_grad():
339
+ for i in tqdm(range(len(sigmas) - 1)):
340
+ sigma_hat = sigmas[i]
341
+
342
+ timestep = model_sampling.timestep(sigma_hat).float()
343
+ timestep = torch.FloatTensor([timestep, timestep]).to(device)
344
+
345
+ x_c_nc = torch.cat([x, x], dim=0)
346
+ # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
347
+
348
+ mmdit.prepare_block_swap_before_forward()
349
+ model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
350
+ model_output = model_output.float()
351
+ batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
352
+
353
+ pos_out, neg_out = batched.chunk(2)
354
+ denoised = neg_out + (pos_out - neg_out) * guidance_scale
355
+ # print(denoised.shape)
356
+
357
+ # d = to_d(x, sigma_hat, denoised)
358
+ dims_to_append = x.ndim - sigma_hat.ndim
359
+ sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
360
+ # print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
361
+ """Converts a denoiser output to a Karras ODE derivative."""
362
+ d = (x - denoised) / sigma_hat_dims
363
+
364
+ dt = sigmas[i + 1] - sigma_hat
365
+
366
+ # Euler method
367
+ x = x + d * dt
368
+ x = x.to(dtype)
369
+
370
+ mmdit.prepare_block_swap_before_forward()
371
+ return x
372
+
373
+
374
+ def sample_images(
375
+ accelerator: Accelerator,
376
+ args: argparse.Namespace,
377
+ epoch,
378
+ steps,
379
+ mmdit,
380
+ vae,
381
+ text_encoders,
382
+ sample_prompts_te_outputs,
383
+ prompt_replacement=None,
384
+ ):
385
+ if steps == 0:
386
+ if not args.sample_at_first:
387
+ return
388
+ else:
389
+ if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
390
+ return
391
+ if args.sample_every_n_epochs is not None:
392
+ # sample_every_n_steps は無視する
393
+ if epoch is None or epoch % args.sample_every_n_epochs != 0:
394
+ return
395
+ else:
396
+ if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
397
+ return
398
+
399
+ logger.info("")
400
+ logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
401
+ if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
402
+ logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
403
+ return
404
+
405
+ distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
406
+
407
+ # unwrap unet and text_encoder(s)
408
+ mmdit = accelerator.unwrap_model(mmdit)
409
+ text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
410
+ # print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
411
+
412
+ prompts = train_util.load_prompts(args.sample_prompts)
413
+
414
+ save_dir = args.output_dir + "/sample"
415
+ os.makedirs(save_dir, exist_ok=True)
416
+
417
+ # save random state to restore later
418
+ rng_state = torch.get_rng_state()
419
+ cuda_rng_state = None
420
+ try:
421
+ cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
422
+ except Exception:
423
+ pass
424
+
425
+ if distributed_state.num_processes <= 1:
426
+ # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
427
+ with torch.no_grad(), accelerator.autocast():
428
+ for prompt_dict in prompts:
429
+ sample_image_inference(
430
+ accelerator,
431
+ args,
432
+ mmdit,
433
+ text_encoders,
434
+ vae,
435
+ save_dir,
436
+ prompt_dict,
437
+ epoch,
438
+ steps,
439
+ sample_prompts_te_outputs,
440
+ prompt_replacement,
441
+ )
442
+ else:
443
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
444
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
445
+ per_process_prompts = [] # list of lists
446
+ for i in range(distributed_state.num_processes):
447
+ per_process_prompts.append(prompts[i :: distributed_state.num_processes])
448
+
449
+ with torch.no_grad():
450
+ with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
451
+ for prompt_dict in prompt_dict_lists[0]:
452
+ sample_image_inference(
453
+ accelerator,
454
+ args,
455
+ mmdit,
456
+ text_encoders,
457
+ vae,
458
+ save_dir,
459
+ prompt_dict,
460
+ epoch,
461
+ steps,
462
+ sample_prompts_te_outputs,
463
+ prompt_replacement,
464
+ )
465
+
466
+ torch.set_rng_state(rng_state)
467
+ if cuda_rng_state is not None:
468
+ torch.cuda.set_rng_state(cuda_rng_state)
469
+
470
+ clean_memory_on_device(accelerator.device)
471
+
472
+
473
+ def sample_image_inference(
474
+ accelerator: Accelerator,
475
+ args: argparse.Namespace,
476
+ mmdit: sd3_models.MMDiT,
477
+ text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
478
+ vae: sd3_models.SDVAE,
479
+ save_dir,
480
+ prompt_dict,
481
+ epoch,
482
+ steps,
483
+ sample_prompts_te_outputs,
484
+ prompt_replacement,
485
+ ):
486
+ assert isinstance(prompt_dict, dict)
487
+ negative_prompt = prompt_dict.get("negative_prompt")
488
+ sample_steps = prompt_dict.get("sample_steps", 30)
489
+ width = prompt_dict.get("width", 512)
490
+ height = prompt_dict.get("height", 512)
491
+ scale = prompt_dict.get("scale", 7.5)
492
+ seed = prompt_dict.get("seed")
493
+ # controlnet_image = prompt_dict.get("controlnet_image")
494
+ prompt: str = prompt_dict.get("prompt", "")
495
+ # sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
496
+
497
+ if prompt_replacement is not None:
498
+ prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
499
+ if negative_prompt is not None:
500
+ negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
501
+
502
+ if seed is not None:
503
+ torch.manual_seed(seed)
504
+ torch.cuda.manual_seed(seed)
505
+ else:
506
+ # True random sample image generation
507
+ torch.seed()
508
+ torch.cuda.seed()
509
+
510
+ if negative_prompt is None:
511
+ negative_prompt = ""
512
+
513
+ height = max(64, height - height % 8) # round to divisible by 8
514
+ width = max(64, width - width % 8) # round to divisible by 8
515
+ logger.info(f"prompt: {prompt}")
516
+ logger.info(f"negative_prompt: {negative_prompt}")
517
+ logger.info(f"height: {height}")
518
+ logger.info(f"width: {width}")
519
+ logger.info(f"sample_steps: {sample_steps}")
520
+ logger.info(f"scale: {scale}")
521
+ # logger.info(f"sample_sampler: {sampler_name}")
522
+ if seed is not None:
523
+ logger.info(f"seed: {seed}")
524
+
525
+ # encode prompts
526
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
527
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
528
+
529
+ def encode_prompt(prpt):
530
+ text_encoder_conds = []
531
+ if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
532
+ text_encoder_conds = sample_prompts_te_outputs[prpt]
533
+ print(f"Using cached text encoder outputs for prompt: {prpt}")
534
+ if text_encoders is not None:
535
+ print(f"Encoding prompt: {prpt}")
536
+ tokens_and_masks = tokenize_strategy.tokenize(prpt)
537
+ # strategy has apply_t5_attn_mask option
538
+ encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
539
+
540
+ # if text_encoder_conds is not cached, use encoded_text_encoder_conds
541
+ if len(text_encoder_conds) == 0:
542
+ text_encoder_conds = encoded_text_encoder_conds
543
+ else:
544
+ # if encoded_text_encoder_conds is not None, update cached text_encoder_conds
545
+ for i in range(len(encoded_text_encoder_conds)):
546
+ if encoded_text_encoder_conds[i] is not None:
547
+ text_encoder_conds[i] = encoded_text_encoder_conds[i]
548
+ return text_encoder_conds
549
+
550
+ lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
551
+ cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
552
+
553
+ # encode negative prompts
554
+ lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
555
+ neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
556
+
557
+ # sample image
558
+ clean_memory_on_device(accelerator.device)
559
+ with accelerator.autocast(), torch.no_grad():
560
+ # mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
561
+ latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
562
+
563
+ # latent to image
564
+ clean_memory_on_device(accelerator.device)
565
+ org_vae_device = vae.device # will be on cpu
566
+ vae.to(accelerator.device)
567
+ latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
568
+ image = vae.decode(latents)
569
+ vae.to(org_vae_device)
570
+ clean_memory_on_device(accelerator.device)
571
+
572
+ image = image.float()
573
+ image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
574
+ decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
575
+ decoded_np = decoded_np.astype(np.uint8)
576
+
577
+ image = Image.fromarray(decoded_np)
578
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
579
+ # but adding 'enum' to the filename should be enough
580
+
581
+ ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
582
+ num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
583
+ seed_suffix = "" if seed is None else f"_{seed}"
584
+ i: int = prompt_dict["enum"]
585
+ img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
586
+ image.save(os.path.join(save_dir, img_filename))
587
+
588
+ # send images to wandb if enabled
589
+ if "wandb" in [tracker.name for tracker in accelerator.trackers]:
590
+ wandb_tracker = accelerator.get_tracker("wandb")
591
+
592
+ import wandb
593
+
594
+ # not to commit images to avoid inconsistency between training and logging steps
595
+ wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
596
+
597
+
598
+ # region Diffusers
599
+
600
+
601
+ from dataclasses import dataclass
602
+ from typing import Optional, Tuple, Union
603
+
604
+ import numpy as np
605
+ import torch
606
+
607
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
608
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
609
+ from diffusers.utils.torch_utils import randn_tensor
610
+ from diffusers.utils import BaseOutput
611
+
612
+
613
+ @dataclass
614
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
615
+ """
616
+ Output class for the scheduler's `step` function output.
617
+
618
+ Args:
619
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
620
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
621
+ denoising loop.
622
+ """
623
+
624
+ prev_sample: torch.FloatTensor
625
+
626
+
627
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
628
+ """
629
+ Euler scheduler.
630
+
631
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
632
+ methods the library implements for all schedulers such as loading and saving.
633
+
634
+ Args:
635
+ num_train_timesteps (`int`, defaults to 1000):
636
+ The number of diffusion steps to train the model.
637
+ timestep_spacing (`str`, defaults to `"linspace"`):
638
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
639
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
640
+ shift (`float`, defaults to 1.0):
641
+ The shift value for the timestep schedule.
642
+ """
643
+
644
+ _compatibles = []
645
+ order = 1
646
+
647
+ @register_to_config
648
+ def __init__(
649
+ self,
650
+ num_train_timesteps: int = 1000,
651
+ shift: float = 1.0,
652
+ ):
653
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
654
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
655
+
656
+ sigmas = timesteps / num_train_timesteps
657
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
658
+
659
+ self.timesteps = sigmas * num_train_timesteps
660
+
661
+ self._step_index = None
662
+ self._begin_index = None
663
+
664
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
665
+ self.sigma_min = self.sigmas[-1].item()
666
+ self.sigma_max = self.sigmas[0].item()
667
+
668
+ @property
669
+ def step_index(self):
670
+ """
671
+ The index counter for current timestep. It will increase 1 after each scheduler step.
672
+ """
673
+ return self._step_index
674
+
675
+ @property
676
+ def begin_index(self):
677
+ """
678
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
679
+ """
680
+ return self._begin_index
681
+
682
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
683
+ def set_begin_index(self, begin_index: int = 0):
684
+ """
685
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
686
+
687
+ Args:
688
+ begin_index (`int`):
689
+ The begin index for the scheduler.
690
+ """
691
+ self._begin_index = begin_index
692
+
693
+ def scale_noise(
694
+ self,
695
+ sample: torch.FloatTensor,
696
+ timestep: Union[float, torch.FloatTensor],
697
+ noise: Optional[torch.FloatTensor] = None,
698
+ ) -> torch.FloatTensor:
699
+ """
700
+ Forward process in flow-matching
701
+
702
+ Args:
703
+ sample (`torch.FloatTensor`):
704
+ The input sample.
705
+ timestep (`int`, *optional*):
706
+ The current timestep in the diffusion chain.
707
+
708
+ Returns:
709
+ `torch.FloatTensor`:
710
+ A scaled input sample.
711
+ """
712
+ if self.step_index is None:
713
+ self._init_step_index(timestep)
714
+
715
+ sigma = self.sigmas[self.step_index]
716
+ sample = sigma * noise + (1.0 - sigma) * sample
717
+
718
+ return sample
719
+
720
+ def _sigma_to_t(self, sigma):
721
+ return sigma * self.config.num_train_timesteps
722
+
723
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
724
+ """
725
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
726
+
727
+ Args:
728
+ num_inference_steps (`int`):
729
+ The number of diffusion steps used when generating samples with a pre-trained model.
730
+ device (`str` or `torch.device`, *optional*):
731
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
732
+ """
733
+ self.num_inference_steps = num_inference_steps
734
+
735
+ timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
736
+
737
+ sigmas = timesteps / self.config.num_train_timesteps
738
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
739
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
740
+
741
+ timesteps = sigmas * self.config.num_train_timesteps
742
+ self.timesteps = timesteps.to(device=device)
743
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
744
+
745
+ self._step_index = None
746
+ self._begin_index = None
747
+
748
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
749
+ if schedule_timesteps is None:
750
+ schedule_timesteps = self.timesteps
751
+
752
+ indices = (schedule_timesteps == timestep).nonzero()
753
+
754
+ # The sigma index that is taken for the **very** first `step`
755
+ # is always the second index (or the last index if there is only 1)
756
+ # This way we can ensure we don't accidentally skip a sigma in
757
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
758
+ pos = 1 if len(indices) > 1 else 0
759
+
760
+ return indices[pos].item()
761
+
762
+ def _init_step_index(self, timestep):
763
+ if self.begin_index is None:
764
+ if isinstance(timestep, torch.Tensor):
765
+ timestep = timestep.to(self.timesteps.device)
766
+ self._step_index = self.index_for_timestep(timestep)
767
+ else:
768
+ self._step_index = self._begin_index
769
+
770
+ def step(
771
+ self,
772
+ model_output: torch.FloatTensor,
773
+ timestep: Union[float, torch.FloatTensor],
774
+ sample: torch.FloatTensor,
775
+ s_churn: float = 0.0,
776
+ s_tmin: float = 0.0,
777
+ s_tmax: float = float("inf"),
778
+ s_noise: float = 1.0,
779
+ generator: Optional[torch.Generator] = None,
780
+ return_dict: bool = True,
781
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
782
+ """
783
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
784
+ process from the learned model outputs (most often the predicted noise).
785
+
786
+ Args:
787
+ model_output (`torch.FloatTensor`):
788
+ The direct output from learned diffusion model.
789
+ timestep (`float`):
790
+ The current discrete timestep in the diffusion chain.
791
+ sample (`torch.FloatTensor`):
792
+ A current instance of a sample created by the diffusion process.
793
+ s_churn (`float`):
794
+ s_tmin (`float`):
795
+ s_tmax (`float`):
796
+ s_noise (`float`, defaults to 1.0):
797
+ Scaling factor for noise added to the sample.
798
+ generator (`torch.Generator`, *optional*):
799
+ A random number generator.
800
+ return_dict (`bool`):
801
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
802
+ tuple.
803
+
804
+ Returns:
805
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
806
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
807
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
808
+ """
809
+
810
+ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
811
+ raise ValueError(
812
+ (
813
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
814
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
815
+ " one of the `scheduler.timesteps` as a timestep."
816
+ ),
817
+ )
818
+
819
+ if self.step_index is None:
820
+ self._init_step_index(timestep)
821
+
822
+ # Upcast to avoid precision issues when computing prev_sample
823
+ sample = sample.to(torch.float32)
824
+
825
+ sigma = self.sigmas[self.step_index]
826
+
827
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
828
+
829
+ noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
830
+
831
+ eps = noise * s_noise
832
+ sigma_hat = sigma * (gamma + 1)
833
+
834
+ if gamma > 0:
835
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
836
+
837
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
838
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
839
+ # backwards compatibility
840
+
841
+ # if self.config.prediction_type == "vector_field":
842
+
843
+ denoised = sample - model_output * sigma
844
+ # 2. Convert to an ODE derivative
845
+ derivative = (sample - denoised) / sigma_hat
846
+
847
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
848
+
849
+ prev_sample = sample + derivative * dt
850
+ # Cast sample back to model compatible dtype
851
+ prev_sample = prev_sample.to(model_output.dtype)
852
+
853
+ # upon completion increase step index by one
854
+ self._step_index += 1
855
+
856
+ if not return_dict:
857
+ return (prev_sample,)
858
+
859
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
860
+
861
+ def __len__(self):
862
+ return self.config.num_train_timesteps
863
+
864
+
865
+ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
866
+ sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
867
+ schedule_timesteps = noise_scheduler.timesteps.to(device)
868
+ timesteps = timesteps.to(device)
869
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
870
+
871
+ sigma = sigmas[step_indices].flatten()
872
+ while len(sigma.shape) < n_dim:
873
+ sigma = sigma.unsqueeze(-1)
874
+ return sigma
875
+
876
+
877
+ def compute_density_for_timestep_sampling(
878
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
879
+ ):
880
+ """Compute the density for sampling the timesteps when doing SD3 training.
881
+
882
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
883
+
884
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
885
+ """
886
+ if weighting_scheme == "logit_normal":
887
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
888
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
889
+ u = torch.nn.functional.sigmoid(u)
890
+ elif weighting_scheme == "mode":
891
+ u = torch.rand(size=(batch_size,), device="cpu")
892
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
893
+ else:
894
+ u = torch.rand(size=(batch_size,), device="cpu")
895
+ return u
896
+
897
+
898
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
899
+ """Computes loss weighting scheme for SD3 training.
900
+
901
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
902
+
903
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
904
+ """
905
+ if weighting_scheme == "sigma_sqrt":
906
+ weighting = (sigmas**-2.0).float()
907
+ elif weighting_scheme == "cosmap":
908
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
909
+ weighting = 2 / (math.pi * bot)
910
+ else:
911
+ weighting = torch.ones_like(sigmas)
912
+ return weighting
913
+
914
+
915
+ # endregion
916
+
917
+
918
+ def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
919
+ bsz = latents.shape[0]
920
+
921
+ # Sample a random timestep for each image
922
+ # for weighting schemes where we sample timesteps non-uniformly
923
+ u = compute_density_for_timestep_sampling(
924
+ weighting_scheme=args.weighting_scheme,
925
+ batch_size=bsz,
926
+ logit_mean=args.logit_mean,
927
+ logit_std=args.logit_std,
928
+ mode_scale=args.mode_scale,
929
+ )
930
+ t_min = args.min_timestep if args.min_timestep is not None else 0
931
+ t_max = args.max_timestep if args.max_timestep is not None else 1000
932
+ shift = args.training_shift
933
+
934
+ # weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
935
+ u = (u * shift) / (1 + (shift - 1) * u)
936
+
937
+ indices = (u * (t_max - t_min) + t_min).long()
938
+ timesteps = indices.to(device=device, dtype=dtype)
939
+
940
+ # sigmas according to flowmatching
941
+ sigmas = timesteps / 1000
942
+ sigmas = sigmas.view(-1, 1, 1, 1)
943
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
944
+
945
+ return noisy_model_input, timesteps, sigmas
library/sd3_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+ import re
4
+ from typing import Dict, List, Optional, Union
5
+ import torch
6
+ import safetensors
7
+ from safetensors.torch import load_file
8
+ from accelerate import init_empty_weights
9
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
10
+
11
+ from .utils import setup_logging
12
+
13
+ setup_logging()
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ from library import sd3_models
19
+
20
+ # TODO move some of functions to model_util.py
21
+ from library import sdxl_model_util
22
+
23
+ # region models
24
+
25
+ # TODO remove dependency on flux_utils
26
+ from library.utils import load_safetensors
27
+ from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
28
+
29
+
30
+ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
31
+ logger.info(f"Analyzing state dict state...")
32
+
33
+ # analyze configs
34
+ patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
35
+ depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
36
+ num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
37
+ pos_embed_max_size = round(math.sqrt(num_patches))
38
+ adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
39
+ context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
40
+ qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
41
+
42
+ # x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
43
+ x_block_self_attn_layers = []
44
+ re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
45
+ for key in list(state_dict.keys()):
46
+ m = re_attn.search(key)
47
+ if m:
48
+ x_block_self_attn_layers.append(int(m.group(1)))
49
+
50
+ context_embedder_in_features = context_shape[1]
51
+ context_embedder_out_features = context_shape[0]
52
+
53
+ # only supports 3-5-large, medium or 3-medium
54
+ if qk_norm is not None:
55
+ if len(x_block_self_attn_layers) == 0:
56
+ model_type = "3-5-large"
57
+ else:
58
+ model_type = "3-5-medium"
59
+ else:
60
+ model_type = "3-medium"
61
+
62
+ params = sd3_models.SD3Params(
63
+ patch_size=patch_size,
64
+ depth=depth,
65
+ num_patches=num_patches,
66
+ pos_embed_max_size=pos_embed_max_size,
67
+ adm_in_channels=adm_in_channels,
68
+ qk_norm=qk_norm,
69
+ x_block_self_attn_layers=x_block_self_attn_layers,
70
+ context_embedder_in_features=context_embedder_in_features,
71
+ context_embedder_out_features=context_embedder_out_features,
72
+ model_type=model_type,
73
+ )
74
+ logger.info(f"Analyzed state dict state: {params}")
75
+ return params
76
+
77
+
78
+ def load_mmdit(
79
+ state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
80
+ ) -> sd3_models.MMDiT:
81
+ mmdit_sd = {}
82
+
83
+ mmdit_prefix = "model.diffusion_model."
84
+ for k in list(state_dict.keys()):
85
+ if k.startswith(mmdit_prefix):
86
+ mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
87
+
88
+ # load MMDiT
89
+ logger.info("Building MMDit")
90
+ params = analyze_state_dict_state(mmdit_sd)
91
+ with init_empty_weights():
92
+ mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
93
+
94
+ logger.info("Loading state dict...")
95
+ info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
96
+ logger.info(f"Loaded MMDiT: {info}")
97
+ return mmdit
98
+
99
+
100
+ def load_clip_l(
101
+ clip_l_path: Optional[str],
102
+ dtype: Optional[Union[str, torch.dtype]],
103
+ device: Union[str, torch.device],
104
+ disable_mmap: bool = False,
105
+ state_dict: Optional[Dict] = None,
106
+ ):
107
+ clip_l_sd = None
108
+ if clip_l_path is None:
109
+ if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
110
+ # found clip_l: remove prefix "text_encoders.clip_l."
111
+ logger.info("clip_l is included in the checkpoint")
112
+ clip_l_sd = {}
113
+ prefix = "text_encoders.clip_l."
114
+ for k in list(state_dict.keys()):
115
+ if k.startswith(prefix):
116
+ clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
117
+ elif clip_l_path is None:
118
+ logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
119
+ return None
120
+
121
+ # load clip_l
122
+ logger.info("Building CLIP-L")
123
+ config = CLIPTextConfig(
124
+ vocab_size=49408,
125
+ hidden_size=768,
126
+ intermediate_size=3072,
127
+ num_hidden_layers=12,
128
+ num_attention_heads=12,
129
+ max_position_embeddings=77,
130
+ hidden_act="quick_gelu",
131
+ layer_norm_eps=1e-05,
132
+ dropout=0.0,
133
+ attention_dropout=0.0,
134
+ initializer_range=0.02,
135
+ initializer_factor=1.0,
136
+ pad_token_id=1,
137
+ bos_token_id=0,
138
+ eos_token_id=2,
139
+ model_type="clip_text_model",
140
+ projection_dim=768,
141
+ # torch_dtype="float32",
142
+ # transformers_version="4.25.0.dev0",
143
+ )
144
+ with init_empty_weights():
145
+ clip = CLIPTextModelWithProjection(config)
146
+
147
+ if clip_l_sd is None:
148
+ logger.info(f"Loading state dict from {clip_l_path}")
149
+ clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
150
+
151
+ if "text_projection.weight" not in clip_l_sd:
152
+ logger.info("Adding text_projection.weight to clip_l_sd")
153
+ clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
154
+
155
+ info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
156
+ logger.info(f"Loaded CLIP-L: {info}")
157
+ return clip
158
+
159
+
160
+ def load_clip_g(
161
+ clip_g_path: Optional[str],
162
+ dtype: Optional[Union[str, torch.dtype]],
163
+ device: Union[str, torch.device],
164
+ disable_mmap: bool = False,
165
+ state_dict: Optional[Dict] = None,
166
+ ):
167
+ clip_g_sd = None
168
+ if state_dict is not None:
169
+ if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
170
+ # found clip_g: remove prefix "text_encoders.clip_g."
171
+ logger.info("clip_g is included in the checkpoint")
172
+ clip_g_sd = {}
173
+ prefix = "text_encoders.clip_g."
174
+ for k in list(state_dict.keys()):
175
+ if k.startswith(prefix):
176
+ clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
177
+ elif clip_g_path is None:
178
+ logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
179
+ return None
180
+
181
+ # load clip_g
182
+ logger.info("Building CLIP-G")
183
+ config = CLIPTextConfig(
184
+ vocab_size=49408,
185
+ hidden_size=1280,
186
+ intermediate_size=5120,
187
+ num_hidden_layers=32,
188
+ num_attention_heads=20,
189
+ max_position_embeddings=77,
190
+ hidden_act="gelu",
191
+ layer_norm_eps=1e-05,
192
+ dropout=0.0,
193
+ attention_dropout=0.0,
194
+ initializer_range=0.02,
195
+ initializer_factor=1.0,
196
+ pad_token_id=1,
197
+ bos_token_id=0,
198
+ eos_token_id=2,
199
+ model_type="clip_text_model",
200
+ projection_dim=1280,
201
+ # torch_dtype="float32",
202
+ # transformers_version="4.25.0.dev0",
203
+ )
204
+ with init_empty_weights():
205
+ clip = CLIPTextModelWithProjection(config)
206
+
207
+ if clip_g_sd is None:
208
+ logger.info(f"Loading state dict from {clip_g_path}")
209
+ clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
210
+ info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
211
+ logger.info(f"Loaded CLIP-G: {info}")
212
+ return clip
213
+
214
+
215
+ def load_t5xxl(
216
+ t5xxl_path: Optional[str],
217
+ dtype: Optional[Union[str, torch.dtype]],
218
+ device: Union[str, torch.device],
219
+ disable_mmap: bool = False,
220
+ state_dict: Optional[Dict] = None,
221
+ ):
222
+ t5xxl_sd = None
223
+ if state_dict is not None:
224
+ if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
225
+ # found t5xxl: remove prefix "text_encoders.t5xxl."
226
+ logger.info("t5xxl is included in the checkpoint")
227
+ t5xxl_sd = {}
228
+ prefix = "text_encoders.t5xxl."
229
+ for k in list(state_dict.keys()):
230
+ if k.startswith(prefix):
231
+ t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
232
+ elif t5xxl_path is None:
233
+ logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
234
+ return None
235
+
236
+ return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
237
+
238
+
239
+ def load_vae(
240
+ vae_path: Optional[str],
241
+ vae_dtype: Optional[Union[str, torch.dtype]],
242
+ device: Optional[Union[str, torch.device]],
243
+ disable_mmap: bool = False,
244
+ state_dict: Optional[Dict] = None,
245
+ ):
246
+ vae_sd = {}
247
+ if vae_path:
248
+ logger.info(f"Loading VAE from {vae_path}...")
249
+ vae_sd = load_safetensors(vae_path, device, disable_mmap)
250
+ else:
251
+ # remove prefix "first_stage_model."
252
+ vae_sd = {}
253
+ vae_prefix = "first_stage_model."
254
+ for k in list(state_dict.keys()):
255
+ if k.startswith(vae_prefix):
256
+ vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
257
+
258
+ logger.info("Building VAE")
259
+ vae = sd3_models.SDVAE(vae_dtype, device)
260
+ logger.info("Loading state dict...")
261
+ info = vae.load_state_dict(vae_sd)
262
+ logger.info(f"Loaded VAE: {info}")
263
+ vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
264
+ return vae
265
+
266
+
267
+ # endregion
268
+
269
+
270
+ class ModelSamplingDiscreteFlow:
271
+ """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
272
+
273
+ def __init__(self, shift=1.0):
274
+ self.shift = shift
275
+ timesteps = 1000
276
+ self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
277
+
278
+ @property
279
+ def sigma_min(self):
280
+ return self.sigmas[0]
281
+
282
+ @property
283
+ def sigma_max(self):
284
+ return self.sigmas[-1]
285
+
286
+ def timestep(self, sigma):
287
+ return sigma * 1000
288
+
289
+ def sigma(self, timestep: torch.Tensor):
290
+ timestep = timestep / 1000.0
291
+ if self.shift == 1.0:
292
+ return timestep
293
+ return self.shift * timestep / (1 + (self.shift - 1) * timestep)
294
+
295
+ def calculate_denoised(self, sigma, model_output, model_input):
296
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
297
+ return model_input - model_output * sigma
298
+
299
+ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
300
+ # assert max_denoise is False, "max_denoise not implemented"
301
+ # max_denoise is always True, I'm not sure why it's there
302
+ return sigma * noise + (1.0 - sigma) * latent_image
library/sdxl_lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from tqdm import tqdm
13
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
+
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+ from PIL import Image
20
+
21
+ from library import (
22
+ sdxl_model_util,
23
+ sdxl_train_util,
24
+ strategy_base,
25
+ strategy_sdxl,
26
+ train_util,
27
+ sdxl_original_unet,
28
+ sdxl_original_control_net,
29
+ )
30
+
31
+
32
+ try:
33
+ from diffusers.utils import PIL_INTERPOLATION
34
+ except ImportError:
35
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
36
+ PIL_INTERPOLATION = {
37
+ "linear": PIL.Image.Resampling.BILINEAR,
38
+ "bilinear": PIL.Image.Resampling.BILINEAR,
39
+ "bicubic": PIL.Image.Resampling.BICUBIC,
40
+ "lanczos": PIL.Image.Resampling.LANCZOS,
41
+ "nearest": PIL.Image.Resampling.NEAREST,
42
+ }
43
+ else:
44
+ PIL_INTERPOLATION = {
45
+ "linear": PIL.Image.LINEAR,
46
+ "bilinear": PIL.Image.BILINEAR,
47
+ "bicubic": PIL.Image.BICUBIC,
48
+ "lanczos": PIL.Image.LANCZOS,
49
+ "nearest": PIL.Image.NEAREST,
50
+ }
51
+ # ------------------------------------------------------------------------------
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+ re_attention = re.compile(
56
+ r"""
57
+ \\\(|
58
+ \\\)|
59
+ \\\[|
60
+ \\]|
61
+ \\\\|
62
+ \\|
63
+ \(|
64
+ \[|
65
+ :([+-]?[.\d]+)\)|
66
+ \)|
67
+ ]|
68
+ [^\\()\[\]:]+|
69
+ :
70
+ """,
71
+ re.X,
72
+ )
73
+
74
+
75
+ def parse_prompt_attention(text):
76
+ """
77
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
78
+ Accepted tokens are:
79
+ (abc) - increases attention to abc by a multiplier of 1.1
80
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
81
+ [abc] - decreases attention to abc by a multiplier of 1.1
82
+ \( - literal character '('
83
+ \[ - literal character '['
84
+ \) - literal character ')'
85
+ \] - literal character ']'
86
+ \\ - literal character '\'
87
+ anything else - just text
88
+ >>> parse_prompt_attention('normal text')
89
+ [['normal text', 1.0]]
90
+ >>> parse_prompt_attention('an (important) word')
91
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
92
+ >>> parse_prompt_attention('(unbalanced')
93
+ [['unbalanced', 1.1]]
94
+ >>> parse_prompt_attention('\(literal\]')
95
+ [['(literal]', 1.0]]
96
+ >>> parse_prompt_attention('(unnecessary)(parens)')
97
+ [['unnecessaryparens', 1.1]]
98
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
99
+ [['a ', 1.0],
100
+ ['house', 1.5730000000000004],
101
+ [' ', 1.1],
102
+ ['on', 1.0],
103
+ [' a ', 1.1],
104
+ ['hill', 0.55],
105
+ [', sun, ', 1.1],
106
+ ['sky', 1.4641000000000006],
107
+ ['.', 1.1]]
108
+ """
109
+
110
+ res = []
111
+ round_brackets = []
112
+ square_brackets = []
113
+
114
+ round_bracket_multiplier = 1.1
115
+ square_bracket_multiplier = 1 / 1.1
116
+
117
+ def multiply_range(start_position, multiplier):
118
+ for p in range(start_position, len(res)):
119
+ res[p][1] *= multiplier
120
+
121
+ for m in re_attention.finditer(text):
122
+ text = m.group(0)
123
+ weight = m.group(1)
124
+
125
+ if text.startswith("\\"):
126
+ res.append([text[1:], 1.0])
127
+ elif text == "(":
128
+ round_brackets.append(len(res))
129
+ elif text == "[":
130
+ square_brackets.append(len(res))
131
+ elif weight is not None and len(round_brackets) > 0:
132
+ multiply_range(round_brackets.pop(), float(weight))
133
+ elif text == ")" and len(round_brackets) > 0:
134
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
135
+ elif text == "]" and len(square_brackets) > 0:
136
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
137
+ else:
138
+ res.append([text, 1.0])
139
+
140
+ for pos in round_brackets:
141
+ multiply_range(pos, round_bracket_multiplier)
142
+
143
+ for pos in square_brackets:
144
+ multiply_range(pos, square_bracket_multiplier)
145
+
146
+ if len(res) == 0:
147
+ res = [["", 1.0]]
148
+
149
+ # merge runs of identical weights
150
+ i = 0
151
+ while i + 1 < len(res):
152
+ if res[i][1] == res[i + 1][1]:
153
+ res[i][0] += res[i + 1][0]
154
+ res.pop(i + 1)
155
+ else:
156
+ i += 1
157
+
158
+ return res
159
+
160
+
161
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
162
+ r"""
163
+ Tokenize a list of prompts and return its tokens with weights of each token.
164
+
165
+ No padding, starting or ending token is included.
166
+ """
167
+ tokens = []
168
+ weights = []
169
+ truncated = False
170
+ for text in prompt:
171
+ texts_and_weights = parse_prompt_attention(text)
172
+ text_token = []
173
+ text_weight = []
174
+ for word, weight in texts_and_weights:
175
+ # tokenize and discard the starting and the ending token
176
+ token = pipe.tokenizer(word).input_ids[1:-1]
177
+ text_token += token
178
+ # copy the weight by length of token
179
+ text_weight += [weight] * len(token)
180
+ # stop if the text is too long (longer than truncation limit)
181
+ if len(text_token) > max_length:
182
+ truncated = True
183
+ break
184
+ # truncate
185
+ if len(text_token) > max_length:
186
+ truncated = True
187
+ text_token = text_token[:max_length]
188
+ text_weight = text_weight[:max_length]
189
+ tokens.append(text_token)
190
+ weights.append(text_weight)
191
+ if truncated:
192
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
193
+ return tokens, weights
194
+
195
+
196
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
197
+ r"""
198
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
199
+ """
200
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
201
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
202
+ for i in range(len(tokens)):
203
+ tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
204
+ if no_boseos_middle:
205
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
206
+ else:
207
+ w = []
208
+ if len(weights[i]) == 0:
209
+ w = [1.0] * weights_length
210
+ else:
211
+ for j in range(max_embeddings_multiples):
212
+ w.append(1.0) # weight for starting token in this chunk
213
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
214
+ w.append(1.0) # weight for ending token in this chunk
215
+ w += [1.0] * (weights_length - len(w))
216
+ weights[i] = w[:]
217
+
218
+ return tokens, weights
219
+
220
+
221
+ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
222
+ if not is_sdxl_text_encoder2:
223
+ # text_encoder1: same as SD1/2
224
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
225
+ hidden_states = enc_out["hidden_states"][11]
226
+ pool = None
227
+ else:
228
+ # text_encoder2
229
+ enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
230
+ hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
231
+ # pool = enc_out["text_embeds"]
232
+ pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
233
+ hidden_states = hidden_states.to(device)
234
+ if pool is not None:
235
+ pool = pool.to(device)
236
+ return hidden_states, pool
237
+
238
+
239
+ def get_unweighted_text_embeddings(
240
+ pipe: StableDiffusionPipeline,
241
+ text_input: torch.Tensor,
242
+ chunk_length: int,
243
+ clip_skip: int,
244
+ eos: int,
245
+ pad: int,
246
+ is_sdxl_text_encoder2: bool,
247
+ no_boseos_middle: Optional[bool] = True,
248
+ ):
249
+ """
250
+ When the length of tokens is a multiple of the capacity of the text encoder,
251
+ it should be split into chunks and sent to the text encoder individually.
252
+ """
253
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
254
+ text_pool = None
255
+ if max_embeddings_multiples > 1:
256
+ text_embeddings = []
257
+ for i in range(max_embeddings_multiples):
258
+ # extract the i-th chunk
259
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
260
+
261
+ # cover the head and the tail by the starting and the ending tokens
262
+ text_input_chunk[:, 0] = text_input[0, 0]
263
+ if pad == eos: # v1
264
+ text_input_chunk[:, -1] = text_input[0, -1]
265
+ else: # v2
266
+ for j in range(len(text_input_chunk)):
267
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
268
+ text_input_chunk[j, -1] = eos
269
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
270
+ text_input_chunk[j, 1] = eos
271
+
272
+ text_embedding, current_text_pool = get_hidden_states(
273
+ pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
274
+ )
275
+ if text_pool is None:
276
+ text_pool = current_text_pool
277
+
278
+ if no_boseos_middle:
279
+ if i == 0:
280
+ # discard the ending token
281
+ text_embedding = text_embedding[:, :-1]
282
+ elif i == max_embeddings_multiples - 1:
283
+ # discard the starting token
284
+ text_embedding = text_embedding[:, 1:]
285
+ else:
286
+ # discard both starting and ending tokens
287
+ text_embedding = text_embedding[:, 1:-1]
288
+
289
+ text_embeddings.append(text_embedding)
290
+ text_embeddings = torch.concat(text_embeddings, axis=1)
291
+ else:
292
+ text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
293
+ return text_embeddings, text_pool
294
+
295
+
296
+ def get_weighted_text_embeddings(
297
+ pipe, # : SdxlStableDiffusionLongPromptWeightingPipeline,
298
+ prompt: Union[str, List[str]],
299
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
300
+ max_embeddings_multiples: Optional[int] = 3,
301
+ no_boseos_middle: Optional[bool] = False,
302
+ skip_parsing: Optional[bool] = False,
303
+ skip_weighting: Optional[bool] = False,
304
+ clip_skip=None,
305
+ is_sdxl_text_encoder2=False,
306
+ ):
307
+ r"""
308
+ Prompts can be assigned with local weights using brackets. For example,
309
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
310
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
311
+
312
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
313
+
314
+ Args:
315
+ pipe (`StableDiffusionPipeline`):
316
+ Pipe to provide access to the tokenizer and the text encoder.
317
+ prompt (`str` or `List[str]`):
318
+ The prompt or prompts to guide the image generation.
319
+ uncond_prompt (`str` or `List[str]`):
320
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
321
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
322
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
323
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
324
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
325
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
326
+ ending token in each of the chunk in the middle.
327
+ skip_parsing (`bool`, *optional*, defaults to `False`):
328
+ Skip the parsing of brackets.
329
+ skip_weighting (`bool`, *optional*, defaults to `False`):
330
+ Skip the weighting. When the parsing is skipped, it is forced True.
331
+ """
332
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
333
+ if isinstance(prompt, str):
334
+ prompt = [prompt]
335
+
336
+ if not skip_parsing:
337
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
338
+ if uncond_prompt is not None:
339
+ if isinstance(uncond_prompt, str):
340
+ uncond_prompt = [uncond_prompt]
341
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
342
+ else:
343
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
344
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
345
+ if uncond_prompt is not None:
346
+ if isinstance(uncond_prompt, str):
347
+ uncond_prompt = [uncond_prompt]
348
+ uncond_tokens = [
349
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
350
+ ]
351
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
352
+
353
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
354
+ max_length = max([len(token) for token in prompt_tokens])
355
+ if uncond_prompt is not None:
356
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
357
+
358
+ max_embeddings_multiples = min(
359
+ max_embeddings_multiples,
360
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
361
+ )
362
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
363
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
364
+
365
+ # pad the length of tokens and weights
366
+ bos = pipe.tokenizer.bos_token_id
367
+ eos = pipe.tokenizer.eos_token_id
368
+ pad = pipe.tokenizer.pad_token_id
369
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
370
+ prompt_tokens,
371
+ prompt_weights,
372
+ max_length,
373
+ bos,
374
+ eos,
375
+ pad,
376
+ no_boseos_middle=no_boseos_middle,
377
+ chunk_length=pipe.tokenizer.model_max_length,
378
+ )
379
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
380
+ if uncond_prompt is not None:
381
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
382
+ uncond_tokens,
383
+ uncond_weights,
384
+ max_length,
385
+ bos,
386
+ eos,
387
+ pad,
388
+ no_boseos_middle=no_boseos_middle,
389
+ chunk_length=pipe.tokenizer.model_max_length,
390
+ )
391
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
392
+
393
+ # get the embeddings
394
+ text_embeddings, text_pool = get_unweighted_text_embeddings(
395
+ pipe,
396
+ prompt_tokens,
397
+ pipe.tokenizer.model_max_length,
398
+ clip_skip,
399
+ eos,
400
+ pad,
401
+ is_sdxl_text_encoder2,
402
+ no_boseos_middle=no_boseos_middle,
403
+ )
404
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
405
+
406
+ if uncond_prompt is not None:
407
+ uncond_embeddings, uncond_pool = get_unweighted_text_embeddings(
408
+ pipe,
409
+ uncond_tokens,
410
+ pipe.tokenizer.model_max_length,
411
+ clip_skip,
412
+ eos,
413
+ pad,
414
+ is_sdxl_text_encoder2,
415
+ no_boseos_middle=no_boseos_middle,
416
+ )
417
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
418
+
419
+ # assign weights to the prompts and normalize in the sense of mean
420
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
421
+ if (not skip_parsing) and (not skip_weighting):
422
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
423
+ text_embeddings *= prompt_weights.unsqueeze(-1)
424
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
425
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
426
+ if uncond_prompt is not None:
427
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
428
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
429
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
430
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
431
+
432
+ if uncond_prompt is not None:
433
+ return text_embeddings, text_pool, uncond_embeddings, uncond_pool
434
+ return text_embeddings, text_pool, None, None
435
+
436
+
437
+ def preprocess_image(image):
438
+ w, h = image.size
439
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
440
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
441
+ image = np.array(image).astype(np.float32) / 255.0
442
+ image = image[None].transpose(0, 3, 1, 2)
443
+ image = torch.from_numpy(image)
444
+ return 2.0 * image - 1.0
445
+
446
+
447
+ def preprocess_mask(mask, scale_factor=8):
448
+ mask = mask.convert("L")
449
+ w, h = mask.size
450
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
451
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
452
+ mask = np.array(mask).astype(np.float32) / 255.0
453
+ mask = np.tile(mask, (4, 1, 1))
454
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
455
+ mask = 1 - mask # repaint white, keep black
456
+ mask = torch.from_numpy(mask)
457
+ return mask
458
+
459
+
460
+ def prepare_controlnet_image(
461
+ image: PIL.Image.Image,
462
+ width: int,
463
+ height: int,
464
+ batch_size: int,
465
+ num_images_per_prompt: int,
466
+ device: torch.device,
467
+ dtype: torch.dtype,
468
+ do_classifier_free_guidance: bool = False,
469
+ guess_mode: bool = False,
470
+ ):
471
+ if not isinstance(image, torch.Tensor):
472
+ if isinstance(image, PIL.Image.Image):
473
+ image = [image]
474
+
475
+ if isinstance(image[0], PIL.Image.Image):
476
+ images = []
477
+
478
+ for image_ in image:
479
+ image_ = image_.convert("RGB")
480
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
481
+ image_ = np.array(image_)
482
+ image_ = image_[None, :]
483
+ images.append(image_)
484
+
485
+ image = images
486
+
487
+ image = np.concatenate(image, axis=0)
488
+ image = np.array(image).astype(np.float32) / 255.0
489
+ image = image.transpose(0, 3, 1, 2)
490
+ image = torch.from_numpy(image)
491
+ elif isinstance(image[0], torch.Tensor):
492
+ image = torch.cat(image, dim=0)
493
+
494
+ image_batch_size = image.shape[0]
495
+
496
+ if image_batch_size == 1:
497
+ repeat_by = batch_size
498
+ else:
499
+ # image batch size is the same as prompt batch size
500
+ repeat_by = num_images_per_prompt
501
+
502
+ image = image.repeat_interleave(repeat_by, dim=0)
503
+
504
+ image = image.to(device=device, dtype=dtype)
505
+
506
+ if do_classifier_free_guidance and not guess_mode:
507
+ image = torch.cat([image] * 2)
508
+
509
+ return image
510
+
511
+
512
+ class SdxlStableDiffusionLongPromptWeightingPipeline:
513
+ r"""
514
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
515
+ weighting in prompt.
516
+
517
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
518
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
519
+
520
+ Args:
521
+ vae ([`AutoencoderKL`]):
522
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
523
+ text_encoder ([`CLIPTextModel`]):
524
+ Frozen text-encoder. Stable Diffusion uses the text portion of
525
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
526
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
527
+ tokenizer (`CLIPTokenizer`):
528
+ Tokenizer of class
529
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
530
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
531
+ scheduler ([`SchedulerMixin`]):
532
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
533
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
534
+ safety_checker ([`StableDiffusionSafetyChecker`]):
535
+ Classification module that estimates whether generated images could be considered offensive or harmful.
536
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
537
+ feature_extractor ([`CLIPFeatureExtractor`]):
538
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
539
+ """
540
+
541
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
542
+
543
+ def __init__(
544
+ self,
545
+ vae: AutoencoderKL,
546
+ text_encoder: List[CLIPTextModel],
547
+ tokenizer: List[CLIPTokenizer],
548
+ unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
549
+ scheduler: SchedulerMixin,
550
+ # clip_skip: int,
551
+ safety_checker: StableDiffusionSafetyChecker,
552
+ feature_extractor: CLIPFeatureExtractor,
553
+ requires_safety_checker: bool = True,
554
+ clip_skip: int = 1,
555
+ ):
556
+ # clip skip is ignored currently
557
+ self.tokenizer = tokenizer[0]
558
+ self.text_encoder = text_encoder[0]
559
+ self.unet = unet
560
+ self.scheduler = scheduler
561
+ self.safety_checker = safety_checker
562
+ self.feature_extractor = feature_extractor
563
+ self.requires_safety_checker = requires_safety_checker
564
+ self.vae = vae
565
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
566
+ self.progress_bar = lambda x: tqdm(x, leave=False)
567
+
568
+ self.clip_skip = clip_skip
569
+ self.tokenizers = tokenizer
570
+ self.text_encoders = text_encoder
571
+
572
+ # self.__init__additional__()
573
+
574
+ # def __init__additional__(self):
575
+ # if not hasattr(self, "vae_scale_factor"):
576
+ # setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
577
+
578
+ def to(self, device=None, dtype=None):
579
+ if device is not None:
580
+ self.device = device
581
+ # self.vae.to(device=self.device)
582
+ if dtype is not None:
583
+ self.dtype = dtype
584
+
585
+ # do not move Text Encoders to device, because Text Encoder should be on CPU
586
+
587
+ @property
588
+ def _execution_device(self):
589
+ r"""
590
+ Returns the device on which the pipeline's models will be executed. After calling
591
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
592
+ hooks.
593
+ """
594
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
595
+ return self.device
596
+ for module in self.unet.modules():
597
+ if (
598
+ hasattr(module, "_hf_hook")
599
+ and hasattr(module._hf_hook, "execution_device")
600
+ and module._hf_hook.execution_device is not None
601
+ ):
602
+ return torch.device(module._hf_hook.execution_device)
603
+ return self.device
604
+
605
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
606
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
607
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
608
+
609
+ if strength < 0 or strength > 1:
610
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
611
+
612
+ if height % 8 != 0 or width % 8 != 0:
613
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
614
+
615
+ if (callback_steps is None) or (
616
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
617
+ ):
618
+ raise ValueError(
619
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
620
+ )
621
+
622
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
623
+ if is_text2img:
624
+ return self.scheduler.timesteps.to(device), num_inference_steps
625
+ else:
626
+ # get the original timestep using init_timestep
627
+ offset = self.scheduler.config.get("steps_offset", 0)
628
+ init_timestep = int(num_inference_steps * strength) + offset
629
+ init_timestep = min(init_timestep, num_inference_steps)
630
+
631
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
632
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
633
+ return timesteps, num_inference_steps - t_start
634
+
635
+ def run_safety_checker(self, image, device, dtype):
636
+ if self.safety_checker is not None:
637
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
638
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
639
+ else:
640
+ has_nsfw_concept = None
641
+ return image, has_nsfw_concept
642
+
643
+ def decode_latents(self, latents):
644
+ with torch.no_grad():
645
+ latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
646
+
647
+ # print("post_quant_conv dtype:", self.vae.post_quant_conv.weight.dtype) # torch.float32
648
+ # x = torch.nn.functional.conv2d(latents, self.vae.post_quant_conv.weight.detach(), stride=1, padding=0)
649
+ # print("latents dtype:", latents.dtype, "x dtype:", x.dtype) # torch.float32, torch.float16
650
+ # self.vae.to("cpu")
651
+ # self.vae.set_use_memory_efficient_attention_xformers(False)
652
+ # image = self.vae.decode(latents.to("cpu")).sample
653
+
654
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
655
+ image = (image / 2 + 0.5).clamp(0, 1)
656
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
657
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
658
+ return image
659
+
660
+ def prepare_extra_step_kwargs(self, generator, eta):
661
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
662
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
663
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
664
+ # and should be between [0, 1]
665
+
666
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
667
+ extra_step_kwargs = {}
668
+ if accepts_eta:
669
+ extra_step_kwargs["eta"] = eta
670
+
671
+ # check if the scheduler accepts generator
672
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
673
+ if accepts_generator:
674
+ extra_step_kwargs["generator"] = generator
675
+ return extra_step_kwargs
676
+
677
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
678
+ if image is None:
679
+ shape = (
680
+ batch_size,
681
+ self.unet.in_channels,
682
+ height // self.vae_scale_factor,
683
+ width // self.vae_scale_factor,
684
+ )
685
+
686
+ if latents is None:
687
+ if device.type == "mps":
688
+ # randn does not work reproducibly on mps
689
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
690
+ else:
691
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
692
+ else:
693
+ if latents.shape != shape:
694
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
695
+ latents = latents.to(device)
696
+
697
+ # scale the initial noise by the standard deviation required by the scheduler
698
+ latents = latents * self.scheduler.init_noise_sigma
699
+ return latents, None, None
700
+ else:
701
+ init_latent_dist = self.vae.encode(image).latent_dist
702
+ init_latents = init_latent_dist.sample(generator=generator)
703
+ init_latents = sdxl_model_util.VAE_SCALE_FACTOR * init_latents
704
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
705
+ init_latents_orig = init_latents
706
+ shape = init_latents.shape
707
+
708
+ # add noise to latents using the timesteps
709
+ if device.type == "mps":
710
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
711
+ else:
712
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
713
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
714
+ return latents, init_latents_orig, noise
715
+
716
+ @torch.no_grad()
717
+ def __call__(
718
+ self,
719
+ prompt: Union[str, List[str]],
720
+ negative_prompt: Optional[Union[str, List[str]]] = None,
721
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
722
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
723
+ height: int = 512,
724
+ width: int = 512,
725
+ num_inference_steps: int = 50,
726
+ guidance_scale: float = 7.5,
727
+ strength: float = 0.8,
728
+ num_images_per_prompt: Optional[int] = 1,
729
+ eta: float = 0.0,
730
+ generator: Optional[torch.Generator] = None,
731
+ latents: Optional[torch.FloatTensor] = None,
732
+ max_embeddings_multiples: Optional[int] = 3,
733
+ output_type: Optional[str] = "pil",
734
+ return_dict: bool = True,
735
+ controlnet: sdxl_original_control_net.SdxlControlNet = None,
736
+ controlnet_image=None,
737
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
738
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
739
+ callback_steps: int = 1,
740
+ ):
741
+ r"""
742
+ Function invoked when calling the pipeline for generation.
743
+
744
+ Args:
745
+ prompt (`str` or `List[str]`):
746
+ The prompt or prompts to guide the image generation.
747
+ negative_prompt (`str` or `List[str]`, *optional*):
748
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
749
+ if `guidance_scale` is less than `1`).
750
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
751
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
752
+ process.
753
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
754
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
755
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
756
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
757
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
758
+ height (`int`, *optional*, defaults to 512):
759
+ The height in pixels of the generated image.
760
+ width (`int`, *optional*, defaults to 512):
761
+ The width in pixels of the generated image.
762
+ num_inference_steps (`int`, *optional*, defaults to 50):
763
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
764
+ expense of slower inference.
765
+ guidance_scale (`float`, *optional*, defaults to 7.5):
766
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
767
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
768
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
769
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
770
+ usually at the expense of lower image quality.
771
+ strength (`float`, *optional*, defaults to 0.8):
772
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
773
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
774
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
775
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
776
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
777
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
778
+ The number of images to generate per prompt.
779
+ eta (`float`, *optional*, defaults to 0.0):
780
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
781
+ [`schedulers.DDIMScheduler`], will be ignored for others.
782
+ generator (`torch.Generator`, *optional*):
783
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
784
+ deterministic.
785
+ latents (`torch.FloatTensor`, *optional*):
786
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
787
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
788
+ tensor will ge generated by sampling using the supplied random `generator`.
789
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
790
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
791
+ output_type (`str`, *optional*, defaults to `"pil"`):
792
+ The output format of the generate image. Choose between
793
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
794
+ return_dict (`bool`, *optional*, defaults to `True`):
795
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
796
+ plain tuple.
797
+ controlnet (`diffusers.ControlNetModel`, *optional*):
798
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
799
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
800
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
801
+ inference.
802
+ callback (`Callable`, *optional*):
803
+ A function that will be called every `callback_steps` steps during inference. The function will be
804
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
805
+ is_cancelled_callback (`Callable`, *optional*):
806
+ A function that will be called every `callback_steps` steps during inference. If the function returns
807
+ `True`, the inference will be cancelled.
808
+ callback_steps (`int`, *optional*, defaults to 1):
809
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
810
+ called at every step.
811
+
812
+ Returns:
813
+ `None` if cancelled by `is_cancelled_callback`,
814
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
815
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
816
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
817
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
818
+ (nsfw) content, according to the `safety_checker`.
819
+ """
820
+ if controlnet is not None and controlnet_image is None:
821
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
822
+
823
+ # 0. Default height and width to unet
824
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
825
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
826
+
827
+ # 1. Check inputs. Raise error if not correct
828
+ self.check_inputs(prompt, height, width, strength, callback_steps)
829
+
830
+ # 2. Define call parameters
831
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
832
+ device = self._execution_device
833
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
834
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
835
+ # corresponds to doing no classifier free guidance.
836
+ do_classifier_free_guidance = guidance_scale > 1.0
837
+
838
+ # 3. Encode input prompt
839
+ tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
840
+ encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
841
+
842
+ text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
843
+ hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
844
+ tokenize_strategy, self.text_encoders, text_input_ids, text_weights
845
+ )
846
+ text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
847
+
848
+ if do_classifier_free_guidance:
849
+ input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
850
+ hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
851
+ tokenize_strategy, self.text_encoders, input_ids, weights
852
+ )
853
+ uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
854
+ else:
855
+ uncond_embeddings = None
856
+ uncond_pool = None
857
+
858
+ unet_dtype = self.unet.dtype
859
+ dtype = unet_dtype
860
+ if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
861
+ dtype = torch.float16
862
+ self.unet.to(dtype)
863
+
864
+ # 4. Preprocess image and mask
865
+ if isinstance(image, PIL.Image.Image):
866
+ image = preprocess_image(image)
867
+ if image is not None:
868
+ image = image.to(device=self.device, dtype=dtype)
869
+ if isinstance(mask_image, PIL.Image.Image):
870
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
871
+ if mask_image is not None:
872
+ mask = mask_image.to(device=self.device, dtype=dtype)
873
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
874
+ else:
875
+ mask = None
876
+
877
+ # ControlNet is not working yet in SDXL, but keep the code here for future use
878
+ if controlnet_image is not None:
879
+ controlnet_image = prepare_controlnet_image(
880
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
881
+ )
882
+
883
+ # 5. set timesteps
884
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
885
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
886
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
887
+
888
+ # 6. Prepare latent variables
889
+ latents, init_latents_orig, noise = self.prepare_latents(
890
+ image,
891
+ latent_timestep,
892
+ batch_size * num_images_per_prompt,
893
+ height,
894
+ width,
895
+ dtype,
896
+ device,
897
+ generator,
898
+ latents,
899
+ )
900
+
901
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
902
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
903
+
904
+ # create size embs and concat embeddings for SDXL
905
+ orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
906
+ crop_size = torch.zeros_like(orig_size)
907
+ target_size = orig_size
908
+ embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
909
+
910
+ # make conditionings
911
+ text_pool = text_pool.to(device, dtype)
912
+ if do_classifier_free_guidance:
913
+ text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
914
+
915
+ uncond_pool = uncond_pool.to(device, dtype)
916
+ cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
917
+ uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
918
+ vector_embedding = torch.cat([uncond_vector, cond_vector])
919
+ else:
920
+ text_embedding = text_embeddings.to(device, dtype)
921
+ vector_embedding = torch.cat([text_pool, embs], dim=1)
922
+
923
+ # 8. Denoising loop
924
+ for i, t in enumerate(self.progress_bar(timesteps)):
925
+ # expand the latents if we are doing classifier free guidance
926
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
927
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
928
+
929
+ # FIXME SD1 ControlNet is not working
930
+
931
+ # predict the noise residual
932
+ if controlnet is not None:
933
+ input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
934
+ noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
935
+ else:
936
+ noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
937
+ noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
938
+
939
+ # perform guidance
940
+ if do_classifier_free_guidance:
941
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
942
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
943
+
944
+ # compute the previous noisy sample x_t -> x_t-1
945
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
946
+
947
+ if mask is not None:
948
+ # masking
949
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
950
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
951
+
952
+ # call the callback, if provided
953
+ if i % callback_steps == 0:
954
+ if callback is not None:
955
+ callback(i, t, latents)
956
+ if is_cancelled_callback is not None and is_cancelled_callback():
957
+ return None
958
+
959
+ self.unet.to(unet_dtype)
960
+ return latents
961
+
962
+ def latents_to_image(self, latents):
963
+ # 9. Post-processing
964
+ image = self.decode_latents(latents.to(self.vae.dtype))
965
+ image = self.numpy_to_pil(image)
966
+ return image
967
+
968
+ # copy from pil_utils.py
969
+ def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
970
+ """
971
+ Convert a numpy image or a batch of images to a PIL image.
972
+ """
973
+ if images.ndim == 3:
974
+ images = images[None, ...]
975
+ images = (images * 255).round().astype("uint8")
976
+ if images.shape[-1] == 1:
977
+ # special case for grayscale (single channel) images
978
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
979
+ else:
980
+ pil_images = [Image.fromarray(image) for image in images]
981
+
982
+ return pil_images
983
+
984
+ def text2img(
985
+ self,
986
+ prompt: Union[str, List[str]],
987
+ negative_prompt: Optional[Union[str, List[str]]] = None,
988
+ height: int = 512,
989
+ width: int = 512,
990
+ num_inference_steps: int = 50,
991
+ guidance_scale: float = 7.5,
992
+ num_images_per_prompt: Optional[int] = 1,
993
+ eta: float = 0.0,
994
+ generator: Optional[torch.Generator] = None,
995
+ latents: Optional[torch.FloatTensor] = None,
996
+ max_embeddings_multiples: Optional[int] = 3,
997
+ output_type: Optional[str] = "pil",
998
+ return_dict: bool = True,
999
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1000
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1001
+ callback_steps: int = 1,
1002
+ ):
1003
+ r"""
1004
+ Function for text-to-image generation.
1005
+ Args:
1006
+ prompt (`str` or `List[str]`):
1007
+ The prompt or prompts to guide the image generation.
1008
+ negative_prompt (`str` or `List[str]`, *optional*):
1009
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1010
+ if `guidance_scale` is less than `1`).
1011
+ height (`int`, *optional*, defaults to 512):
1012
+ The height in pixels of the generated image.
1013
+ width (`int`, *optional*, defaults to 512):
1014
+ The width in pixels of the generated image.
1015
+ num_inference_steps (`int`, *optional*, defaults to 50):
1016
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1017
+ expense of slower inference.
1018
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1019
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1020
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1021
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1022
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1023
+ usually at the expense of lower image quality.
1024
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1025
+ The number of images to generate per prompt.
1026
+ eta (`float`, *optional*, defaults to 0.0):
1027
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1028
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1029
+ generator (`torch.Generator`, *optional*):
1030
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1031
+ deterministic.
1032
+ latents (`torch.FloatTensor`, *optional*):
1033
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1034
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1035
+ tensor will ge generated by sampling using the supplied random `generator`.
1036
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1037
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1038
+ output_type (`str`, *optional*, defaults to `"pil"`):
1039
+ The output format of the generate image. Choose between
1040
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1041
+ return_dict (`bool`, *optional*, defaults to `True`):
1042
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1043
+ plain tuple.
1044
+ callback (`Callable`, *optional*):
1045
+ A function that will be called every `callback_steps` steps during inference. The function will be
1046
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1047
+ is_cancelled_callback (`Callable`, *optional*):
1048
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1049
+ `True`, the inference will be cancelled.
1050
+ callback_steps (`int`, *optional*, defaults to 1):
1051
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1052
+ called at every step.
1053
+ Returns:
1054
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1055
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1056
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1057
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1058
+ (nsfw) content, according to the `safety_checker`.
1059
+ """
1060
+ return self.__call__(
1061
+ prompt=prompt,
1062
+ negative_prompt=negative_prompt,
1063
+ height=height,
1064
+ width=width,
1065
+ num_inference_steps=num_inference_steps,
1066
+ guidance_scale=guidance_scale,
1067
+ num_images_per_prompt=num_images_per_prompt,
1068
+ eta=eta,
1069
+ generator=generator,
1070
+ latents=latents,
1071
+ max_embeddings_multiples=max_embeddings_multiples,
1072
+ output_type=output_type,
1073
+ return_dict=return_dict,
1074
+ callback=callback,
1075
+ is_cancelled_callback=is_cancelled_callback,
1076
+ callback_steps=callback_steps,
1077
+ )
1078
+
1079
+ def img2img(
1080
+ self,
1081
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1082
+ prompt: Union[str, List[str]],
1083
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1084
+ strength: float = 0.8,
1085
+ num_inference_steps: Optional[int] = 50,
1086
+ guidance_scale: Optional[float] = 7.5,
1087
+ num_images_per_prompt: Optional[int] = 1,
1088
+ eta: Optional[float] = 0.0,
1089
+ generator: Optional[torch.Generator] = None,
1090
+ max_embeddings_multiples: Optional[int] = 3,
1091
+ output_type: Optional[str] = "pil",
1092
+ return_dict: bool = True,
1093
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1094
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1095
+ callback_steps: int = 1,
1096
+ ):
1097
+ r"""
1098
+ Function for image-to-image generation.
1099
+ Args:
1100
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1101
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1102
+ process.
1103
+ prompt (`str` or `List[str]`):
1104
+ The prompt or prompts to guide the image generation.
1105
+ negative_prompt (`str` or `List[str]`, *optional*):
1106
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1107
+ if `guidance_scale` is less than `1`).
1108
+ strength (`float`, *optional*, defaults to 0.8):
1109
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1110
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1111
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1112
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1113
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1114
+ num_inference_steps (`int`, *optional*, defaults to 50):
1115
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1116
+ expense of slower inference. This parameter will be modulated by `strength`.
1117
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1118
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1119
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1120
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1121
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1122
+ usually at the expense of lower image quality.
1123
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1124
+ The number of images to generate per prompt.
1125
+ eta (`float`, *optional*, defaults to 0.0):
1126
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1127
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1128
+ generator (`torch.Generator`, *optional*):
1129
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1130
+ deterministic.
1131
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1132
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1133
+ output_type (`str`, *optional*, defaults to `"pil"`):
1134
+ The output format of the generate image. Choose between
1135
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1136
+ return_dict (`bool`, *optional*, defaults to `True`):
1137
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1138
+ plain tuple.
1139
+ callback (`Callable`, *optional*):
1140
+ A function that will be called every `callback_steps` steps during inference. The function will be
1141
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1142
+ is_cancelled_callback (`Callable`, *optional*):
1143
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1144
+ `True`, the inference will be cancelled.
1145
+ callback_steps (`int`, *optional*, defaults to 1):
1146
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1147
+ called at every step.
1148
+ Returns:
1149
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1150
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1151
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1152
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1153
+ (nsfw) content, according to the `safety_checker`.
1154
+ """
1155
+ return self.__call__(
1156
+ prompt=prompt,
1157
+ negative_prompt=negative_prompt,
1158
+ image=image,
1159
+ num_inference_steps=num_inference_steps,
1160
+ guidance_scale=guidance_scale,
1161
+ strength=strength,
1162
+ num_images_per_prompt=num_images_per_prompt,
1163
+ eta=eta,
1164
+ generator=generator,
1165
+ max_embeddings_multiples=max_embeddings_multiples,
1166
+ output_type=output_type,
1167
+ return_dict=return_dict,
1168
+ callback=callback,
1169
+ is_cancelled_callback=is_cancelled_callback,
1170
+ callback_steps=callback_steps,
1171
+ )
1172
+
1173
+ def inpaint(
1174
+ self,
1175
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1176
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1177
+ prompt: Union[str, List[str]],
1178
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1179
+ strength: float = 0.8,
1180
+ num_inference_steps: Optional[int] = 50,
1181
+ guidance_scale: Optional[float] = 7.5,
1182
+ num_images_per_prompt: Optional[int] = 1,
1183
+ eta: Optional[float] = 0.0,
1184
+ generator: Optional[torch.Generator] = None,
1185
+ max_embeddings_multiples: Optional[int] = 3,
1186
+ output_type: Optional[str] = "pil",
1187
+ return_dict: bool = True,
1188
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1189
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1190
+ callback_steps: int = 1,
1191
+ ):
1192
+ r"""
1193
+ Function for inpaint.
1194
+ Args:
1195
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1196
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1197
+ process. This is the image whose masked region will be inpainted.
1198
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1199
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1200
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1201
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1202
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1203
+ prompt (`str` or `List[str]`):
1204
+ The prompt or prompts to guide the image generation.
1205
+ negative_prompt (`str` or `List[str]`, *optional*):
1206
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1207
+ if `guidance_scale` is less than `1`).
1208
+ strength (`float`, *optional*, defaults to 0.8):
1209
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1210
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1211
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1212
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1213
+ num_inference_steps (`int`, *optional*, defaults to 50):
1214
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1215
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1216
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1217
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1218
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1219
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1220
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1221
+ usually at the expense of lower image quality.
1222
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1223
+ The number of images to generate per prompt.
1224
+ eta (`float`, *optional*, defaults to 0.0):
1225
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1226
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1227
+ generator (`torch.Generator`, *optional*):
1228
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1229
+ deterministic.
1230
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1231
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1232
+ output_type (`str`, *optional*, defaults to `"pil"`):
1233
+ The output format of the generate image. Choose between
1234
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1235
+ return_dict (`bool`, *optional*, defaults to `True`):
1236
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1237
+ plain tuple.
1238
+ callback (`Callable`, *optional*):
1239
+ A function that will be called every `callback_steps` steps during inference. The function will be
1240
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1241
+ is_cancelled_callback (`Callable`, *optional*):
1242
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1243
+ `True`, the inference will be cancelled.
1244
+ callback_steps (`int`, *optional*, defaults to 1):
1245
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1246
+ called at every step.
1247
+ Returns:
1248
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1249
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1250
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1251
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1252
+ (nsfw) content, according to the `safety_checker`.
1253
+ """
1254
+ return self.__call__(
1255
+ prompt=prompt,
1256
+ negative_prompt=negative_prompt,
1257
+ image=image,
1258
+ mask_image=mask_image,
1259
+ num_inference_steps=num_inference_steps,
1260
+ guidance_scale=guidance_scale,
1261
+ strength=strength,
1262
+ num_images_per_prompt=num_images_per_prompt,
1263
+ eta=eta,
1264
+ generator=generator,
1265
+ max_embeddings_multiples=max_embeddings_multiples,
1266
+ output_type=output_type,
1267
+ return_dict=return_dict,
1268
+ callback=callback,
1269
+ is_cancelled_callback=is_cancelled_callback,
1270
+ callback_steps=callback_steps,
1271
+ )
library/sdxl_model_util.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import safetensors
3
+ from accelerate import init_empty_weights
4
+ from accelerate.utils.modeling import set_module_tensor_to_device
5
+ from safetensors.torch import load_file, save_file
6
+ from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
7
+ from typing import List
8
+ from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
9
+ from library import model_util
10
+ from library import sdxl_original_unet
11
+ from library.utils import setup_logging
12
+
13
+ setup_logging()
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ VAE_SCALE_FACTOR = 0.13025
19
+ MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
20
+
21
+ # Diffusersの設定を読み込むための参照モデル
22
+ DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
23
+
24
+ DIFFUSERS_SDXL_UNET_CONFIG = {
25
+ "act_fn": "silu",
26
+ "addition_embed_type": "text_time",
27
+ "addition_embed_type_num_heads": 64,
28
+ "addition_time_embed_dim": 256,
29
+ "attention_head_dim": [5, 10, 20],
30
+ "block_out_channels": [320, 640, 1280],
31
+ "center_input_sample": False,
32
+ "class_embed_type": None,
33
+ "class_embeddings_concat": False,
34
+ "conv_in_kernel": 3,
35
+ "conv_out_kernel": 3,
36
+ "cross_attention_dim": 2048,
37
+ "cross_attention_norm": None,
38
+ "down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
39
+ "downsample_padding": 1,
40
+ "dual_cross_attention": False,
41
+ "encoder_hid_dim": None,
42
+ "encoder_hid_dim_type": None,
43
+ "flip_sin_to_cos": True,
44
+ "freq_shift": 0,
45
+ "in_channels": 4,
46
+ "layers_per_block": 2,
47
+ "mid_block_only_cross_attention": None,
48
+ "mid_block_scale_factor": 1,
49
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
50
+ "norm_eps": 1e-05,
51
+ "norm_num_groups": 32,
52
+ "num_attention_heads": None,
53
+ "num_class_embeds": None,
54
+ "only_cross_attention": False,
55
+ "out_channels": 4,
56
+ "projection_class_embeddings_input_dim": 2816,
57
+ "resnet_out_scale_factor": 1.0,
58
+ "resnet_skip_time_act": False,
59
+ "resnet_time_scale_shift": "default",
60
+ "sample_size": 128,
61
+ "time_cond_proj_dim": None,
62
+ "time_embedding_act_fn": None,
63
+ "time_embedding_dim": None,
64
+ "time_embedding_type": "positional",
65
+ "timestep_post_act": None,
66
+ "transformer_layers_per_block": [1, 2, 10],
67
+ "up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
68
+ "upcast_attention": False,
69
+ "use_linear_projection": True,
70
+ }
71
+
72
+
73
+ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
74
+ SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
75
+
76
+ # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
77
+ # logit_scaleはcheckpointの保存時に使用する
78
+ def convert_key(key):
79
+ # common conversion
80
+ key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
81
+ key = key.replace(SDXL_KEY_PREFIX, "text_model.")
82
+
83
+ if "resblocks" in key:
84
+ # resblocks conversion
85
+ key = key.replace(".resblocks.", ".layers.")
86
+ if ".ln_" in key:
87
+ key = key.replace(".ln_", ".layer_norm")
88
+ elif ".mlp." in key:
89
+ key = key.replace(".c_fc.", ".fc1.")
90
+ key = key.replace(".c_proj.", ".fc2.")
91
+ elif ".attn.out_proj" in key:
92
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
93
+ elif ".attn.in_proj" in key:
94
+ key = None # 特殊なので後で処理する
95
+ else:
96
+ raise ValueError(f"unexpected key in SD: {key}")
97
+ elif ".positional_embedding" in key:
98
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
99
+ elif ".text_projection" in key:
100
+ key = key.replace("text_model.text_projection", "text_projection.weight")
101
+ elif ".logit_scale" in key:
102
+ key = None # 後で処理する
103
+ elif ".token_embedding" in key:
104
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
105
+ elif ".ln_final" in key:
106
+ key = key.replace(".ln_final", ".final_layer_norm")
107
+ # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
108
+ elif ".embeddings.position_ids" in key:
109
+ key = None # remove this key: position_ids is not used in newer transformers
110
+ return key
111
+
112
+ keys = list(checkpoint.keys())
113
+ new_sd = {}
114
+ for key in keys:
115
+ new_key = convert_key(key)
116
+ if new_key is None:
117
+ continue
118
+ new_sd[new_key] = checkpoint[key]
119
+
120
+ # attnの変換
121
+ for key in keys:
122
+ if ".resblocks" in key and ".attn.in_proj_" in key:
123
+ # 三つに分割
124
+ values = torch.chunk(checkpoint[key], 3)
125
+
126
+ key_suffix = ".weight" if "weight" in key else ".bias"
127
+ key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
128
+ key_pfx = key_pfx.replace("_weight", "")
129
+ key_pfx = key_pfx.replace("_bias", "")
130
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
131
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
132
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
133
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
134
+
135
+ # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
136
+ logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
137
+
138
+ # temporary workaround for text_projection.weight.weight for Playground-v2
139
+ if "text_projection.weight.weight" in new_sd:
140
+ logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
141
+ new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
142
+ del new_sd["text_projection.weight.weight"]
143
+
144
+ return new_sd, logit_scale
145
+
146
+
147
+ # load state_dict without allocating new tensors
148
+ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
149
+ # dtype will use fp32 as default
150
+ missing_keys = list(model.state_dict().keys() - state_dict.keys())
151
+ unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
152
+
153
+ # similar to model.load_state_dict()
154
+ if not missing_keys and not unexpected_keys:
155
+ for k in list(state_dict.keys()):
156
+ set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
157
+ return "<All keys matched successfully>"
158
+
159
+ # error_msgs
160
+ error_msgs: List[str] = []
161
+ if missing_keys:
162
+ error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
163
+ if unexpected_keys:
164
+ error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
165
+
166
+ raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
167
+
168
+
169
+ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
170
+ # model_version is reserved for future use
171
+ # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
172
+
173
+ # Load the state dict
174
+ if model_util.is_safetensors(ckpt_path):
175
+ checkpoint = None
176
+ if disable_mmap:
177
+ state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
178
+ else:
179
+ try:
180
+ state_dict = load_file(ckpt_path, device=map_location)
181
+ except:
182
+ state_dict = load_file(ckpt_path) # prevent device invalid Error
183
+ epoch = None
184
+ global_step = None
185
+ else:
186
+ checkpoint = torch.load(ckpt_path, map_location=map_location)
187
+ if "state_dict" in checkpoint:
188
+ state_dict = checkpoint["state_dict"]
189
+ epoch = checkpoint.get("epoch", 0)
190
+ global_step = checkpoint.get("global_step", 0)
191
+ else:
192
+ state_dict = checkpoint
193
+ epoch = 0
194
+ global_step = 0
195
+ checkpoint = None
196
+
197
+ # U-Net
198
+ logger.info("building U-Net")
199
+ with init_empty_weights():
200
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
201
+
202
+ logger.info("loading U-Net from checkpoint")
203
+ unet_sd = {}
204
+ for k in list(state_dict.keys()):
205
+ if k.startswith("model.diffusion_model."):
206
+ unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
207
+ info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
208
+ logger.info(f"U-Net: {info}")
209
+
210
+ # Text Encoders
211
+ logger.info("building text encoders")
212
+
213
+ # Text Encoder 1 is same to Stability AI's SDXL
214
+ text_model1_cfg = CLIPTextConfig(
215
+ vocab_size=49408,
216
+ hidden_size=768,
217
+ intermediate_size=3072,
218
+ num_hidden_layers=12,
219
+ num_attention_heads=12,
220
+ max_position_embeddings=77,
221
+ hidden_act="quick_gelu",
222
+ layer_norm_eps=1e-05,
223
+ dropout=0.0,
224
+ attention_dropout=0.0,
225
+ initializer_range=0.02,
226
+ initializer_factor=1.0,
227
+ pad_token_id=1,
228
+ bos_token_id=0,
229
+ eos_token_id=2,
230
+ model_type="clip_text_model",
231
+ projection_dim=768,
232
+ # torch_dtype="float32",
233
+ # transformers_version="4.25.0.dev0",
234
+ )
235
+ with init_empty_weights():
236
+ text_model1 = CLIPTextModel._from_config(text_model1_cfg)
237
+
238
+ # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
239
+ # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
240
+ text_model2_cfg = CLIPTextConfig(
241
+ vocab_size=49408,
242
+ hidden_size=1280,
243
+ intermediate_size=5120,
244
+ num_hidden_layers=32,
245
+ num_attention_heads=20,
246
+ max_position_embeddings=77,
247
+ hidden_act="gelu",
248
+ layer_norm_eps=1e-05,
249
+ dropout=0.0,
250
+ attention_dropout=0.0,
251
+ initializer_range=0.02,
252
+ initializer_factor=1.0,
253
+ pad_token_id=1,
254
+ bos_token_id=0,
255
+ eos_token_id=2,
256
+ model_type="clip_text_model",
257
+ projection_dim=1280,
258
+ # torch_dtype="float32",
259
+ # transformers_version="4.25.0.dev0",
260
+ )
261
+ with init_empty_weights():
262
+ text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
263
+
264
+ logger.info("loading text encoders from checkpoint")
265
+ te1_sd = {}
266
+ te2_sd = {}
267
+ for k in list(state_dict.keys()):
268
+ if k.startswith("conditioner.embedders.0.transformer."):
269
+ te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
270
+ elif k.startswith("conditioner.embedders.1.model."):
271
+ te2_sd[k] = state_dict.pop(k)
272
+
273
+ # 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
274
+ if "text_model.embeddings.position_ids" in te1_sd:
275
+ te1_sd.pop("text_model.embeddings.position_ids")
276
+
277
+ info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
278
+ logger.info(f"text encoder 1: {info1}")
279
+
280
+ converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
281
+ info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
282
+ logger.info(f"text encoder 2: {info2}")
283
+
284
+ # prepare vae
285
+ logger.info("building VAE")
286
+ vae_config = model_util.create_vae_diffusers_config()
287
+ with init_empty_weights():
288
+ vae = AutoencoderKL(**vae_config)
289
+
290
+ logger.info("loading VAE from checkpoint")
291
+ converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
292
+ info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
293
+ logger.info(f"VAE: {info}")
294
+
295
+ ckpt_info = (epoch, global_step) if epoch is not None else None
296
+ return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
297
+
298
+
299
+ def make_unet_conversion_map():
300
+ unet_conversion_map_layer = []
301
+
302
+ for i in range(3): # num_blocks is 3 in sdxl
303
+ # loop over downblocks/upblocks
304
+ for j in range(2):
305
+ # loop over resnets/attentions for downblocks
306
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
307
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
308
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
309
+
310
+ if i < 3:
311
+ # no attention layers in down_blocks.3
312
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
313
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
314
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
315
+
316
+ for j in range(3):
317
+ # loop over resnets/attentions for upblocks
318
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
319
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
320
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
321
+
322
+ # if i > 0: commentout for sdxl
323
+ # no attention layers in up_blocks.0
324
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
325
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
326
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
327
+
328
+ if i < 3:
329
+ # no downsample in down_blocks.3
330
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
331
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
332
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
333
+
334
+ # no upsample in up_blocks.3
335
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
336
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
337
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
338
+
339
+ hf_mid_atn_prefix = "mid_block.attentions.0."
340
+ sd_mid_atn_prefix = "middle_block.1."
341
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
342
+
343
+ for j in range(2):
344
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
345
+ sd_mid_res_prefix = f"middle_block.{2*j}."
346
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
347
+
348
+ unet_conversion_map_resnet = [
349
+ # (stable-diffusion, HF Diffusers)
350
+ ("in_layers.0.", "norm1."),
351
+ ("in_layers.2.", "conv1."),
352
+ ("out_layers.0.", "norm2."),
353
+ ("out_layers.3.", "conv2."),
354
+ ("emb_layers.1.", "time_emb_proj."),
355
+ ("skip_connection.", "conv_shortcut."),
356
+ ]
357
+
358
+ unet_conversion_map = []
359
+ for sd, hf in unet_conversion_map_layer:
360
+ if "resnets" in hf:
361
+ for sd_res, hf_res in unet_conversion_map_resnet:
362
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
363
+ else:
364
+ unet_conversion_map.append((sd, hf))
365
+
366
+ for j in range(2):
367
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
368
+ sd_time_embed_prefix = f"time_embed.{j*2}."
369
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
370
+
371
+ for j in range(2):
372
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
373
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
374
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
375
+
376
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
377
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
378
+ unet_conversion_map.append(("out.2.", "conv_out."))
379
+
380
+ return unet_conversion_map
381
+
382
+
383
+ def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
384
+ unet_conversion_map = make_unet_conversion_map()
385
+
386
+ conversion_map = {hf: sd for sd, hf in unet_conversion_map}
387
+ return convert_unet_state_dict(du_sd, conversion_map)
388
+
389
+
390
+ def convert_unet_state_dict(src_sd, conversion_map):
391
+ converted_sd = {}
392
+ for src_key, value in src_sd.items():
393
+ # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
394
+ src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
395
+ while len(src_key_fragments) > 0:
396
+ src_key_prefix = ".".join(src_key_fragments) + "."
397
+ if src_key_prefix in conversion_map:
398
+ converted_prefix = conversion_map[src_key_prefix]
399
+ converted_key = converted_prefix + src_key[len(src_key_prefix) :]
400
+ converted_sd[converted_key] = value
401
+ break
402
+ src_key_fragments.pop(-1)
403
+ assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
404
+
405
+ return converted_sd
406
+
407
+
408
+ def convert_sdxl_unet_state_dict_to_diffusers(sd):
409
+ unet_conversion_map = make_unet_conversion_map()
410
+
411
+ conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
412
+ return convert_unet_state_dict(sd, conversion_dict)
413
+
414
+
415
+ def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
416
+ def convert_key(key):
417
+ # position_idsの除去
418
+ if ".position_ids" in key:
419
+ return None
420
+
421
+ # common
422
+ key = key.replace("text_model.encoder.", "transformer.")
423
+ key = key.replace("text_model.", "")
424
+ if "layers" in key:
425
+ # resblocks conversion
426
+ key = key.replace(".layers.", ".resblocks.")
427
+ if ".layer_norm" in key:
428
+ key = key.replace(".layer_norm", ".ln_")
429
+ elif ".mlp." in key:
430
+ key = key.replace(".fc1.", ".c_fc.")
431
+ key = key.replace(".fc2.", ".c_proj.")
432
+ elif ".self_attn.out_proj" in key:
433
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
434
+ elif ".self_attn." in key:
435
+ key = None # 特殊なので後で処理する
436
+ else:
437
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
438
+ elif ".position_embedding" in key:
439
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
440
+ elif ".token_embedding" in key:
441
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
442
+ elif "text_projection" in key: # no dot in key
443
+ key = key.replace("text_projection.weight", "text_projection")
444
+ elif "final_layer_norm" in key:
445
+ key = key.replace("final_layer_norm", "ln_final")
446
+ return key
447
+
448
+ keys = list(checkpoint.keys())
449
+ new_sd = {}
450
+ for key in keys:
451
+ new_key = convert_key(key)
452
+ if new_key is None:
453
+ continue
454
+ new_sd[new_key] = checkpoint[key]
455
+
456
+ # attnの変換
457
+ for key in keys:
458
+ if "layers" in key and "q_proj" in key:
459
+ # 三つを結合
460
+ key_q = key
461
+ key_k = key.replace("q_proj", "k_proj")
462
+ key_v = key.replace("q_proj", "v_proj")
463
+
464
+ value_q = checkpoint[key_q]
465
+ value_k = checkpoint[key_k]
466
+ value_v = checkpoint[key_v]
467
+ value = torch.cat([value_q, value_k, value_v])
468
+
469
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
470
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
471
+ new_sd[new_key] = value
472
+
473
+ if logit_scale is not None:
474
+ new_sd["logit_scale"] = logit_scale
475
+
476
+ return new_sd
477
+
478
+
479
+ def save_stable_diffusion_checkpoint(
480
+ output_file,
481
+ text_encoder1,
482
+ text_encoder2,
483
+ unet,
484
+ epochs,
485
+ steps,
486
+ ckpt_info,
487
+ vae,
488
+ logit_scale,
489
+ metadata,
490
+ save_dtype=None,
491
+ ):
492
+ state_dict = {}
493
+
494
+ def update_sd(prefix, sd):
495
+ for k, v in sd.items():
496
+ key = prefix + k
497
+ if save_dtype is not None:
498
+ v = v.detach().clone().to("cpu").to(save_dtype)
499
+ state_dict[key] = v
500
+
501
+ # Convert the UNet model
502
+ update_sd("model.diffusion_model.", unet.state_dict())
503
+
504
+ # Convert the text encoders
505
+ update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
506
+
507
+ text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
508
+ update_sd("conditioner.embedders.1.model.", text_enc2_dict)
509
+
510
+ # Convert the VAE
511
+ vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
512
+ update_sd("first_stage_model.", vae_dict)
513
+
514
+ # Put together new checkpoint
515
+ key_count = len(state_dict.keys())
516
+ new_ckpt = {"state_dict": state_dict}
517
+
518
+ # epoch and global_step are sometimes not int
519
+ if ckpt_info is not None:
520
+ epochs += ckpt_info[0]
521
+ steps += ckpt_info[1]
522
+
523
+ new_ckpt["epoch"] = epochs
524
+ new_ckpt["global_step"] = steps
525
+
526
+ if model_util.is_safetensors(output_file):
527
+ save_file(state_dict, output_file, metadata)
528
+ else:
529
+ torch.save(new_ckpt, output_file)
530
+
531
+ return key_count
532
+
533
+
534
+ def save_diffusers_checkpoint(
535
+ output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
536
+ ):
537
+ from diffusers import StableDiffusionXLPipeline
538
+
539
+ # convert U-Net
540
+ unet_sd = unet.state_dict()
541
+ du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
542
+
543
+ diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
544
+ if save_dtype is not None:
545
+ diffusers_unet.to(save_dtype)
546
+ diffusers_unet.load_state_dict(du_unet_sd)
547
+
548
+ # create pipeline to save
549
+ if pretrained_model_name_or_path is None:
550
+ pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
551
+
552
+ scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
553
+ tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
554
+ tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
555
+ if vae is None:
556
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
557
+
558
+ # prevent local path from being saved
559
+ def remove_name_or_path(model):
560
+ if hasattr(model, "config"):
561
+ model.config._name_or_path = None
562
+ model.config._name_or_path = None
563
+
564
+ remove_name_or_path(diffusers_unet)
565
+ remove_name_or_path(text_encoder1)
566
+ remove_name_or_path(text_encoder2)
567
+ remove_name_or_path(scheduler)
568
+ remove_name_or_path(tokenizer1)
569
+ remove_name_or_path(tokenizer2)
570
+ remove_name_or_path(vae)
571
+
572
+ pipeline = StableDiffusionXLPipeline(
573
+ unet=diffusers_unet,
574
+ text_encoder=text_encoder1,
575
+ text_encoder_2=text_encoder2,
576
+ vae=vae,
577
+ scheduler=scheduler,
578
+ tokenizer=tokenizer1,
579
+ tokenizer_2=tokenizer2,
580
+ )
581
+ if save_dtype is not None:
582
+ pipeline.to(None, save_dtype)
583
+ pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
library/sdxl_original_control_net.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some parts are modified from Diffusers library (Apache License 2.0)
2
+
3
+ import math
4
+ from types import SimpleNamespace
5
+ from typing import Any, Optional
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from einops import rearrange
11
+ from library.utils import setup_logging
12
+
13
+ setup_logging()
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ from library import sdxl_original_unet
19
+ from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
20
+
21
+
22
+ class ControlNetConditioningEmbedding(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ dims = [16, 32, 96, 256]
27
+
28
+ self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
29
+ self.blocks = nn.ModuleList([])
30
+
31
+ for i in range(len(dims) - 1):
32
+ channel_in = dims[i]
33
+ channel_out = dims[i + 1]
34
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
35
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
36
+
37
+ self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
38
+ nn.init.zeros_(self.conv_out.weight) # zero module weight
39
+ nn.init.zeros_(self.conv_out.bias) # zero module bias
40
+
41
+ def forward(self, x):
42
+ x = self.conv_in(x)
43
+ x = F.silu(x)
44
+ for block in self.blocks:
45
+ x = block(x)
46
+ x = F.silu(x)
47
+ x = self.conv_out(x)
48
+ return x
49
+
50
+
51
+ class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
52
+ def __init__(self, multiplier: Optional[float] = None, **kwargs):
53
+ super().__init__(**kwargs)
54
+ self.multiplier = multiplier
55
+
56
+ # remove unet layers
57
+ self.output_blocks = nn.ModuleList([])
58
+ del self.out
59
+
60
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
61
+
62
+ dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
63
+ self.controlnet_down_blocks = nn.ModuleList([])
64
+ for dim in dims:
65
+ self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
66
+ nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
67
+ nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
68
+
69
+ self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
70
+ nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
71
+ nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
72
+
73
+ def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
74
+ unet_sd = unet.state_dict()
75
+ unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
76
+ sd = super().state_dict()
77
+ sd.update(unet_sd)
78
+ info = super().load_state_dict(sd, strict=True, assign=True)
79
+ return info
80
+
81
+ def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
82
+ # convert state_dict to SAI format
83
+ unet_sd = {}
84
+ for k in list(state_dict.keys()):
85
+ if not k.startswith("controlnet_"):
86
+ unet_sd[k] = state_dict.pop(k)
87
+ unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
88
+ state_dict.update(unet_sd)
89
+ super().load_state_dict(state_dict, strict=strict, assign=assign)
90
+
91
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
92
+ # convert state_dict to Diffusers format
93
+ state_dict = super().state_dict(destination, prefix, keep_vars)
94
+ control_net_sd = {}
95
+ for k in list(state_dict.keys()):
96
+ if k.startswith("controlnet_"):
97
+ control_net_sd[k] = state_dict.pop(k)
98
+ state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
99
+ state_dict.update(control_net_sd)
100
+ return state_dict
101
+
102
+ def forward(
103
+ self,
104
+ x: torch.Tensor,
105
+ timesteps: Optional[torch.Tensor] = None,
106
+ context: Optional[torch.Tensor] = None,
107
+ y: Optional[torch.Tensor] = None,
108
+ cond_image: Optional[torch.Tensor] = None,
109
+ **kwargs,
110
+ ) -> torch.Tensor:
111
+ # broadcast timesteps to batch dimension
112
+ timesteps = timesteps.expand(x.shape[0])
113
+
114
+ t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
115
+ t_emb = t_emb.to(x.dtype)
116
+ emb = self.time_embed(t_emb)
117
+
118
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
119
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
120
+ emb = emb + self.label_emb(y)
121
+
122
+ def call_module(module, h, emb, context):
123
+ x = h
124
+ for layer in module:
125
+ if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
126
+ x = layer(x, emb)
127
+ elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
128
+ x = layer(x, context)
129
+ else:
130
+ x = layer(x)
131
+ return x
132
+
133
+ h = x
134
+ multiplier = self.multiplier if self.multiplier is not None else 1.0
135
+ hs = []
136
+ for i, module in enumerate(self.input_blocks):
137
+ h = call_module(module, h, emb, context)
138
+ if i == 0:
139
+ h = self.controlnet_cond_embedding(cond_image) + h
140
+ hs.append(self.controlnet_down_blocks[i](h) * multiplier)
141
+
142
+ h = call_module(self.middle_block, h, emb, context)
143
+ h = self.controlnet_mid_block(h) * multiplier
144
+
145
+ return hs, h
146
+
147
+
148
+ class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
149
+ """
150
+ This class is for training purpose only.
151
+ """
152
+
153
+ def __init__(self, **kwargs):
154
+ super().__init__(**kwargs)
155
+
156
+ def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
157
+ # broadcast timesteps to batch dimension
158
+ timesteps = timesteps.expand(x.shape[0])
159
+
160
+ hs = []
161
+ t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
162
+ t_emb = t_emb.to(x.dtype)
163
+ emb = self.time_embed(t_emb)
164
+
165
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
166
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
167
+ emb = emb + self.label_emb(y)
168
+
169
+ def call_module(module, h, emb, context):
170
+ x = h
171
+ for layer in module:
172
+ if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
173
+ x = layer(x, emb)
174
+ elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
175
+ x = layer(x, context)
176
+ else:
177
+ x = layer(x)
178
+ return x
179
+
180
+ h = x
181
+ for module in self.input_blocks:
182
+ h = call_module(module, h, emb, context)
183
+ hs.append(h)
184
+
185
+ h = call_module(self.middle_block, h, emb, context)
186
+ h = h + mid_add
187
+
188
+ for module in self.output_blocks:
189
+ resi = hs.pop() + input_resi_add.pop()
190
+ h = torch.cat([h, resi], dim=1)
191
+ h = call_module(module, h, emb, context)
192
+
193
+ h = h.type(x.dtype)
194
+ h = call_module(self.out, h, emb, context)
195
+
196
+ return h
197
+
198
+
199
+ if __name__ == "__main__":
200
+ import time
201
+
202
+ logger.info("create unet")
203
+ unet = SdxlControlledUNet()
204
+ unet.to("cuda", torch.bfloat16)
205
+ unet.set_use_sdpa(True)
206
+ unet.set_gradient_checkpointing(True)
207
+ unet.train()
208
+
209
+ logger.info("create control_net")
210
+ control_net = SdxlControlNet()
211
+ control_net.to("cuda")
212
+ control_net.set_use_sdpa(True)
213
+ control_net.set_gradient_checkpointing(True)
214
+ control_net.train()
215
+
216
+ logger.info("Initialize control_net from unet")
217
+ control_net.init_from_unet(unet)
218
+
219
+ unet.requires_grad_(False)
220
+ control_net.requires_grad_(True)
221
+
222
+ # 使用メモリ量確認用の疑似学習ループ
223
+ logger.info("preparing optimizer")
224
+
225
+ # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
226
+
227
+ import bitsandbytes
228
+
229
+ optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
230
+ # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
231
+ # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
232
+
233
+ # import transformers
234
+ # optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
235
+
236
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
237
+
238
+ logger.info("start training")
239
+ steps = 10
240
+ batch_size = 1
241
+
242
+ for step in range(steps):
243
+ logger.info(f"step {step}")
244
+ if step == 1:
245
+ time_start = time.perf_counter()
246
+
247
+ x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
248
+ t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
249
+ txt = torch.randn(batch_size, 77, 2048).cuda()
250
+ vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
251
+ cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
252
+
253
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
254
+ input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
255
+ output = unet(x, t, txt, vector, input_resi_add, mid_add)
256
+ target = torch.randn_like(output)
257
+ loss = torch.nn.functional.mse_loss(output, target)
258
+
259
+ scaler.scale(loss).backward()
260
+ scaler.step(optimizer)
261
+ scaler.update()
262
+ optimizer.zero_grad(set_to_none=True)
263
+
264
+ time_end = time.perf_counter()
265
+ logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
266
+
267
+ logger.info("finish training")
268
+ sd = control_net.state_dict()
269
+
270
+ from safetensors.torch import save_file
271
+
272
+ save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")
library/sdxl_original_unet.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusersのコードをベースとした sd_xl_baseのU-Net
2
+ # state dictの形式をSDXLに合わせてある
3
+
4
+ """
5
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
6
+ params:
7
+ adm_in_channels: 2816
8
+ num_classes: sequential
9
+ use_checkpoint: True
10
+ in_channels: 4
11
+ out_channels: 4
12
+ model_channels: 320
13
+ attention_resolutions: [4, 2]
14
+ num_res_blocks: 2
15
+ channel_mult: [1, 2, 4]
16
+ num_head_channels: 64
17
+ use_spatial_transformer: True
18
+ use_linear_in_transformer: True
19
+ transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
20
+ context_dim: 2048
21
+ spatial_transformer_attn_type: softmax-xformers
22
+ legacy: False
23
+ """
24
+
25
+ import math
26
+ from types import SimpleNamespace
27
+ from typing import Any, Optional
28
+ import torch
29
+ import torch.utils.checkpoint
30
+ from torch import nn
31
+ from torch.nn import functional as F
32
+ from einops import rearrange
33
+ from library.utils import setup_logging
34
+
35
+ setup_logging()
36
+ import logging
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ IN_CHANNELS: int = 4
41
+ OUT_CHANNELS: int = 4
42
+ ADM_IN_CHANNELS: int = 2816
43
+ CONTEXT_DIM: int = 2048
44
+ MODEL_CHANNELS: int = 320
45
+ TIME_EMBED_DIM = 320 * 4
46
+
47
+ USE_REENTRANT = True
48
+
49
+ # region memory efficient attention
50
+
51
+ # FlashAttentionを使うCrossAttention
52
+ # based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
53
+ # LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
54
+
55
+ # constants
56
+
57
+ EPSILON = 1e-6
58
+
59
+ # helper functions
60
+
61
+
62
+ def exists(val):
63
+ return val is not None
64
+
65
+
66
+ def default(val, d):
67
+ return val if exists(val) else d
68
+
69
+
70
+ # flash attention forwards and backwards
71
+
72
+ # https://arxiv.org/abs/2205.14135
73
+
74
+
75
+ class FlashAttentionFunction(torch.autograd.Function):
76
+ @staticmethod
77
+ @torch.no_grad()
78
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
79
+ """Algorithm 2 in the paper"""
80
+
81
+ device = q.device
82
+ dtype = q.dtype
83
+ max_neg_value = -torch.finfo(q.dtype).max
84
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
85
+
86
+ o = torch.zeros_like(q)
87
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
88
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
89
+
90
+ scale = q.shape[-1] ** -0.5
91
+
92
+ if not exists(mask):
93
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
94
+ else:
95
+ mask = rearrange(mask, "b n -> b 1 1 n")
96
+ mask = mask.split(q_bucket_size, dim=-1)
97
+
98
+ row_splits = zip(
99
+ q.split(q_bucket_size, dim=-2),
100
+ o.split(q_bucket_size, dim=-2),
101
+ mask,
102
+ all_row_sums.split(q_bucket_size, dim=-2),
103
+ all_row_maxes.split(q_bucket_size, dim=-2),
104
+ )
105
+
106
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
107
+ q_start_index = ind * q_bucket_size - qk_len_diff
108
+
109
+ col_splits = zip(
110
+ k.split(k_bucket_size, dim=-2),
111
+ v.split(k_bucket_size, dim=-2),
112
+ )
113
+
114
+ for k_ind, (kc, vc) in enumerate(col_splits):
115
+ k_start_index = k_ind * k_bucket_size
116
+
117
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
118
+
119
+ if exists(row_mask):
120
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
121
+
122
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
123
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
124
+ q_start_index - k_start_index + 1
125
+ )
126
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
127
+
128
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
129
+ attn_weights -= block_row_maxes
130
+ exp_weights = torch.exp(attn_weights)
131
+
132
+ if exists(row_mask):
133
+ exp_weights.masked_fill_(~row_mask, 0.0)
134
+
135
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
136
+
137
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
138
+
139
+ exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc)
140
+
141
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
142
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
143
+
144
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
145
+
146
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
147
+
148
+ row_maxes.copy_(new_row_maxes)
149
+ row_sums.copy_(new_row_sums)
150
+
151
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
152
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
153
+
154
+ return o
155
+
156
+ @staticmethod
157
+ @torch.no_grad()
158
+ def backward(ctx, do):
159
+ """Algorithm 4 in the paper"""
160
+
161
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
162
+ q, k, v, o, l, m = ctx.saved_tensors
163
+
164
+ device = q.device
165
+
166
+ max_neg_value = -torch.finfo(q.dtype).max
167
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
168
+
169
+ dq = torch.zeros_like(q)
170
+ dk = torch.zeros_like(k)
171
+ dv = torch.zeros_like(v)
172
+
173
+ row_splits = zip(
174
+ q.split(q_bucket_size, dim=-2),
175
+ o.split(q_bucket_size, dim=-2),
176
+ do.split(q_bucket_size, dim=-2),
177
+ mask,
178
+ l.split(q_bucket_size, dim=-2),
179
+ m.split(q_bucket_size, dim=-2),
180
+ dq.split(q_bucket_size, dim=-2),
181
+ )
182
+
183
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
184
+ q_start_index = ind * q_bucket_size - qk_len_diff
185
+
186
+ col_splits = zip(
187
+ k.split(k_bucket_size, dim=-2),
188
+ v.split(k_bucket_size, dim=-2),
189
+ dk.split(k_bucket_size, dim=-2),
190
+ dv.split(k_bucket_size, dim=-2),
191
+ )
192
+
193
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
194
+ k_start_index = k_ind * k_bucket_size
195
+
196
+ attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
197
+
198
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
199
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu(
200
+ q_start_index - k_start_index + 1
201
+ )
202
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
203
+
204
+ exp_attn_weights = torch.exp(attn_weights - mc)
205
+
206
+ if exists(row_mask):
207
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
208
+
209
+ p = exp_attn_weights / lc
210
+
211
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
212
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
213
+
214
+ D = (doc * oc).sum(dim=-1, keepdims=True)
215
+ ds = p * scale * (dp - D)
216
+
217
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
218
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
219
+
220
+ dqc.add_(dq_chunk)
221
+ dkc.add_(dk_chunk)
222
+ dvc.add_(dv_chunk)
223
+
224
+ return dq, dk, dv, None, None, None, None
225
+
226
+
227
+ # endregion
228
+
229
+
230
+ def get_parameter_dtype(parameter: torch.nn.Module):
231
+ return next(parameter.parameters()).dtype
232
+
233
+
234
+ def get_parameter_device(parameter: torch.nn.Module):
235
+ return next(parameter.parameters()).device
236
+
237
+
238
+ def get_timestep_embedding(
239
+ timesteps: torch.Tensor,
240
+ embedding_dim: int,
241
+ downscale_freq_shift: float = 1,
242
+ scale: float = 1,
243
+ max_period: int = 10000,
244
+ ):
245
+ """
246
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
247
+
248
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
249
+ These may be fractional.
250
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
251
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
252
+ """
253
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
254
+
255
+ half_dim = embedding_dim // 2
256
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
257
+ exponent = exponent / (half_dim - downscale_freq_shift)
258
+
259
+ emb = torch.exp(exponent)
260
+ emb = timesteps[:, None].float() * emb[None, :]
261
+
262
+ # scale embeddings
263
+ emb = scale * emb
264
+
265
+ # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True
266
+ emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1)
267
+
268
+ # zero pad
269
+ if embedding_dim % 2 == 1:
270
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
271
+ return emb
272
+
273
+
274
+ # Deep Shrink: We do not common this function, because minimize dependencies.
275
+ def resize_like(x, target, mode="bicubic", align_corners=False):
276
+ org_dtype = x.dtype
277
+ if org_dtype == torch.bfloat16:
278
+ x = x.to(torch.float32)
279
+
280
+ if x.shape[-2:] != target.shape[-2:]:
281
+ if mode == "nearest":
282
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode)
283
+ else:
284
+ x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
285
+
286
+ if org_dtype == torch.bfloat16:
287
+ x = x.to(org_dtype)
288
+ return x
289
+
290
+
291
+ class GroupNorm32(nn.GroupNorm):
292
+ def forward(self, x):
293
+ if self.weight.dtype != torch.float32:
294
+ return super().forward(x)
295
+ return super().forward(x.float()).type(x.dtype)
296
+
297
+
298
+ class ResnetBlock2D(nn.Module):
299
+ def __init__(
300
+ self,
301
+ in_channels,
302
+ out_channels,
303
+ ):
304
+ super().__init__()
305
+ self.in_channels = in_channels
306
+ self.out_channels = out_channels
307
+
308
+ self.in_layers = nn.Sequential(
309
+ GroupNorm32(32, in_channels),
310
+ nn.SiLU(),
311
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
312
+ )
313
+
314
+ self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels))
315
+
316
+ self.out_layers = nn.Sequential(
317
+ GroupNorm32(32, out_channels),
318
+ nn.SiLU(),
319
+ nn.Identity(), # to make state_dict compatible with original model
320
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
321
+ )
322
+
323
+ if in_channels != out_channels:
324
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
325
+ else:
326
+ self.skip_connection = nn.Identity()
327
+
328
+ self.gradient_checkpointing = False
329
+
330
+ def forward_body(self, x, emb):
331
+ h = self.in_layers(x)
332
+ emb_out = self.emb_layers(emb).type(h.dtype)
333
+ h = h + emb_out[:, :, None, None]
334
+ h = self.out_layers(h)
335
+ x = self.skip_connection(x)
336
+ return x + h
337
+
338
+ def forward(self, x, emb):
339
+ if self.training and self.gradient_checkpointing:
340
+ # logger.info("ResnetBlock2D: gradient_checkpointing")
341
+
342
+ def create_custom_forward(func):
343
+ def custom_forward(*inputs):
344
+ return func(*inputs)
345
+
346
+ return custom_forward
347
+
348
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT)
349
+ else:
350
+ x = self.forward_body(x, emb)
351
+
352
+ return x
353
+
354
+
355
+ class Downsample2D(nn.Module):
356
+ def __init__(self, channels, out_channels):
357
+ super().__init__()
358
+
359
+ self.channels = channels
360
+ self.out_channels = out_channels
361
+
362
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1)
363
+
364
+ self.gradient_checkpointing = False
365
+
366
+ def forward_body(self, hidden_states):
367
+ assert hidden_states.shape[1] == self.channels
368
+ hidden_states = self.op(hidden_states)
369
+
370
+ return hidden_states
371
+
372
+ def forward(self, hidden_states):
373
+ if self.training and self.gradient_checkpointing:
374
+ # logger.info("Downsample2D: gradient_checkpointing")
375
+
376
+ def create_custom_forward(func):
377
+ def custom_forward(*inputs):
378
+ return func(*inputs)
379
+
380
+ return custom_forward
381
+
382
+ hidden_states = torch.utils.checkpoint.checkpoint(
383
+ create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT
384
+ )
385
+ else:
386
+ hidden_states = self.forward_body(hidden_states)
387
+
388
+ return hidden_states
389
+
390
+
391
+ class CrossAttention(nn.Module):
392
+ def __init__(
393
+ self,
394
+ query_dim: int,
395
+ cross_attention_dim: Optional[int] = None,
396
+ heads: int = 8,
397
+ dim_head: int = 64,
398
+ upcast_attention: bool = False,
399
+ ):
400
+ super().__init__()
401
+ inner_dim = dim_head * heads
402
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
403
+ self.upcast_attention = upcast_attention
404
+
405
+ self.scale = dim_head**-0.5
406
+ self.heads = heads
407
+
408
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
409
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
410
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)
411
+
412
+ self.to_out = nn.ModuleList([])
413
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
414
+ # no dropout here
415
+
416
+ self.use_memory_efficient_attention_xformers = False
417
+ self.use_memory_efficient_attention_mem_eff = False
418
+ self.use_sdpa = False
419
+
420
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
421
+ self.use_memory_efficient_attention_xformers = xformers
422
+ self.use_memory_efficient_attention_mem_eff = mem_eff
423
+
424
+ def set_use_sdpa(self, sdpa):
425
+ self.use_sdpa = sdpa
426
+
427
+ def reshape_heads_to_batch_dim(self, tensor):
428
+ batch_size, seq_len, dim = tensor.shape
429
+ head_size = self.heads
430
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
431
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
432
+ return tensor
433
+
434
+ def reshape_batch_dim_to_heads(self, tensor):
435
+ batch_size, seq_len, dim = tensor.shape
436
+ head_size = self.heads
437
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
438
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
439
+ return tensor
440
+
441
+ def forward(self, hidden_states, context=None, mask=None):
442
+ if self.use_memory_efficient_attention_xformers:
443
+ return self.forward_memory_efficient_xformers(hidden_states, context, mask)
444
+ if self.use_memory_efficient_attention_mem_eff:
445
+ return self.forward_memory_efficient_mem_eff(hidden_states, context, mask)
446
+ if self.use_sdpa:
447
+ return self.forward_sdpa(hidden_states, context, mask)
448
+
449
+ query = self.to_q(hidden_states)
450
+ context = context if context is not None else hidden_states
451
+ key = self.to_k(context)
452
+ value = self.to_v(context)
453
+
454
+ query = self.reshape_heads_to_batch_dim(query)
455
+ key = self.reshape_heads_to_batch_dim(key)
456
+ value = self.reshape_heads_to_batch_dim(value)
457
+
458
+ hidden_states = self._attention(query, key, value)
459
+
460
+ # linear proj
461
+ hidden_states = self.to_out[0](hidden_states)
462
+ # hidden_states = self.to_out[1](hidden_states) # no dropout
463
+ return hidden_states
464
+
465
+ def _attention(self, query, key, value):
466
+ if self.upcast_attention:
467
+ query = query.float()
468
+ key = key.float()
469
+
470
+ attention_scores = torch.baddbmm(
471
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
472
+ query,
473
+ key.transpose(-1, -2),
474
+ beta=0,
475
+ alpha=self.scale,
476
+ )
477
+ attention_probs = attention_scores.softmax(dim=-1)
478
+
479
+ # cast back to the original dtype
480
+ attention_probs = attention_probs.to(value.dtype)
481
+
482
+ # compute attention output
483
+ hidden_states = torch.bmm(attention_probs, value)
484
+
485
+ # reshape hidden_states
486
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
487
+ return hidden_states
488
+
489
+ # TODO support Hypernetworks
490
+ def forward_memory_efficient_xformers(self, x, context=None, mask=None):
491
+ import xformers.ops
492
+
493
+ h = self.heads
494
+ q_in = self.to_q(x)
495
+ context = context if context is not None else x
496
+ context = context.to(x.dtype)
497
+ k_in = self.to_k(context)
498
+ v_in = self.to_v(context)
499
+
500
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in))
501
+ del q_in, k_in, v_in
502
+
503
+ q = q.contiguous()
504
+ k = k.contiguous()
505
+ v = v.contiguous()
506
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
507
+ del q, k, v
508
+
509
+ out = rearrange(out, "b n h d -> b n (h d)", h=h)
510
+
511
+ out = self.to_out[0](out)
512
+ return out
513
+
514
+ def forward_memory_efficient_mem_eff(self, x, context=None, mask=None):
515
+ flash_func = FlashAttentionFunction
516
+
517
+ q_bucket_size = 512
518
+ k_bucket_size = 1024
519
+
520
+ h = self.heads
521
+ q = self.to_q(x)
522
+ context = context if context is not None else x
523
+ context = context.to(x.dtype)
524
+ k = self.to_k(context)
525
+ v = self.to_v(context)
526
+ del context, x
527
+
528
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
529
+
530
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
531
+
532
+ out = rearrange(out, "b h n d -> b n (h d)")
533
+
534
+ out = self.to_out[0](out)
535
+ return out
536
+
537
+ def forward_sdpa(self, x, context=None, mask=None):
538
+ h = self.heads
539
+ q_in = self.to_q(x)
540
+ context = context if context is not None else x
541
+ context = context.to(x.dtype)
542
+ k_in = self.to_k(context)
543
+ v_in = self.to_v(context)
544
+
545
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in))
546
+ del q_in, k_in, v_in
547
+
548
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
549
+
550
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
551
+
552
+ out = self.to_out[0](out)
553
+ return out
554
+
555
+
556
+ # feedforward
557
+ class GEGLU(nn.Module):
558
+ r"""
559
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
560
+
561
+ Parameters:
562
+ dim_in (`int`): The number of channels in the input.
563
+ dim_out (`int`): The number of channels in the output.
564
+ """
565
+
566
+ def __init__(self, dim_in: int, dim_out: int):
567
+ super().__init__()
568
+ self.proj = nn.Linear(dim_in, dim_out * 2)
569
+
570
+ def gelu(self, gate):
571
+ if gate.device.type != "mps":
572
+ return F.gelu(gate)
573
+ # mps: gelu is not implemented for float16
574
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
575
+
576
+ def forward(self, hidden_states):
577
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
578
+ return hidden_states * self.gelu(gate)
579
+
580
+
581
+ class FeedForward(nn.Module):
582
+ def __init__(
583
+ self,
584
+ dim: int,
585
+ ):
586
+ super().__init__()
587
+ inner_dim = int(dim * 4) # mult is always 4
588
+
589
+ self.net = nn.ModuleList([])
590
+ # project in
591
+ self.net.append(GEGLU(dim, inner_dim))
592
+ # project dropout
593
+ self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0
594
+ # project out
595
+ self.net.append(nn.Linear(inner_dim, dim))
596
+
597
+ def forward(self, hidden_states):
598
+ for module in self.net:
599
+ hidden_states = module(hidden_states)
600
+ return hidden_states
601
+
602
+
603
+ class BasicTransformerBlock(nn.Module):
604
+ def __init__(
605
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False
606
+ ):
607
+ super().__init__()
608
+
609
+ self.gradient_checkpointing = False
610
+
611
+ # 1. Self-Attn
612
+ self.attn1 = CrossAttention(
613
+ query_dim=dim,
614
+ cross_attention_dim=None,
615
+ heads=num_attention_heads,
616
+ dim_head=attention_head_dim,
617
+ upcast_attention=upcast_attention,
618
+ )
619
+ self.ff = FeedForward(dim)
620
+
621
+ # 2. Cross-Attn
622
+ self.attn2 = CrossAttention(
623
+ query_dim=dim,
624
+ cross_attention_dim=cross_attention_dim,
625
+ heads=num_attention_heads,
626
+ dim_head=attention_head_dim,
627
+ upcast_attention=upcast_attention,
628
+ )
629
+
630
+ self.norm1 = nn.LayerNorm(dim)
631
+ self.norm2 = nn.LayerNorm(dim)
632
+
633
+ # 3. Feed-forward
634
+ self.norm3 = nn.LayerNorm(dim)
635
+
636
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool):
637
+ self.attn1.set_use_memory_efficient_attention(xformers, mem_eff)
638
+ self.attn2.set_use_memory_efficient_attention(xformers, mem_eff)
639
+
640
+ def set_use_sdpa(self, sdpa: bool):
641
+ self.attn1.set_use_sdpa(sdpa)
642
+ self.attn2.set_use_sdpa(sdpa)
643
+
644
+ def forward_body(self, hidden_states, context=None, timestep=None):
645
+ # 1. Self-Attention
646
+ norm_hidden_states = self.norm1(hidden_states)
647
+
648
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
649
+
650
+ # 2. Cross-Attention
651
+ norm_hidden_states = self.norm2(hidden_states)
652
+ hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
653
+
654
+ # 3. Feed-forward
655
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
656
+
657
+ return hidden_states
658
+
659
+ def forward(self, hidden_states, context=None, timestep=None):
660
+ if self.training and self.gradient_checkpointing:
661
+ # logger.info("BasicTransformerBlock: checkpointing")
662
+
663
+ def create_custom_forward(func):
664
+ def custom_forward(*inputs):
665
+ return func(*inputs)
666
+
667
+ return custom_forward
668
+
669
+ output = torch.utils.checkpoint.checkpoint(
670
+ create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT
671
+ )
672
+ else:
673
+ output = self.forward_body(hidden_states, context, timestep)
674
+
675
+ return output
676
+
677
+
678
+ class Transformer2DModel(nn.Module):
679
+ def __init__(
680
+ self,
681
+ num_attention_heads: int = 16,
682
+ attention_head_dim: int = 88,
683
+ in_channels: Optional[int] = None,
684
+ cross_attention_dim: Optional[int] = None,
685
+ use_linear_projection: bool = False,
686
+ upcast_attention: bool = False,
687
+ num_transformer_layers: int = 1,
688
+ ):
689
+ super().__init__()
690
+ self.in_channels = in_channels
691
+ self.num_attention_heads = num_attention_heads
692
+ self.attention_head_dim = attention_head_dim
693
+ inner_dim = num_attention_heads * attention_head_dim
694
+ self.use_linear_projection = use_linear_projection
695
+
696
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
697
+ # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True)
698
+
699
+ if use_linear_projection:
700
+ self.proj_in = nn.Linear(in_channels, inner_dim)
701
+ else:
702
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
703
+
704
+ blocks = []
705
+ for _ in range(num_transformer_layers):
706
+ blocks.append(
707
+ BasicTransformerBlock(
708
+ inner_dim,
709
+ num_attention_heads,
710
+ attention_head_dim,
711
+ cross_attention_dim=cross_attention_dim,
712
+ upcast_attention=upcast_attention,
713
+ )
714
+ )
715
+
716
+ self.transformer_blocks = nn.ModuleList(blocks)
717
+
718
+ if use_linear_projection:
719
+ self.proj_out = nn.Linear(in_channels, inner_dim)
720
+ else:
721
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
722
+
723
+ self.gradient_checkpointing = False
724
+
725
+ def set_use_memory_efficient_attention(self, xformers, mem_eff):
726
+ for transformer in self.transformer_blocks:
727
+ transformer.set_use_memory_efficient_attention(xformers, mem_eff)
728
+
729
+ def set_use_sdpa(self, sdpa):
730
+ for transformer in self.transformer_blocks:
731
+ transformer.set_use_sdpa(sdpa)
732
+
733
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None):
734
+ # 1. Input
735
+ batch, _, height, weight = hidden_states.shape
736
+ residual = hidden_states
737
+
738
+ hidden_states = self.norm(hidden_states)
739
+ if not self.use_linear_projection:
740
+ hidden_states = self.proj_in(hidden_states)
741
+ inner_dim = hidden_states.shape[1]
742
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
743
+ else:
744
+ inner_dim = hidden_states.shape[1]
745
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
746
+ hidden_states = self.proj_in(hidden_states)
747
+
748
+ # 2. Blocks
749
+ for block in self.transformer_blocks:
750
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
751
+
752
+ # 3. Output
753
+ if not self.use_linear_projection:
754
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
755
+ hidden_states = self.proj_out(hidden_states)
756
+ else:
757
+ hidden_states = self.proj_out(hidden_states)
758
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
759
+
760
+ output = hidden_states + residual
761
+
762
+ return output
763
+
764
+
765
+ class Upsample2D(nn.Module):
766
+ def __init__(self, channels, out_channels):
767
+ super().__init__()
768
+ self.channels = channels
769
+ self.out_channels = out_channels
770
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
771
+
772
+ self.gradient_checkpointing = False
773
+
774
+ def forward_body(self, hidden_states, output_size=None):
775
+ assert hidden_states.shape[1] == self.channels
776
+
777
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
778
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
779
+ # https://github.com/pytorch/pytorch/issues/86679
780
+ dtype = hidden_states.dtype
781
+ if dtype == torch.bfloat16:
782
+ hidden_states = hidden_states.to(torch.float32)
783
+
784
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
785
+ if hidden_states.shape[0] >= 64:
786
+ hidden_states = hidden_states.contiguous()
787
+
788
+ # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2`
789
+ if output_size is None:
790
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
791
+ else:
792
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
793
+
794
+ # If the input is bfloat16, we cast back to bfloat16
795
+ if dtype == torch.bfloat16:
796
+ hidden_states = hidden_states.to(dtype)
797
+
798
+ hidden_states = self.conv(hidden_states)
799
+
800
+ return hidden_states
801
+
802
+ def forward(self, hidden_states, output_size=None):
803
+ if self.training and self.gradient_checkpointing:
804
+ # logger.info("Upsample2D: gradient_checkpointing")
805
+
806
+ def create_custom_forward(func):
807
+ def custom_forward(*inputs):
808
+ return func(*inputs)
809
+
810
+ return custom_forward
811
+
812
+ hidden_states = torch.utils.checkpoint.checkpoint(
813
+ create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT
814
+ )
815
+ else:
816
+ hidden_states = self.forward_body(hidden_states, output_size)
817
+
818
+ return hidden_states
819
+
820
+
821
+ class SdxlUNet2DConditionModel(nn.Module):
822
+ _supports_gradient_checkpointing = True
823
+
824
+ def __init__(
825
+ self,
826
+ **kwargs,
827
+ ):
828
+ super().__init__()
829
+
830
+ self.in_channels = IN_CHANNELS
831
+ self.out_channels = OUT_CHANNELS
832
+ self.model_channels = MODEL_CHANNELS
833
+ self.time_embed_dim = TIME_EMBED_DIM
834
+ self.adm_in_channels = ADM_IN_CHANNELS
835
+
836
+ self.gradient_checkpointing = False
837
+ # self.sample_size = sample_size
838
+
839
+ # time embedding
840
+ self.time_embed = nn.Sequential(
841
+ nn.Linear(self.model_channels, self.time_embed_dim),
842
+ nn.SiLU(),
843
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
844
+ )
845
+
846
+ # label embedding
847
+ self.label_emb = nn.Sequential(
848
+ nn.Sequential(
849
+ nn.Linear(self.adm_in_channels, self.time_embed_dim),
850
+ nn.SiLU(),
851
+ nn.Linear(self.time_embed_dim, self.time_embed_dim),
852
+ )
853
+ )
854
+
855
+ # input
856
+ self.input_blocks = nn.ModuleList(
857
+ [
858
+ nn.Sequential(
859
+ nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)),
860
+ )
861
+ ]
862
+ )
863
+
864
+ # level 0
865
+ for i in range(2):
866
+ layers = [
867
+ ResnetBlock2D(
868
+ in_channels=1 * self.model_channels,
869
+ out_channels=1 * self.model_channels,
870
+ ),
871
+ ]
872
+ self.input_blocks.append(nn.ModuleList(layers))
873
+
874
+ self.input_blocks.append(
875
+ nn.Sequential(
876
+ Downsample2D(
877
+ channels=1 * self.model_channels,
878
+ out_channels=1 * self.model_channels,
879
+ ),
880
+ )
881
+ )
882
+
883
+ # level 1
884
+ for i in range(2):
885
+ layers = [
886
+ ResnetBlock2D(
887
+ in_channels=(1 if i == 0 else 2) * self.model_channels,
888
+ out_channels=2 * self.model_channels,
889
+ ),
890
+ Transformer2DModel(
891
+ num_attention_heads=2 * self.model_channels // 64,
892
+ attention_head_dim=64,
893
+ in_channels=2 * self.model_channels,
894
+ num_transformer_layers=2,
895
+ use_linear_projection=True,
896
+ cross_attention_dim=2048,
897
+ ),
898
+ ]
899
+ self.input_blocks.append(nn.ModuleList(layers))
900
+
901
+ self.input_blocks.append(
902
+ nn.Sequential(
903
+ Downsample2D(
904
+ channels=2 * self.model_channels,
905
+ out_channels=2 * self.model_channels,
906
+ ),
907
+ )
908
+ )
909
+
910
+ # level 2
911
+ for i in range(2):
912
+ layers = [
913
+ ResnetBlock2D(
914
+ in_channels=(2 if i == 0 else 4) * self.model_channels,
915
+ out_channels=4 * self.model_channels,
916
+ ),
917
+ Transformer2DModel(
918
+ num_attention_heads=4 * self.model_channels // 64,
919
+ attention_head_dim=64,
920
+ in_channels=4 * self.model_channels,
921
+ num_transformer_layers=10,
922
+ use_linear_projection=True,
923
+ cross_attention_dim=2048,
924
+ ),
925
+ ]
926
+ self.input_blocks.append(nn.ModuleList(layers))
927
+
928
+ # mid
929
+ self.middle_block = nn.ModuleList(
930
+ [
931
+ ResnetBlock2D(
932
+ in_channels=4 * self.model_channels,
933
+ out_channels=4 * self.model_channels,
934
+ ),
935
+ Transformer2DModel(
936
+ num_attention_heads=4 * self.model_channels // 64,
937
+ attention_head_dim=64,
938
+ in_channels=4 * self.model_channels,
939
+ num_transformer_layers=10,
940
+ use_linear_projection=True,
941
+ cross_attention_dim=2048,
942
+ ),
943
+ ResnetBlock2D(
944
+ in_channels=4 * self.model_channels,
945
+ out_channels=4 * self.model_channels,
946
+ ),
947
+ ]
948
+ )
949
+
950
+ # output
951
+ self.output_blocks = nn.ModuleList([])
952
+
953
+ # level 2
954
+ for i in range(3):
955
+ layers = [
956
+ ResnetBlock2D(
957
+ in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels,
958
+ out_channels=4 * self.model_channels,
959
+ ),
960
+ Transformer2DModel(
961
+ num_attention_heads=4 * self.model_channels // 64,
962
+ attention_head_dim=64,
963
+ in_channels=4 * self.model_channels,
964
+ num_transformer_layers=10,
965
+ use_linear_projection=True,
966
+ cross_attention_dim=2048,
967
+ ),
968
+ ]
969
+ if i == 2:
970
+ layers.append(
971
+ Upsample2D(
972
+ channels=4 * self.model_channels,
973
+ out_channels=4 * self.model_channels,
974
+ )
975
+ )
976
+
977
+ self.output_blocks.append(nn.ModuleList(layers))
978
+
979
+ # level 1
980
+ for i in range(3):
981
+ layers = [
982
+ ResnetBlock2D(
983
+ in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels,
984
+ out_channels=2 * self.model_channels,
985
+ ),
986
+ Transformer2DModel(
987
+ num_attention_heads=2 * self.model_channels // 64,
988
+ attention_head_dim=64,
989
+ in_channels=2 * self.model_channels,
990
+ num_transformer_layers=2,
991
+ use_linear_projection=True,
992
+ cross_attention_dim=2048,
993
+ ),
994
+ ]
995
+ if i == 2:
996
+ layers.append(
997
+ Upsample2D(
998
+ channels=2 * self.model_channels,
999
+ out_channels=2 * self.model_channels,
1000
+ )
1001
+ )
1002
+
1003
+ self.output_blocks.append(nn.ModuleList(layers))
1004
+
1005
+ # level 0
1006
+ for i in range(3):
1007
+ layers = [
1008
+ ResnetBlock2D(
1009
+ in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels,
1010
+ out_channels=1 * self.model_channels,
1011
+ ),
1012
+ ]
1013
+
1014
+ self.output_blocks.append(nn.ModuleList(layers))
1015
+
1016
+ # output
1017
+ self.out = nn.ModuleList(
1018
+ [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
1019
+ )
1020
+
1021
+ # region diffusers compatibility
1022
+ def prepare_config(self):
1023
+ self.config = SimpleNamespace()
1024
+
1025
+ @property
1026
+ def dtype(self) -> torch.dtype:
1027
+ # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
1028
+ return get_parameter_dtype(self)
1029
+
1030
+ @property
1031
+ def device(self) -> torch.device:
1032
+ # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).
1033
+ return get_parameter_device(self)
1034
+
1035
+ def set_attention_slice(self, slice_size):
1036
+ raise NotImplementedError("Attention slicing is not supported for this model.")
1037
+
1038
+ def is_gradient_checkpointing(self) -> bool:
1039
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
1040
+
1041
+ def enable_gradient_checkpointing(self):
1042
+ self.gradient_checkpointing = True
1043
+ self.set_gradient_checkpointing(value=True)
1044
+
1045
+ def disable_gradient_checkpointing(self):
1046
+ self.gradient_checkpointing = False
1047
+ self.set_gradient_checkpointing(value=False)
1048
+
1049
+ def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None:
1050
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1051
+ for block in blocks:
1052
+ for module in block:
1053
+ if hasattr(module, "set_use_memory_efficient_attention"):
1054
+ # logger.info(module.__class__.__name__)
1055
+ module.set_use_memory_efficient_attention(xformers, mem_eff)
1056
+
1057
+ def set_use_sdpa(self, sdpa: bool) -> None:
1058
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1059
+ for block in blocks:
1060
+ for module in block:
1061
+ if hasattr(module, "set_use_sdpa"):
1062
+ module.set_use_sdpa(sdpa)
1063
+
1064
+ def set_gradient_checkpointing(self, value=False):
1065
+ blocks = self.input_blocks + [self.middle_block] + self.output_blocks
1066
+ for block in blocks:
1067
+ for module in block.modules():
1068
+ if hasattr(module, "gradient_checkpointing"):
1069
+ # logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
1070
+ module.gradient_checkpointing = value
1071
+
1072
+ # endregion
1073
+
1074
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1075
+ # broadcast timesteps to batch dimension
1076
+ timesteps = timesteps.expand(x.shape[0])
1077
+
1078
+ hs = []
1079
+ t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
1080
+ t_emb = t_emb.to(x.dtype)
1081
+ emb = self.time_embed(t_emb)
1082
+
1083
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1084
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1085
+ # assert x.dtype == self.dtype
1086
+ emb = emb + self.label_emb(y)
1087
+
1088
+ def call_module(module, h, emb, context):
1089
+ x = h
1090
+ for layer in module:
1091
+ # logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1092
+ if isinstance(layer, ResnetBlock2D):
1093
+ x = layer(x, emb)
1094
+ elif isinstance(layer, Transformer2DModel):
1095
+ x = layer(x, context)
1096
+ else:
1097
+ x = layer(x)
1098
+ return x
1099
+
1100
+ # h = x.type(self.dtype)
1101
+ h = x
1102
+
1103
+ for module in self.input_blocks:
1104
+ h = call_module(module, h, emb, context)
1105
+ hs.append(h)
1106
+
1107
+ h = call_module(self.middle_block, h, emb, context)
1108
+
1109
+ for module in self.output_blocks:
1110
+ h = torch.cat([h, hs.pop()], dim=1)
1111
+ h = call_module(module, h, emb, context)
1112
+
1113
+ h = h.type(x.dtype)
1114
+ h = call_module(self.out, h, emb, context)
1115
+
1116
+ return h
1117
+
1118
+
1119
+ class InferSdxlUNet2DConditionModel:
1120
+ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
1121
+ self.delegate = original_unet
1122
+
1123
+ # override original model's forward method: because forward is not called by `__call__`
1124
+ # overriding `__call__` is not enough, because nn.Module.forward has a special handling
1125
+ self.delegate.forward = self.forward
1126
+
1127
+ # Deep Shrink
1128
+ self.ds_depth_1 = None
1129
+ self.ds_depth_2 = None
1130
+ self.ds_timesteps_1 = None
1131
+ self.ds_timesteps_2 = None
1132
+ self.ds_ratio = None
1133
+
1134
+ # call original model's methods
1135
+ def __getattr__(self, name):
1136
+ return getattr(self.delegate, name)
1137
+
1138
+ def __call__(self, *args, **kwargs):
1139
+ return self.delegate(*args, **kwargs)
1140
+
1141
+ def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
1142
+ if ds_depth_1 is None:
1143
+ logger.info("Deep Shrink is disabled.")
1144
+ self.ds_depth_1 = None
1145
+ self.ds_timesteps_1 = None
1146
+ self.ds_depth_2 = None
1147
+ self.ds_timesteps_2 = None
1148
+ self.ds_ratio = None
1149
+ else:
1150
+ logger.info(
1151
+ f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
1152
+ )
1153
+ self.ds_depth_1 = ds_depth_1
1154
+ self.ds_timesteps_1 = ds_timesteps_1
1155
+ self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
1156
+ self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
1157
+ self.ds_ratio = ds_ratio
1158
+
1159
+ def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
1160
+ r"""
1161
+ current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
1162
+ """
1163
+ _self = self.delegate
1164
+
1165
+ # broadcast timesteps to batch dimension
1166
+ timesteps = timesteps.expand(x.shape[0])
1167
+
1168
+ hs = []
1169
+ t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
1170
+ t_emb = t_emb.to(x.dtype)
1171
+ emb = _self.time_embed(t_emb)
1172
+
1173
+ assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
1174
+ assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
1175
+ # assert x.dtype == _self.dtype
1176
+ emb = emb + _self.label_emb(y)
1177
+
1178
+ def call_module(module, h, emb, context):
1179
+ x = h
1180
+ for layer in module:
1181
+ # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
1182
+ if isinstance(layer, ResnetBlock2D):
1183
+ x = layer(x, emb)
1184
+ elif isinstance(layer, Transformer2DModel):
1185
+ x = layer(x, context)
1186
+ else:
1187
+ x = layer(x)
1188
+ return x
1189
+
1190
+ # h = x.type(self.dtype)
1191
+ h = x
1192
+
1193
+ for depth, module in enumerate(_self.input_blocks):
1194
+ # Deep Shrink
1195
+ if self.ds_depth_1 is not None:
1196
+ if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
1197
+ self.ds_depth_2 is not None
1198
+ and depth == self.ds_depth_2
1199
+ and timesteps[0] < self.ds_timesteps_1
1200
+ and timesteps[0] >= self.ds_timesteps_2
1201
+ ):
1202
+ # print("downsample", h.shape, self.ds_ratio)
1203
+ org_dtype = h.dtype
1204
+ if org_dtype == torch.bfloat16:
1205
+ h = h.to(torch.float32)
1206
+ h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
1207
+
1208
+ h = call_module(module, h, emb, context)
1209
+ hs.append(h)
1210
+
1211
+ h = call_module(_self.middle_block, h, emb, context)
1212
+ if mid_add is not None:
1213
+ h = h + mid_add
1214
+
1215
+ for module in _self.output_blocks:
1216
+ # Deep Shrink
1217
+ if self.ds_depth_1 is not None:
1218
+ if hs[-1].shape[-2:] != h.shape[-2:]:
1219
+ # print("upsample", h.shape, hs[-1].shape)
1220
+ h = resize_like(h, hs[-1])
1221
+
1222
+ resi = hs.pop()
1223
+ if input_resi_add is not None:
1224
+ resi = resi + input_resi_add.pop()
1225
+
1226
+ h = torch.cat([h, resi], dim=1)
1227
+ h = call_module(module, h, emb, context)
1228
+
1229
+ # Deep Shrink: in case of depth 0
1230
+ if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
1231
+ # print("upsample", h.shape, x.shape)
1232
+ h = resize_like(h, x)
1233
+
1234
+ h = h.type(x.dtype)
1235
+ h = call_module(_self.out, h, emb, context)
1236
+
1237
+ return h
1238
+
1239
+
1240
+ if __name__ == "__main__":
1241
+ import time
1242
+
1243
+ logger.info("create unet")
1244
+ unet = SdxlUNet2DConditionModel()
1245
+
1246
+ unet.to("cuda")
1247
+ unet.set_use_memory_efficient_attention(True, False)
1248
+ unet.set_gradient_checkpointing(True)
1249
+ unet.train()
1250
+
1251
+ # 使用メモリ量確認用の疑似学習ループ
1252
+ logger.info("preparing optimizer")
1253
+
1254
+ # optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
1255
+
1256
+ # import bitsandbytes
1257
+ # optimizer = bitsandbytes.adam.Adam8bit(unet.parameters(), lr=1e-3) # not working
1258
+ # optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1259
+ # optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
1260
+
1261
+ import transformers
1262
+
1263
+ optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
1264
+
1265
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
1266
+
1267
+ logger.info("start training")
1268
+ steps = 10
1269
+ batch_size = 1
1270
+
1271
+ for step in range(steps):
1272
+ logger.info(f"step {step}")
1273
+ if step == 1:
1274
+ time_start = time.perf_counter()
1275
+
1276
+ x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
1277
+ t = torch.randint(low=0, high=10, size=(batch_size,), device="cuda")
1278
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
1279
+ y = torch.randn(batch_size, ADM_IN_CHANNELS).cuda()
1280
+
1281
+ with torch.cuda.amp.autocast(enabled=True):
1282
+ output = unet(x, t, ctx, y)
1283
+ target = torch.randn_like(output)
1284
+ loss = torch.nn.functional.mse_loss(output, target)
1285
+
1286
+ scaler.scale(loss).backward()
1287
+ scaler.step(optimizer)
1288
+ scaler.update()
1289
+ optimizer.zero_grad(set_to_none=True)
1290
+
1291
+ time_end = time.perf_counter()
1292
+ logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
library/sdxl_train_util.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from library.device_utils import init_ipex, clean_memory_on_device
8
+
9
+ init_ipex()
10
+
11
+ from accelerate import init_empty_weights
12
+ from tqdm import tqdm
13
+ from transformers import CLIPTokenizer
14
+ from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
15
+ from .utils import setup_logging
16
+
17
+ setup_logging()
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
23
+ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
24
+
25
+ # DEFAULT_NOISE_OFFSET = 0.0357
26
+
27
+
28
+ def load_target_model(args, accelerator, model_version: str, weight_dtype):
29
+ model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
30
+ for pi in range(accelerator.state.num_processes):
31
+ if pi == accelerator.state.local_process_index:
32
+ logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
33
+
34
+ (
35
+ load_stable_diffusion_format,
36
+ text_encoder1,
37
+ text_encoder2,
38
+ vae,
39
+ unet,
40
+ logit_scale,
41
+ ckpt_info,
42
+ ) = _load_target_model(
43
+ args.pretrained_model_name_or_path,
44
+ args.vae,
45
+ model_version,
46
+ weight_dtype,
47
+ accelerator.device if args.lowram else "cpu",
48
+ model_dtype,
49
+ args.disable_mmap_load_safetensors,
50
+ )
51
+
52
+ # work on low-ram device
53
+ if args.lowram:
54
+ text_encoder1.to(accelerator.device)
55
+ text_encoder2.to(accelerator.device)
56
+ unet.to(accelerator.device)
57
+ vae.to(accelerator.device)
58
+
59
+ clean_memory_on_device(accelerator.device)
60
+ accelerator.wait_for_everyone()
61
+
62
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
63
+
64
+
65
+ def _load_target_model(
66
+ name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
67
+ ):
68
+ # model_dtype only work with full fp16/bf16
69
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
70
+ load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
71
+
72
+ if load_stable_diffusion_format:
73
+ logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
74
+ (
75
+ text_encoder1,
76
+ text_encoder2,
77
+ vae,
78
+ unet,
79
+ logit_scale,
80
+ ckpt_info,
81
+ ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
82
+ else:
83
+ # Diffusers model is loaded to CPU
84
+ from diffusers import StableDiffusionXLPipeline
85
+
86
+ variant = "fp16" if weight_dtype == torch.float16 else None
87
+ logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
88
+ try:
89
+ try:
90
+ pipe = StableDiffusionXLPipeline.from_pretrained(
91
+ name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
92
+ )
93
+ except EnvironmentError as ex:
94
+ if variant is not None:
95
+ logger.info("try to load fp32 model")
96
+ pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
97
+ else:
98
+ raise ex
99
+ except EnvironmentError as ex:
100
+ logger.error(
101
+ f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
102
+ )
103
+ raise ex
104
+
105
+ text_encoder1 = pipe.text_encoder
106
+ text_encoder2 = pipe.text_encoder_2
107
+
108
+ # convert to fp32 for cache text_encoders outputs
109
+ if text_encoder1.dtype != torch.float32:
110
+ text_encoder1 = text_encoder1.to(dtype=torch.float32)
111
+ if text_encoder2.dtype != torch.float32:
112
+ text_encoder2 = text_encoder2.to(dtype=torch.float32)
113
+
114
+ vae = pipe.vae
115
+ unet = pipe.unet
116
+ del pipe
117
+
118
+ # Diffusers U-Net to original U-Net
119
+ state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
120
+ with init_empty_weights():
121
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
122
+ sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
123
+ logger.info("U-Net converted to original U-Net")
124
+
125
+ logit_scale = None
126
+ ckpt_info = None
127
+
128
+ # VAEを読み込む
129
+ if vae_path is not None:
130
+ vae = model_util.load_vae(vae_path, weight_dtype)
131
+ logger.info("additional VAE loaded")
132
+
133
+ return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
134
+
135
+
136
+ def load_tokenizers(args: argparse.Namespace):
137
+ logger.info("prepare tokenizers")
138
+
139
+ original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
140
+ tokeniers = []
141
+ for i, original_path in enumerate(original_paths):
142
+ tokenizer: CLIPTokenizer = None
143
+ if args.tokenizer_cache_dir:
144
+ local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
145
+ if os.path.exists(local_tokenizer_path):
146
+ logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
147
+ tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
148
+
149
+ if tokenizer is None:
150
+ tokenizer = CLIPTokenizer.from_pretrained(original_path)
151
+
152
+ if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
153
+ logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
154
+ tokenizer.save_pretrained(local_tokenizer_path)
155
+
156
+ if i == 1:
157
+ tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
158
+
159
+ tokeniers.append(tokenizer)
160
+
161
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
162
+ logger.info(f"update token length: {args.max_token_length}")
163
+
164
+ return tokeniers
165
+
166
+
167
+ def match_mixed_precision(args, weight_dtype):
168
+ if args.full_fp16:
169
+ assert (
170
+ weight_dtype == torch.float16
171
+ ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
172
+ return weight_dtype
173
+ elif args.full_bf16:
174
+ assert (
175
+ weight_dtype == torch.bfloat16
176
+ ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
177
+ return weight_dtype
178
+ else:
179
+ return None
180
+
181
+
182
+ def timestep_embedding(timesteps, dim, max_period=10000):
183
+ """
184
+ Create sinusoidal timestep embeddings.
185
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
186
+ These may be fractional.
187
+ :param dim: the dimension of the output.
188
+ :param max_period: controls the minimum frequency of the embeddings.
189
+ :return: an [N x dim] Tensor of positional embeddings.
190
+ """
191
+ half = dim // 2
192
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
193
+ device=timesteps.device
194
+ )
195
+ args = timesteps[:, None].float() * freqs[None]
196
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
197
+ if dim % 2:
198
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
199
+ return embedding
200
+
201
+
202
+ def get_timestep_embedding(x, outdim):
203
+ assert len(x.shape) == 2
204
+ b, dims = x.shape[0], x.shape[1]
205
+ x = torch.flatten(x)
206
+ emb = timestep_embedding(x, outdim)
207
+ emb = torch.reshape(emb, (b, dims * outdim))
208
+ return emb
209
+
210
+
211
+ def get_size_embeddings(orig_size, crop_size, target_size, device):
212
+ emb1 = get_timestep_embedding(orig_size, 256)
213
+ emb2 = get_timestep_embedding(crop_size, 256)
214
+ emb3 = get_timestep_embedding(target_size, 256)
215
+ vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
216
+ return vector
217
+
218
+
219
+ def save_sd_model_on_train_end(
220
+ args: argparse.Namespace,
221
+ src_path: str,
222
+ save_stable_diffusion_format: bool,
223
+ use_safetensors: bool,
224
+ save_dtype: torch.dtype,
225
+ epoch: int,
226
+ global_step: int,
227
+ text_encoder1,
228
+ text_encoder2,
229
+ unet,
230
+ vae,
231
+ logit_scale,
232
+ ckpt_info,
233
+ ):
234
+ def sd_saver(ckpt_file, epoch_no, global_step):
235
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
236
+ sdxl_model_util.save_stable_diffusion_checkpoint(
237
+ ckpt_file,
238
+ text_encoder1,
239
+ text_encoder2,
240
+ unet,
241
+ epoch_no,
242
+ global_step,
243
+ ckpt_info,
244
+ vae,
245
+ logit_scale,
246
+ sai_metadata,
247
+ save_dtype,
248
+ )
249
+
250
+ def diffusers_saver(out_dir):
251
+ sdxl_model_util.save_diffusers_checkpoint(
252
+ out_dir,
253
+ text_encoder1,
254
+ text_encoder2,
255
+ unet,
256
+ src_path,
257
+ vae,
258
+ use_safetensors=use_safetensors,
259
+ save_dtype=save_dtype,
260
+ )
261
+
262
+ train_util.save_sd_model_on_train_end_common(
263
+ args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
264
+ )
265
+
266
+
267
+ # epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
268
+ # on_epoch_end: Trueならepoch終了時、Falseならstep経過時
269
+ def save_sd_model_on_epoch_end_or_stepwise(
270
+ args: argparse.Namespace,
271
+ on_epoch_end: bool,
272
+ accelerator,
273
+ src_path,
274
+ save_stable_diffusion_format: bool,
275
+ use_safetensors: bool,
276
+ save_dtype: torch.dtype,
277
+ epoch: int,
278
+ num_train_epochs: int,
279
+ global_step: int,
280
+ text_encoder1,
281
+ text_encoder2,
282
+ unet,
283
+ vae,
284
+ logit_scale,
285
+ ckpt_info,
286
+ ):
287
+ def sd_saver(ckpt_file, epoch_no, global_step):
288
+ sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
289
+ sdxl_model_util.save_stable_diffusion_checkpoint(
290
+ ckpt_file,
291
+ text_encoder1,
292
+ text_encoder2,
293
+ unet,
294
+ epoch_no,
295
+ global_step,
296
+ ckpt_info,
297
+ vae,
298
+ logit_scale,
299
+ sai_metadata,
300
+ save_dtype,
301
+ )
302
+
303
+ def diffusers_saver(out_dir):
304
+ sdxl_model_util.save_diffusers_checkpoint(
305
+ out_dir,
306
+ text_encoder1,
307
+ text_encoder2,
308
+ unet,
309
+ src_path,
310
+ vae,
311
+ use_safetensors=use_safetensors,
312
+ save_dtype=save_dtype,
313
+ )
314
+
315
+ train_util.save_sd_model_on_epoch_end_or_stepwise_common(
316
+ args,
317
+ on_epoch_end,
318
+ accelerator,
319
+ save_stable_diffusion_format,
320
+ use_safetensors,
321
+ epoch,
322
+ num_train_epochs,
323
+ global_step,
324
+ sd_saver,
325
+ diffusers_saver,
326
+ )
327
+
328
+
329
+ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
330
+ parser.add_argument(
331
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
332
+ )
333
+ parser.add_argument(
334
+ "--cache_text_encoder_outputs_to_disk",
335
+ action="store_true",
336
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
337
+ )
338
+ parser.add_argument(
339
+ "--disable_mmap_load_safetensors",
340
+ action="store_true",
341
+ help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
342
+ )
343
+
344
+
345
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
346
+ assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
347
+ if args.v_parameterization:
348
+ logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
349
+
350
+ if args.clip_skip is not None:
351
+ logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
352
+
353
+ # if args.multires_noise_iterations:
354
+ # logger.info(
355
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
356
+ # )
357
+ # else:
358
+ # if args.noise_offset is None:
359
+ # args.noise_offset = DEFAULT_NOISE_OFFSET
360
+ # elif args.noise_offset != DEFAULT_NOISE_OFFSET:
361
+ # logger.info(
362
+ # f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
363
+ # )
364
+ # logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
365
+
366
+ # assert (
367
+ # not hasattr(args, "weighted_captions") or not args.weighted_captions
368
+ # ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
369
+
370
+ if supportTextEncoderCaching:
371
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
372
+ args.cache_text_encoder_outputs = True
373
+ logger.warning(
374
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
375
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
376
+ )
377
+
378
+
379
+ def sample_images(*args, **kwargs):
380
+ from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
381
+
382
+ return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
library/slicing_vae.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Diffusers to reduce VRAM usage
2
+
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
27
+ from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
28
+ from diffusers.models.autoencoder_kl import AutoencoderKLOutput
29
+ from .utils import setup_logging
30
+ setup_logging()
31
+ import logging
32
+ logger = logging.getLogger(__name__)
33
+
34
+ def slice_h(x, num_slices):
35
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
36
+ # Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
37
+ # NCHWでもNHWCでもどちらでも動く
38
+ size = (x.shape[2] + num_slices - 1) // num_slices
39
+ sliced = []
40
+ for i in range(num_slices):
41
+ if i == 0:
42
+ sliced.append(x[:, :, : size + 1, :])
43
+ else:
44
+ end = size * (i + 1) + 1
45
+ if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
46
+ end = x.shape[2]
47
+ sliced.append(x[:, :, size * i - 1 : end, :])
48
+ if end >= x.shape[2]:
49
+ break
50
+ return sliced
51
+
52
+
53
+ def cat_h(sliced):
54
+ # padding分を除いて結合する
55
+ cat = []
56
+ for i, x in enumerate(sliced):
57
+ if i == 0:
58
+ cat.append(x[:, :, :-1, :])
59
+ elif i == len(sliced) - 1:
60
+ cat.append(x[:, :, 1:, :])
61
+ else:
62
+ cat.append(x[:, :, 1:-1, :])
63
+ del x
64
+ x = torch.cat(cat, dim=2)
65
+ return x
66
+
67
+
68
+ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
69
+ assert _self.upsample is None and _self.downsample is None
70
+ assert _self.norm1.num_groups == _self.norm2.num_groups
71
+ assert temb is None
72
+
73
+ # make sure norms are on cpu
74
+ org_device = input_tensor.device
75
+ cpu_device = torch.device("cpu")
76
+ _self.norm1.to(cpu_device)
77
+ _self.norm2.to(cpu_device)
78
+
79
+ # GroupNormがCPUでfp16で動かない対策
80
+ org_dtype = input_tensor.dtype
81
+ if org_dtype == torch.float16:
82
+ _self.norm1.to(torch.float32)
83
+ _self.norm2.to(torch.float32)
84
+
85
+ # すべてのテンソルをCPUに移動する
86
+ input_tensor = input_tensor.to(cpu_device)
87
+ hidden_states = input_tensor
88
+
89
+ # どうもこれは結果が異なるようだ……
90
+ # def sliced_norm1(norm, x):
91
+ # num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
92
+ # sliced_tensor = torch.chunk(x, num_div, dim=1)
93
+ # sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
94
+ # sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
95
+ # logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
96
+ # normed_tensor = []
97
+ # for i in range(num_div):
98
+ # n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
99
+ # normed_tensor.append(n)
100
+ # del n
101
+ # x = torch.cat(normed_tensor, dim=1)
102
+ # return num_div, x
103
+
104
+ # normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
105
+ if org_dtype == torch.float16:
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ hidden_states = _self.norm1(hidden_states) # run on cpu
108
+ if org_dtype == torch.float16:
109
+ hidden_states = hidden_states.to(torch.float16)
110
+
111
+ sliced = slice_h(hidden_states, num_slices)
112
+ del hidden_states
113
+
114
+ for i in range(len(sliced)):
115
+ x = sliced[i]
116
+ sliced[i] = None
117
+
118
+ # 計算する部分だけGPUに移動する、以下同様
119
+ x = x.to(org_device)
120
+ x = _self.nonlinearity(x)
121
+ x = _self.conv1(x)
122
+ x = x.to(cpu_device)
123
+ sliced[i] = x
124
+ del x
125
+
126
+ hidden_states = cat_h(sliced)
127
+ del sliced
128
+
129
+ if org_dtype == torch.float16:
130
+ hidden_states = hidden_states.to(torch.float32)
131
+ hidden_states = _self.norm2(hidden_states) # run on cpu
132
+ if org_dtype == torch.float16:
133
+ hidden_states = hidden_states.to(torch.float16)
134
+
135
+ sliced = slice_h(hidden_states, num_slices)
136
+ del hidden_states
137
+
138
+ for i in range(len(sliced)):
139
+ x = sliced[i]
140
+ sliced[i] = None
141
+
142
+ x = x.to(org_device)
143
+ x = _self.nonlinearity(x)
144
+ x = _self.dropout(x)
145
+ x = _self.conv2(x)
146
+ x = x.to(cpu_device)
147
+ sliced[i] = x
148
+ del x
149
+
150
+ hidden_states = cat_h(sliced)
151
+ del sliced
152
+
153
+ # make shortcut
154
+ if _self.conv_shortcut is not None:
155
+ sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
156
+ del input_tensor
157
+
158
+ for i in range(len(sliced)):
159
+ x = sliced[i]
160
+ sliced[i] = None
161
+
162
+ x = x.to(org_device)
163
+ x = _self.conv_shortcut(x)
164
+ x = x.to(cpu_device)
165
+ sliced[i] = x
166
+ del x
167
+
168
+ input_tensor = torch.cat(sliced, dim=2)
169
+ del sliced
170
+
171
+ output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
172
+
173
+ output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
174
+ return output_tensor
175
+
176
+
177
+ class SlicingEncoder(nn.Module):
178
+ def __init__(
179
+ self,
180
+ in_channels=3,
181
+ out_channels=3,
182
+ down_block_types=("DownEncoderBlock2D",),
183
+ block_out_channels=(64,),
184
+ layers_per_block=2,
185
+ norm_num_groups=32,
186
+ act_fn="silu",
187
+ double_z=True,
188
+ num_slices=2,
189
+ ):
190
+ super().__init__()
191
+ self.layers_per_block = layers_per_block
192
+
193
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
194
+
195
+ self.mid_block = None
196
+ self.down_blocks = nn.ModuleList([])
197
+
198
+ # down
199
+ output_channel = block_out_channels[0]
200
+ for i, down_block_type in enumerate(down_block_types):
201
+ input_channel = output_channel
202
+ output_channel = block_out_channels[i]
203
+ is_final_block = i == len(block_out_channels) - 1
204
+
205
+ down_block = get_down_block(
206
+ down_block_type,
207
+ num_layers=self.layers_per_block,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ add_downsample=not is_final_block,
211
+ resnet_eps=1e-6,
212
+ downsample_padding=0,
213
+ resnet_act_fn=act_fn,
214
+ resnet_groups=norm_num_groups,
215
+ attention_head_dim=output_channel,
216
+ temb_channels=None,
217
+ )
218
+ self.down_blocks.append(down_block)
219
+
220
+ # mid
221
+ self.mid_block = UNetMidBlock2D(
222
+ in_channels=block_out_channels[-1],
223
+ resnet_eps=1e-6,
224
+ resnet_act_fn=act_fn,
225
+ output_scale_factor=1,
226
+ resnet_time_scale_shift="default",
227
+ attention_head_dim=block_out_channels[-1],
228
+ resnet_groups=norm_num_groups,
229
+ temb_channels=None,
230
+ )
231
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
232
+
233
+ # out
234
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
235
+ self.conv_act = nn.SiLU()
236
+
237
+ conv_out_channels = 2 * out_channels if double_z else out_channels
238
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
239
+
240
+ # replace forward of ResBlocks
241
+ def wrapper(func, module, num_slices):
242
+ def forward(*args, **kwargs):
243
+ return func(module, num_slices, *args, **kwargs)
244
+
245
+ return forward
246
+
247
+ self.num_slices = num_slices
248
+ div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
249
+ # logger.info(f"initial divisor: {div}")
250
+ if div >= 2:
251
+ div = int(div)
252
+ for resnet in self.mid_block.resnets:
253
+ resnet.forward = wrapper(resblock_forward, resnet, div)
254
+ # midblock doesn't have downsample
255
+
256
+ for i, down_block in enumerate(self.down_blocks[::-1]):
257
+ if div >= 2:
258
+ div = int(div)
259
+ # logger.info(f"down block: {i} divisor: {div}")
260
+ for resnet in down_block.resnets:
261
+ resnet.forward = wrapper(resblock_forward, resnet, div)
262
+ if down_block.downsamplers is not None:
263
+ # logger.info("has downsample")
264
+ for downsample in down_block.downsamplers:
265
+ downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
266
+ div *= 2
267
+
268
+ def forward(self, x):
269
+ sample = x
270
+ del x
271
+
272
+ org_device = sample.device
273
+ cpu_device = torch.device("cpu")
274
+
275
+ # sample = self.conv_in(sample)
276
+ sample = sample.to(cpu_device)
277
+ sliced = slice_h(sample, self.num_slices)
278
+ del sample
279
+
280
+ for i in range(len(sliced)):
281
+ x = sliced[i]
282
+ sliced[i] = None
283
+
284
+ x = x.to(org_device)
285
+ x = self.conv_in(x)
286
+ x = x.to(cpu_device)
287
+ sliced[i] = x
288
+ del x
289
+
290
+ sample = cat_h(sliced)
291
+ del sliced
292
+
293
+ sample = sample.to(org_device)
294
+
295
+ # down
296
+ for down_block in self.down_blocks:
297
+ sample = down_block(sample)
298
+
299
+ # middle
300
+ sample = self.mid_block(sample)
301
+
302
+ # post-process
303
+ # ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
304
+ sample = self.conv_norm_out(sample)
305
+ sample = self.conv_act(sample)
306
+ sample = self.conv_out(sample)
307
+
308
+ return sample
309
+
310
+ def downsample_forward(self, _self, num_slices, hidden_states):
311
+ assert hidden_states.shape[1] == _self.channels
312
+ assert _self.use_conv and _self.padding == 0
313
+ logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
314
+
315
+ org_device = hidden_states.device
316
+ cpu_device = torch.device("cpu")
317
+
318
+ hidden_states = hidden_states.to(cpu_device)
319
+ pad = (0, 1, 0, 1)
320
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
321
+
322
+ # slice with even number because of stride 2
323
+ # strideが2なので偶数でスライスする
324
+ # slice with pad 1 both sides: to eliminate side effect of padding of conv2d
325
+ size = (hidden_states.shape[2] + num_slices - 1) // num_slices
326
+ size = size + 1 if size % 2 == 1 else size
327
+
328
+ sliced = []
329
+ for i in range(num_slices):
330
+ if i == 0:
331
+ sliced.append(hidden_states[:, :, : size + 1, :])
332
+ else:
333
+ end = size * (i + 1) + 1
334
+ if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
335
+ end = hidden_states.shape[2]
336
+ sliced.append(hidden_states[:, :, size * i - 1 : end, :])
337
+ if end >= hidden_states.shape[2]:
338
+ break
339
+ del hidden_states
340
+
341
+ for i in range(len(sliced)):
342
+ x = sliced[i]
343
+ sliced[i] = None
344
+
345
+ x = x.to(org_device)
346
+ x = _self.conv(x)
347
+ x = x.to(cpu_device)
348
+
349
+ # ここだけ雰囲気が違うのはCopilotのせい
350
+ if i == 0:
351
+ hidden_states = x
352
+ else:
353
+ hidden_states = torch.cat([hidden_states, x], dim=2)
354
+
355
+ hidden_states = hidden_states.to(org_device)
356
+ # logger.info(f"downsample forward done {hidden_states.shape}")
357
+ return hidden_states
358
+
359
+
360
+ class SlicingDecoder(nn.Module):
361
+ def __init__(
362
+ self,
363
+ in_channels=3,
364
+ out_channels=3,
365
+ up_block_types=("UpDecoderBlock2D",),
366
+ block_out_channels=(64,),
367
+ layers_per_block=2,
368
+ norm_num_groups=32,
369
+ act_fn="silu",
370
+ num_slices=2,
371
+ ):
372
+ super().__init__()
373
+ self.layers_per_block = layers_per_block
374
+
375
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
376
+
377
+ self.mid_block = None
378
+ self.up_blocks = nn.ModuleList([])
379
+
380
+ # mid
381
+ self.mid_block = UNetMidBlock2D(
382
+ in_channels=block_out_channels[-1],
383
+ resnet_eps=1e-6,
384
+ resnet_act_fn=act_fn,
385
+ output_scale_factor=1,
386
+ resnet_time_scale_shift="default",
387
+ attention_head_dim=block_out_channels[-1],
388
+ resnet_groups=norm_num_groups,
389
+ temb_channels=None,
390
+ )
391
+ self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
392
+
393
+ # up
394
+ reversed_block_out_channels = list(reversed(block_out_channels))
395
+ output_channel = reversed_block_out_channels[0]
396
+ for i, up_block_type in enumerate(up_block_types):
397
+ prev_output_channel = output_channel
398
+ output_channel = reversed_block_out_channels[i]
399
+
400
+ is_final_block = i == len(block_out_channels) - 1
401
+
402
+ up_block = get_up_block(
403
+ up_block_type,
404
+ num_layers=self.layers_per_block + 1,
405
+ in_channels=prev_output_channel,
406
+ out_channels=output_channel,
407
+ prev_output_channel=None,
408
+ add_upsample=not is_final_block,
409
+ resnet_eps=1e-6,
410
+ resnet_act_fn=act_fn,
411
+ resnet_groups=norm_num_groups,
412
+ attention_head_dim=output_channel,
413
+ temb_channels=None,
414
+ )
415
+ self.up_blocks.append(up_block)
416
+ prev_output_channel = output_channel
417
+
418
+ # out
419
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
420
+ self.conv_act = nn.SiLU()
421
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
422
+
423
+ # replace forward of ResBlocks
424
+ def wrapper(func, module, num_slices):
425
+ def forward(*args, **kwargs):
426
+ return func(module, num_slices, *args, **kwargs)
427
+
428
+ return forward
429
+
430
+ self.num_slices = num_slices
431
+ div = num_slices / (2 ** (len(self.up_blocks) - 1))
432
+ logger.info(f"initial divisor: {div}")
433
+ if div >= 2:
434
+ div = int(div)
435
+ for resnet in self.mid_block.resnets:
436
+ resnet.forward = wrapper(resblock_forward, resnet, div)
437
+ # midblock doesn't have upsample
438
+
439
+ for i, up_block in enumerate(self.up_blocks):
440
+ if div >= 2:
441
+ div = int(div)
442
+ # logger.info(f"up block: {i} divisor: {div}")
443
+ for resnet in up_block.resnets:
444
+ resnet.forward = wrapper(resblock_forward, resnet, div)
445
+ if up_block.upsamplers is not None:
446
+ # logger.info("has upsample")
447
+ for upsample in up_block.upsamplers:
448
+ upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
449
+ div *= 2
450
+
451
+ def forward(self, z):
452
+ sample = z
453
+ del z
454
+ sample = self.conv_in(sample)
455
+
456
+ # middle
457
+ sample = self.mid_block(sample)
458
+
459
+ # up
460
+ for i, up_block in enumerate(self.up_blocks):
461
+ sample = up_block(sample)
462
+
463
+ # post-process
464
+ sample = self.conv_norm_out(sample)
465
+ sample = self.conv_act(sample)
466
+
467
+ # conv_out with slicing because of VRAM usage
468
+ # conv_outはとてもVRAM使うのでスライスして対応
469
+ org_device = sample.device
470
+ cpu_device = torch.device("cpu")
471
+ sample = sample.to(cpu_device)
472
+
473
+ sliced = slice_h(sample, self.num_slices)
474
+ del sample
475
+ for i in range(len(sliced)):
476
+ x = sliced[i]
477
+ sliced[i] = None
478
+
479
+ x = x.to(org_device)
480
+ x = self.conv_out(x)
481
+ x = x.to(cpu_device)
482
+ sliced[i] = x
483
+ sample = cat_h(sliced)
484
+ del sliced
485
+
486
+ sample = sample.to(org_device)
487
+ return sample
488
+
489
+ def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
490
+ assert hidden_states.shape[1] == _self.channels
491
+ assert _self.use_conv_transpose == False and _self.use_conv
492
+
493
+ org_dtype = hidden_states.dtype
494
+ org_device = hidden_states.device
495
+ cpu_device = torch.device("cpu")
496
+
497
+ hidden_states = hidden_states.to(cpu_device)
498
+ sliced = slice_h(hidden_states, num_slices)
499
+ del hidden_states
500
+
501
+ for i in range(len(sliced)):
502
+ x = sliced[i]
503
+ sliced[i] = None
504
+
505
+ x = x.to(org_device)
506
+
507
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
508
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
509
+ # https://github.com/pytorch/pytorch/issues/86679
510
+ # PyTorch 2で直らないかね……
511
+ if org_dtype == torch.bfloat16:
512
+ x = x.to(torch.float32)
513
+
514
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
515
+
516
+ if org_dtype == torch.bfloat16:
517
+ x = x.to(org_dtype)
518
+
519
+ x = _self.conv(x)
520
+
521
+ # upsampleされてるのでpadは2になる
522
+ if i == 0:
523
+ x = x[:, :, :-2, :]
524
+ elif i == num_slices - 1:
525
+ x = x[:, :, 2:, :]
526
+ else:
527
+ x = x[:, :, 2:-2, :]
528
+
529
+ x = x.to(cpu_device)
530
+ sliced[i] = x
531
+ del x
532
+
533
+ hidden_states = torch.cat(sliced, dim=2)
534
+ # logger.info(f"us hidden_states {hidden_states.shape}")
535
+ del sliced
536
+
537
+ hidden_states = hidden_states.to(org_device)
538
+ return hidden_states
539
+
540
+
541
+ class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
542
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
543
+ and Max Welling.
544
+
545
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
546
+ implements for all the model (such as downloading or saving, etc.)
547
+
548
+ Parameters:
549
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
550
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
551
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
552
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
553
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
554
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
555
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
556
+ obj:`(64,)`): Tuple of block output channels.
557
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
558
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
559
+ sample_size (`int`, *optional*, defaults to `32`): TODO
560
+ """
561
+
562
+ @register_to_config
563
+ def __init__(
564
+ self,
565
+ in_channels: int = 3,
566
+ out_channels: int = 3,
567
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
568
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
569
+ block_out_channels: Tuple[int] = (64,),
570
+ layers_per_block: int = 1,
571
+ act_fn: str = "silu",
572
+ latent_channels: int = 4,
573
+ norm_num_groups: int = 32,
574
+ sample_size: int = 32,
575
+ num_slices: int = 16,
576
+ ):
577
+ super().__init__()
578
+
579
+ # pass init params to Encoder
580
+ self.encoder = SlicingEncoder(
581
+ in_channels=in_channels,
582
+ out_channels=latent_channels,
583
+ down_block_types=down_block_types,
584
+ block_out_channels=block_out_channels,
585
+ layers_per_block=layers_per_block,
586
+ act_fn=act_fn,
587
+ norm_num_groups=norm_num_groups,
588
+ double_z=True,
589
+ num_slices=num_slices,
590
+ )
591
+
592
+ # pass init params to Decoder
593
+ self.decoder = SlicingDecoder(
594
+ in_channels=latent_channels,
595
+ out_channels=out_channels,
596
+ up_block_types=up_block_types,
597
+ block_out_channels=block_out_channels,
598
+ layers_per_block=layers_per_block,
599
+ norm_num_groups=norm_num_groups,
600
+ act_fn=act_fn,
601
+ num_slices=num_slices,
602
+ )
603
+
604
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
605
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
606
+ self.use_slicing = False
607
+
608
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
609
+ h = self.encoder(x)
610
+ moments = self.quant_conv(h)
611
+ posterior = DiagonalGaussianDistribution(moments)
612
+
613
+ if not return_dict:
614
+ return (posterior,)
615
+
616
+ return AutoencoderKLOutput(latent_dist=posterior)
617
+
618
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
619
+ z = self.post_quant_conv(z)
620
+ dec = self.decoder(z)
621
+
622
+ if not return_dict:
623
+ return (dec,)
624
+
625
+ return DecoderOutput(sample=dec)
626
+
627
+ # これはバッチ方向のスライシング 紛らわしい
628
+ def enable_slicing(self):
629
+ r"""
630
+ Enable sliced VAE decoding.
631
+
632
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
633
+ steps. This is useful to save some memory and allow larger batch sizes.
634
+ """
635
+ self.use_slicing = True
636
+
637
+ def disable_slicing(self):
638
+ r"""
639
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
640
+ decoding in one step.
641
+ """
642
+ self.use_slicing = False
643
+
644
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
645
+ if self.use_slicing and z.shape[0] > 1:
646
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
647
+ decoded = torch.cat(decoded_slices)
648
+ else:
649
+ decoded = self._decode(z).sample
650
+
651
+ if not return_dict:
652
+ return (decoded,)
653
+
654
+ return DecoderOutput(sample=decoded)
655
+
656
+ def forward(
657
+ self,
658
+ sample: torch.FloatTensor,
659
+ sample_posterior: bool = False,
660
+ return_dict: bool = True,
661
+ generator: Optional[torch.Generator] = None,
662
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
663
+ r"""
664
+ Args:
665
+ sample (`torch.FloatTensor`): Input sample.
666
+ sample_posterior (`bool`, *optional*, defaults to `False`):
667
+ Whether to sample from the posterior.
668
+ return_dict (`bool`, *optional*, defaults to `True`):
669
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
670
+ """
671
+ x = sample
672
+ posterior = self.encode(x).latent_dist
673
+ if sample_posterior:
674
+ z = posterior.sample(generator=generator)
675
+ else:
676
+ z = posterior.mode()
677
+ dec = self.decode(z).sample
678
+
679
+ if not return_dict:
680
+ return (dec,)
681
+
682
+ return DecoderOutput(sample=dec)
library/strategy_base.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base class for platform strategies. this file defines the interface for strategies
2
+
3
+ import os
4
+ import re
5
+ from typing import Any, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
10
+
11
+
12
+ # TODO remove circular import by moving ImageInfo to a separate file
13
+ # from library.train_util import ImageInfo
14
+
15
+ from library.utils import setup_logging
16
+
17
+ setup_logging()
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class TokenizeStrategy:
24
+ _strategy = None # strategy instance: actual strategy class
25
+
26
+ _re_attention = re.compile(
27
+ r"""\\\(|
28
+ \\\)|
29
+ \\\[|
30
+ \\]|
31
+ \\\\|
32
+ \\|
33
+ \(|
34
+ \[|
35
+ :([+-]?[.\d]+)\)|
36
+ \)|
37
+ ]|
38
+ [^\\()\[\]:]+|
39
+ :
40
+ """,
41
+ re.X,
42
+ )
43
+
44
+ @classmethod
45
+ def set_strategy(cls, strategy):
46
+ if cls._strategy is not None:
47
+ raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
48
+ cls._strategy = strategy
49
+
50
+ @classmethod
51
+ def get_strategy(cls) -> Optional["TokenizeStrategy"]:
52
+ return cls._strategy
53
+
54
+ def _load_tokenizer(
55
+ self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
56
+ ) -> Any:
57
+ tokenizer = None
58
+ if tokenizer_cache_dir:
59
+ local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
60
+ if os.path.exists(local_tokenizer_path):
61
+ logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
62
+ tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
63
+
64
+ if tokenizer is None:
65
+ tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
66
+
67
+ if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
68
+ logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
69
+ tokenizer.save_pretrained(local_tokenizer_path)
70
+
71
+ return tokenizer
72
+
73
+ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
74
+ raise NotImplementedError
75
+
76
+ def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
77
+ """
78
+ returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
79
+ """
80
+ raise NotImplementedError
81
+
82
+ def _get_weighted_input_ids(
83
+ self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
84
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
85
+ """
86
+ max_length includes starting and ending tokens.
87
+ """
88
+
89
+ def parse_prompt_attention(text):
90
+ """
91
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
92
+ Accepted tokens are:
93
+ (abc) - increases attention to abc by a multiplier of 1.1
94
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
95
+ [abc] - decreases attention to abc by a multiplier of 1.1
96
+ \( - literal character '('
97
+ \[ - literal character '['
98
+ \) - literal character ')'
99
+ \] - literal character ']'
100
+ \\ - literal character '\'
101
+ anything else - just text
102
+ >>> parse_prompt_attention('normal text')
103
+ [['normal text', 1.0]]
104
+ >>> parse_prompt_attention('an (important) word')
105
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
106
+ >>> parse_prompt_attention('(unbalanced')
107
+ [['unbalanced', 1.1]]
108
+ >>> parse_prompt_attention('\(literal\]')
109
+ [['(literal]', 1.0]]
110
+ >>> parse_prompt_attention('(unnecessary)(parens)')
111
+ [['unnecessaryparens', 1.1]]
112
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
113
+ [['a ', 1.0],
114
+ ['house', 1.5730000000000004],
115
+ [' ', 1.1],
116
+ ['on', 1.0],
117
+ [' a ', 1.1],
118
+ ['hill', 0.55],
119
+ [', sun, ', 1.1],
120
+ ['sky', 1.4641000000000006],
121
+ ['.', 1.1]]
122
+ """
123
+
124
+ res = []
125
+ round_brackets = []
126
+ square_brackets = []
127
+
128
+ round_bracket_multiplier = 1.1
129
+ square_bracket_multiplier = 1 / 1.1
130
+
131
+ def multiply_range(start_position, multiplier):
132
+ for p in range(start_position, len(res)):
133
+ res[p][1] *= multiplier
134
+
135
+ for m in TokenizeStrategy._re_attention.finditer(text):
136
+ text = m.group(0)
137
+ weight = m.group(1)
138
+
139
+ if text.startswith("\\"):
140
+ res.append([text[1:], 1.0])
141
+ elif text == "(":
142
+ round_brackets.append(len(res))
143
+ elif text == "[":
144
+ square_brackets.append(len(res))
145
+ elif weight is not None and len(round_brackets) > 0:
146
+ multiply_range(round_brackets.pop(), float(weight))
147
+ elif text == ")" and len(round_brackets) > 0:
148
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
149
+ elif text == "]" and len(square_brackets) > 0:
150
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
151
+ else:
152
+ res.append([text, 1.0])
153
+
154
+ for pos in round_brackets:
155
+ multiply_range(pos, round_bracket_multiplier)
156
+
157
+ for pos in square_brackets:
158
+ multiply_range(pos, square_bracket_multiplier)
159
+
160
+ if len(res) == 0:
161
+ res = [["", 1.0]]
162
+
163
+ # merge runs of identical weights
164
+ i = 0
165
+ while i + 1 < len(res):
166
+ if res[i][1] == res[i + 1][1]:
167
+ res[i][0] += res[i + 1][0]
168
+ res.pop(i + 1)
169
+ else:
170
+ i += 1
171
+
172
+ return res
173
+
174
+ def get_prompts_with_weights(text: str, max_length: int):
175
+ r"""
176
+ Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
177
+
178
+ No padding, starting or ending token is included.
179
+ """
180
+ truncated = False
181
+
182
+ texts_and_weights = parse_prompt_attention(text)
183
+ tokens = []
184
+ weights = []
185
+ for word, weight in texts_and_weights:
186
+ # tokenize and discard the starting and the ending token
187
+ token = tokenizer(word).input_ids[1:-1]
188
+ tokens += token
189
+ # copy the weight by length of token
190
+ weights += [weight] * len(token)
191
+ # stop if the text is too long (longer than truncation limit)
192
+ if len(tokens) > max_length:
193
+ truncated = True
194
+ break
195
+ # truncate
196
+ if len(tokens) > max_length:
197
+ truncated = True
198
+ tokens = tokens[:max_length]
199
+ weights = weights[:max_length]
200
+ if truncated:
201
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
202
+ return tokens, weights
203
+
204
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
205
+ r"""
206
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
207
+ """
208
+ tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
209
+ weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
210
+ return tokens, weights
211
+
212
+ if max_length is None:
213
+ max_length = tokenizer.model_max_length
214
+
215
+ tokens, weights = get_prompts_with_weights(text, max_length - 2)
216
+ tokens, weights = pad_tokens_and_weights(
217
+ tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
218
+ )
219
+ return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
220
+
221
+ def _get_input_ids(
222
+ self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
223
+ ) -> torch.Tensor:
224
+ """
225
+ for SD1.5/2.0/SDXL
226
+ TODO support batch input
227
+ """
228
+ if max_length is None:
229
+ max_length = tokenizer.model_max_length - 2
230
+
231
+ if weighted:
232
+ input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
233
+ else:
234
+ input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
235
+
236
+ if max_length > tokenizer.model_max_length:
237
+ input_ids = input_ids.squeeze(0)
238
+ iids_list = []
239
+ if tokenizer.pad_token_id == tokenizer.eos_token_id:
240
+ # v1
241
+ # 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
242
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
243
+ for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
244
+ ids_chunk = (
245
+ input_ids[0].unsqueeze(0),
246
+ input_ids[i : i + tokenizer.model_max_length - 2],
247
+ input_ids[-1].unsqueeze(0),
248
+ )
249
+ ids_chunk = torch.cat(ids_chunk)
250
+ iids_list.append(ids_chunk)
251
+ else:
252
+ # v2 or SDXL
253
+ # 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
254
+ for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
255
+ ids_chunk = (
256
+ input_ids[0].unsqueeze(0), # BOS
257
+ input_ids[i : i + tokenizer.model_max_length - 2],
258
+ input_ids[-1].unsqueeze(0),
259
+ ) # PAD or EOS
260
+ ids_chunk = torch.cat(ids_chunk)
261
+
262
+ # 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
263
+ # 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
264
+ if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
265
+ ids_chunk[-1] = tokenizer.eos_token_id
266
+ # 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
267
+ if ids_chunk[1] == tokenizer.pad_token_id:
268
+ ids_chunk[1] = tokenizer.eos_token_id
269
+
270
+ iids_list.append(ids_chunk)
271
+
272
+ input_ids = torch.stack(iids_list) # 3,77
273
+
274
+ if weighted:
275
+ weights = weights.squeeze(0)
276
+ new_weights = torch.ones(input_ids.shape)
277
+ for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
278
+ b = i // (tokenizer.model_max_length - 2)
279
+ new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
280
+ weights = new_weights
281
+
282
+ if weighted:
283
+ return input_ids, weights
284
+ return input_ids
285
+
286
+
287
+ class TextEncodingStrategy:
288
+ _strategy = None # strategy instance: actual strategy class
289
+
290
+ @classmethod
291
+ def set_strategy(cls, strategy):
292
+ if cls._strategy is not None:
293
+ raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
294
+ cls._strategy = strategy
295
+
296
+ @classmethod
297
+ def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
298
+ return cls._strategy
299
+
300
+ def encode_tokens(
301
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
302
+ ) -> List[torch.Tensor]:
303
+ """
304
+ Encode tokens into embeddings and outputs.
305
+ :param tokens: list of token tensors for each TextModel
306
+ :return: list of output embeddings for each architecture
307
+ """
308
+ raise NotImplementedError
309
+
310
+ def encode_tokens_with_weights(
311
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
312
+ ) -> List[torch.Tensor]:
313
+ """
314
+ Encode tokens into embeddings and outputs.
315
+ :param tokens: list of token tensors for each TextModel
316
+ :param weights: list of weight tensors for each TextModel
317
+ :return: list of output embeddings for each architecture
318
+ """
319
+ raise NotImplementedError
320
+
321
+
322
+ class TextEncoderOutputsCachingStrategy:
323
+ _strategy = None # strategy instance: actual strategy class
324
+
325
+ def __init__(
326
+ self,
327
+ cache_to_disk: bool,
328
+ batch_size: Optional[int],
329
+ skip_disk_cache_validity_check: bool,
330
+ is_partial: bool = False,
331
+ is_weighted: bool = False,
332
+ ) -> None:
333
+ self._cache_to_disk = cache_to_disk
334
+ self._batch_size = batch_size
335
+ self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
336
+ self._is_partial = is_partial
337
+ self._is_weighted = is_weighted
338
+
339
+ @classmethod
340
+ def set_strategy(cls, strategy):
341
+ if cls._strategy is not None:
342
+ raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
343
+ cls._strategy = strategy
344
+
345
+ @classmethod
346
+ def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
347
+ return cls._strategy
348
+
349
+ @property
350
+ def cache_to_disk(self):
351
+ return self._cache_to_disk
352
+
353
+ @property
354
+ def batch_size(self):
355
+ return self._batch_size
356
+
357
+ @property
358
+ def is_partial(self):
359
+ return self._is_partial
360
+
361
+ @property
362
+ def is_weighted(self):
363
+ return self._is_weighted
364
+
365
+ def get_outputs_npz_path(self, image_abs_path: str) -> str:
366
+ raise NotImplementedError
367
+
368
+ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
369
+ raise NotImplementedError
370
+
371
+ def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
372
+ raise NotImplementedError
373
+
374
+ def cache_batch_outputs(
375
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
376
+ ):
377
+ raise NotImplementedError
378
+
379
+
380
+ class LatentsCachingStrategy:
381
+ # TODO commonize utillity functions to this class, such as npz handling etc.
382
+
383
+ _strategy = None # strategy instance: actual strategy class
384
+
385
+ def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
386
+ self._cache_to_disk = cache_to_disk
387
+ self._batch_size = batch_size
388
+ self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
389
+
390
+ @classmethod
391
+ def set_strategy(cls, strategy):
392
+ if cls._strategy is not None:
393
+ raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
394
+ cls._strategy = strategy
395
+
396
+ @classmethod
397
+ def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
398
+ return cls._strategy
399
+
400
+ @property
401
+ def cache_to_disk(self):
402
+ return self._cache_to_disk
403
+
404
+ @property
405
+ def batch_size(self):
406
+ return self._batch_size
407
+
408
+ @property
409
+ def cache_suffix(self):
410
+ raise NotImplementedError
411
+
412
+ def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
413
+ w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
414
+ return int(w), int(h)
415
+
416
+ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
417
+ raise NotImplementedError
418
+
419
+ def is_disk_cached_latents_expected(
420
+ self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
421
+ ) -> bool:
422
+ raise NotImplementedError
423
+
424
+ def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
425
+ raise NotImplementedError
426
+
427
+ def _default_is_disk_cached_latents_expected(
428
+ self,
429
+ latents_stride: int,
430
+ bucket_reso: Tuple[int, int],
431
+ npz_path: str,
432
+ flip_aug: bool,
433
+ alpha_mask: bool,
434
+ multi_resolution: bool = False,
435
+ ):
436
+ if not self.cache_to_disk:
437
+ return False
438
+ if not os.path.exists(npz_path):
439
+ return False
440
+ if self.skip_disk_cache_validity_check:
441
+ return True
442
+
443
+ expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
444
+
445
+ # e.g. "_32x64", HxW
446
+ key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
447
+
448
+ try:
449
+ npz = np.load(npz_path)
450
+ if "latents" + key_reso_suffix not in npz:
451
+ return False
452
+ if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
453
+ return False
454
+ if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
455
+ return False
456
+ except Exception as e:
457
+ logger.error(f"Error loading file: {npz_path}")
458
+ raise e
459
+
460
+ return True
461
+
462
+ # TODO remove circular dependency for ImageInfo
463
+ def _default_cache_batch_latents(
464
+ self,
465
+ encode_by_vae,
466
+ vae_device,
467
+ vae_dtype,
468
+ image_infos: List,
469
+ flip_aug: bool,
470
+ alpha_mask: bool,
471
+ random_crop: bool,
472
+ multi_resolution: bool = False,
473
+ ):
474
+ """
475
+ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
476
+ """
477
+ from library import train_util # import here to avoid circular import
478
+
479
+ img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
480
+ image_infos, alpha_mask, random_crop
481
+ )
482
+ img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
483
+
484
+ with torch.no_grad():
485
+ latents_tensors = encode_by_vae(img_tensor).to("cpu")
486
+ if flip_aug:
487
+ img_tensor = torch.flip(img_tensor, dims=[3])
488
+ with torch.no_grad():
489
+ flipped_latents = encode_by_vae(img_tensor).to("cpu")
490
+ else:
491
+ flipped_latents = [None] * len(latents_tensors)
492
+
493
+ # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
494
+ for i in range(len(image_infos)):
495
+ info = image_infos[i]
496
+ latents = latents_tensors[i]
497
+ flipped_latent = flipped_latents[i]
498
+ alpha_mask = alpha_masks[i]
499
+ original_size = original_sizes[i]
500
+ crop_ltrb = crop_ltrbs[i]
501
+
502
+ latents_size = latents.shape[1:3] # H, W
503
+ key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
504
+
505
+ if self.cache_to_disk:
506
+ self.save_latents_to_disk(
507
+ info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
508
+ )
509
+ else:
510
+ info.latents_original_size = original_size
511
+ info.latents_crop_ltrb = crop_ltrb
512
+ info.latents = latents
513
+ if flip_aug:
514
+ info.latents_flipped = flipped_latent
515
+ info.alpha_mask = alpha_mask
516
+
517
+ def load_latents_from_disk(
518
+ self, npz_path: str, bucket_reso: Tuple[int, int]
519
+ ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
520
+ """
521
+ for SD/SDXL
522
+ """
523
+ return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
524
+
525
+ def _default_load_latents_from_disk(
526
+ self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
527
+ ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
528
+ if latents_stride is None:
529
+ key_reso_suffix = ""
530
+ else:
531
+ latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
532
+ key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
533
+
534
+ npz = np.load(npz_path)
535
+ if "latents" + key_reso_suffix not in npz:
536
+ raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
537
+
538
+ latents = npz["latents" + key_reso_suffix]
539
+ original_size = npz["original_size" + key_reso_suffix].tolist()
540
+ crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
541
+ flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
542
+ alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
543
+ return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
544
+
545
+ def save_latents_to_disk(
546
+ self,
547
+ npz_path,
548
+ latents_tensor,
549
+ original_size,
550
+ crop_ltrb,
551
+ flipped_latents_tensor=None,
552
+ alpha_mask=None,
553
+ key_reso_suffix="",
554
+ ):
555
+ kwargs = {}
556
+
557
+ if os.path.exists(npz_path):
558
+ # load existing npz and update it
559
+ npz = np.load(npz_path)
560
+ for key in npz.files:
561
+ kwargs[key] = npz[key]
562
+
563
+ kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
564
+ kwargs["original_size" + key_reso_suffix] = np.array(original_size)
565
+ kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
566
+ if flipped_latents_tensor is not None:
567
+ kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
568
+ if alpha_mask is not None:
569
+ kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
570
+ np.savez(npz_path, **kwargs)
library/strategy_flux.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from typing import Any, List, Optional, Tuple, Union
4
+ import torch
5
+ import numpy as np
6
+ from transformers import CLIPTokenizer, T5TokenizerFast
7
+
8
+ from library import flux_utils, train_util
9
+ from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
10
+
11
+ from library.utils import setup_logging
12
+
13
+ setup_logging()
14
+ import logging
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
20
+ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
21
+
22
+
23
+ class FluxTokenizeStrategy(TokenizeStrategy):
24
+ def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
25
+ self.t5xxl_max_length = t5xxl_max_length
26
+ self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
27
+ self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
28
+
29
+ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
30
+ text = [text] if isinstance(text, str) else text
31
+
32
+ l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
33
+ t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
34
+
35
+ t5_attn_mask = t5_tokens["attention_mask"]
36
+ l_tokens = l_tokens["input_ids"]
37
+ t5_tokens = t5_tokens["input_ids"]
38
+
39
+ return [l_tokens, t5_tokens, t5_attn_mask]
40
+
41
+
42
+ class FluxTextEncodingStrategy(TextEncodingStrategy):
43
+ def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
44
+ """
45
+ Args:
46
+ apply_t5_attn_mask: Default value for apply_t5_attn_mask.
47
+ """
48
+ self.apply_t5_attn_mask = apply_t5_attn_mask
49
+
50
+ def encode_tokens(
51
+ self,
52
+ tokenize_strategy: TokenizeStrategy,
53
+ models: List[Any],
54
+ tokens: List[torch.Tensor],
55
+ apply_t5_attn_mask: Optional[bool] = None,
56
+ ) -> List[torch.Tensor]:
57
+ # supports single model inference
58
+
59
+ if apply_t5_attn_mask is None:
60
+ apply_t5_attn_mask = self.apply_t5_attn_mask
61
+
62
+ clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
63
+ l_tokens, t5_tokens = tokens[:2]
64
+ t5_attn_mask = tokens[2] if len(tokens) > 2 else None
65
+
66
+ # clip_l is None when using T5 only
67
+ if clip_l is not None and l_tokens is not None:
68
+ l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
69
+ else:
70
+ l_pooled = None
71
+
72
+ # t5xxl is None when using CLIP only
73
+ if t5xxl is not None and t5_tokens is not None:
74
+ # t5_out is [b, max length, 4096]
75
+ attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
76
+ t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
77
+ # if zero_pad_t5_output:
78
+ # t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
79
+ txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
80
+ else:
81
+ t5_out = None
82
+ txt_ids = None
83
+ t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
84
+
85
+ return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
86
+
87
+
88
+ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
89
+ FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
90
+
91
+ def __init__(
92
+ self,
93
+ cache_to_disk: bool,
94
+ batch_size: int,
95
+ skip_disk_cache_validity_check: bool,
96
+ is_partial: bool = False,
97
+ apply_t5_attn_mask: bool = False,
98
+ ) -> None:
99
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
100
+ self.apply_t5_attn_mask = apply_t5_attn_mask
101
+
102
+ self.warn_fp8_weights = False
103
+
104
+ def get_outputs_npz_path(self, image_abs_path: str) -> str:
105
+ return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
106
+
107
+ def is_disk_cached_outputs_expected(self, npz_path: str):
108
+ if not self.cache_to_disk:
109
+ return False
110
+ if not os.path.exists(npz_path):
111
+ return False
112
+ if self.skip_disk_cache_validity_check:
113
+ return True
114
+
115
+ try:
116
+ npz = np.load(npz_path)
117
+ if "l_pooled" not in npz:
118
+ return False
119
+ if "t5_out" not in npz:
120
+ return False
121
+ if "txt_ids" not in npz:
122
+ return False
123
+ if "t5_attn_mask" not in npz:
124
+ return False
125
+ if "apply_t5_attn_mask" not in npz:
126
+ return False
127
+ npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
128
+ if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
129
+ return False
130
+ except Exception as e:
131
+ logger.error(f"Error loading file: {npz_path}")
132
+ raise e
133
+
134
+ return True
135
+
136
+ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
137
+ data = np.load(npz_path)
138
+ l_pooled = data["l_pooled"]
139
+ t5_out = data["t5_out"]
140
+ txt_ids = data["txt_ids"]
141
+ t5_attn_mask = data["t5_attn_mask"]
142
+ # apply_t5_attn_mask should be same as self.apply_t5_attn_mask
143
+ return [l_pooled, t5_out, txt_ids, t5_attn_mask]
144
+
145
+ def cache_batch_outputs(
146
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
147
+ ):
148
+ if not self.warn_fp8_weights:
149
+ if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
150
+ logger.warning(
151
+ "T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
152
+ " / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
153
+ )
154
+ self.warn_fp8_weights = True
155
+
156
+ flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
157
+ captions = [info.caption for info in infos]
158
+
159
+ tokens_and_masks = tokenize_strategy.tokenize(captions)
160
+ with torch.no_grad():
161
+ # attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
162
+ l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
163
+
164
+ if l_pooled.dtype == torch.bfloat16:
165
+ l_pooled = l_pooled.float()
166
+ if t5_out.dtype == torch.bfloat16:
167
+ t5_out = t5_out.float()
168
+ if txt_ids.dtype == torch.bfloat16:
169
+ txt_ids = txt_ids.float()
170
+
171
+ l_pooled = l_pooled.cpu().numpy()
172
+ t5_out = t5_out.cpu().numpy()
173
+ txt_ids = txt_ids.cpu().numpy()
174
+ t5_attn_mask = tokens_and_masks[2].cpu().numpy()
175
+
176
+ for i, info in enumerate(infos):
177
+ l_pooled_i = l_pooled[i]
178
+ t5_out_i = t5_out[i]
179
+ txt_ids_i = txt_ids[i]
180
+ t5_attn_mask_i = t5_attn_mask[i]
181
+ apply_t5_attn_mask_i = self.apply_t5_attn_mask
182
+
183
+ if self.cache_to_disk:
184
+ np.savez(
185
+ info.text_encoder_outputs_npz,
186
+ l_pooled=l_pooled_i,
187
+ t5_out=t5_out_i,
188
+ txt_ids=txt_ids_i,
189
+ t5_attn_mask=t5_attn_mask_i,
190
+ apply_t5_attn_mask=apply_t5_attn_mask_i,
191
+ )
192
+ else:
193
+ # it's fine that attn mask is not None. it's overwritten before calling the model if necessary
194
+ info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
195
+
196
+
197
+ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
198
+ FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
199
+
200
+ def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
201
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
202
+
203
+ @property
204
+ def cache_suffix(self) -> str:
205
+ return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
206
+
207
+ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
208
+ return (
209
+ os.path.splitext(absolute_path)[0]
210
+ + f"_{image_size[0]:04d}x{image_size[1]:04d}"
211
+ + FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
212
+ )
213
+
214
+ def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
215
+ return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
216
+
217
+ def load_latents_from_disk(
218
+ self, npz_path: str, bucket_reso: Tuple[int, int]
219
+ ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
220
+ return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
221
+
222
+ # TODO remove circular dependency for ImageInfo
223
+ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
224
+ encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
225
+ vae_device = vae.device
226
+ vae_dtype = vae.dtype
227
+
228
+ self._default_cache_batch_latents(
229
+ encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
230
+ )
231
+
232
+ if not train_util.HIGH_VRAM:
233
+ train_util.clean_memory_on_device(vae.device)
234
+
235
+
236
+ if __name__ == "__main__":
237
+ # test code for FluxTokenizeStrategy
238
+ # tokenizer = sd3_models.SD3Tokenizer()
239
+ strategy = FluxTokenizeStrategy(256)
240
+ text = "hello world"
241
+
242
+ l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
243
+ # print(l_tokens.shape)
244
+ print(l_tokens)
245
+ print(g_tokens)
246
+ print(t5_tokens)
247
+
248
+ texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
249
+ l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
250
+ g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
251
+ t5_tokens_2 = strategy.t5xxl(
252
+ texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
253
+ )
254
+ print(l_tokens_2)
255
+ print(g_tokens_2)
256
+ print(t5_tokens_2)
257
+
258
+ # compare
259
+ print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
260
+ print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
261
+ print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
262
+
263
+ text = ",".join(["hello world! this is long text"] * 50)
264
+ l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
265
+ print(l_tokens)
266
+ print(g_tokens)
267
+ print(t5_tokens)
268
+
269
+ print(f"model max length l: {strategy.clip_l.model_max_length}")
270
+ print(f"model max length g: {strategy.clip_g.model_max_length}")
271
+ print(f"model max length t5: {strategy.t5xxl.model_max_length}")
library/strategy_sd.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import Any, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ from transformers import CLIPTokenizer
7
+ from library import train_util
8
+ from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
9
+ from library.utils import setup_logging
10
+
11
+ setup_logging()
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ TOKENIZER_ID = "openai/clip-vit-large-patch14"
18
+ V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
19
+
20
+
21
+ class SdTokenizeStrategy(TokenizeStrategy):
22
+ def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
23
+ """
24
+ max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
25
+ """
26
+ logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
27
+ if v2:
28
+ self.tokenizer = self._load_tokenizer(
29
+ CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
30
+ )
31
+ else:
32
+ self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
33
+
34
+ if max_length is None:
35
+ self.max_length = self.tokenizer.model_max_length
36
+ else:
37
+ self.max_length = max_length + 2
38
+
39
+ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
40
+ text = [text] if isinstance(text, str) else text
41
+ return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
42
+
43
+ def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
44
+ text = [text] if isinstance(text, str) else text
45
+ tokens_list = []
46
+ weights_list = []
47
+ for t in text:
48
+ tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
49
+ tokens_list.append(tokens)
50
+ weights_list.append(weights)
51
+ return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
52
+
53
+
54
+ class SdTextEncodingStrategy(TextEncodingStrategy):
55
+ def __init__(self, clip_skip: Optional[int] = None) -> None:
56
+ self.clip_skip = clip_skip
57
+
58
+ def encode_tokens(
59
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
60
+ ) -> List[torch.Tensor]:
61
+ text_encoder = models[0]
62
+ tokens = tokens[0]
63
+ sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
64
+
65
+ # tokens: b,n,77
66
+ b_size = tokens.size()[0]
67
+ max_token_length = tokens.size()[1] * tokens.size()[2]
68
+ model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
69
+ tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
70
+
71
+ tokens = tokens.to(text_encoder.device)
72
+
73
+ if self.clip_skip is None:
74
+ encoder_hidden_states = text_encoder(tokens)[0]
75
+ else:
76
+ enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
77
+ encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
78
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
79
+
80
+ # bs*3, 77, 768 or 1024
81
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
82
+
83
+ if max_token_length != model_max_length:
84
+ v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
85
+ if not v1:
86
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
87
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
88
+ for i in range(1, max_token_length, model_max_length):
89
+ chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
90
+ if i > 0:
91
+ for j in range(len(chunk)):
92
+ if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
93
+ # 空、つまり <BOS> <EOS> <PAD> ...のパターン
94
+ chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
95
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
96
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
97
+ encoder_hidden_states = torch.cat(states_list, dim=1)
98
+ else:
99
+ # v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
100
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
101
+ for i in range(1, max_token_length, model_max_length):
102
+ states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
103
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
104
+ encoder_hidden_states = torch.cat(states_list, dim=1)
105
+
106
+ return [encoder_hidden_states]
107
+
108
+ def encode_tokens_with_weights(
109
+ self,
110
+ tokenize_strategy: TokenizeStrategy,
111
+ models: List[Any],
112
+ tokens_list: List[torch.Tensor],
113
+ weights_list: List[torch.Tensor],
114
+ ) -> List[torch.Tensor]:
115
+ encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
116
+
117
+ weights = weights_list[0].to(encoder_hidden_states.device)
118
+
119
+ # apply weights
120
+ if weights.shape[1] == 1: # no max_token_length
121
+ # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
122
+ encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
123
+ else:
124
+ # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
125
+ for i in range(weights.shape[1]):
126
+ encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
127
+ :, i, 1:-1
128
+ ].unsqueeze(-1)
129
+
130
+ return [encoder_hidden_states]
131
+
132
+
133
+ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
134
+ # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
135
+ # and we keep the old npz for the backward compatibility.
136
+
137
+ SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
138
+ SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
139
+ SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
140
+
141
+ def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
142
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
143
+ self.sd = sd
144
+ self.suffix = (
145
+ SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
146
+ )
147
+
148
+ @property
149
+ def cache_suffix(self) -> str:
150
+ return self.suffix
151
+
152
+ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
153
+ # support old .npz
154
+ old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
155
+ if os.path.exists(old_npz_file):
156
+ return old_npz_file
157
+ return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
158
+
159
+ def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
160
+ return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
161
+
162
+ # TODO remove circular dependency for ImageInfo
163
+ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
164
+ encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
165
+ vae_device = vae.device
166
+ vae_dtype = vae.dtype
167
+
168
+ self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
169
+
170
+ if not train_util.HIGH_VRAM:
171
+ train_util.clean_memory_on_device(vae.device)
library/strategy_sd3.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ from typing import Any, List, Optional, Tuple, Union
5
+ import torch
6
+ import numpy as np
7
+ from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
8
+
9
+ from library import sd3_utils, train_util
10
+ from library import sd3_models
11
+ from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
12
+
13
+ from library.utils import setup_logging
14
+
15
+ setup_logging()
16
+ import logging
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
22
+ CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
23
+ T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
24
+
25
+
26
+ class Sd3TokenizeStrategy(TokenizeStrategy):
27
+ def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
28
+ self.t5xxl_max_length = t5xxl_max_length
29
+ self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
30
+ self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
31
+ self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
32
+ self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
33
+
34
+ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
35
+ text = [text] if isinstance(text, str) else text
36
+
37
+ l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
38
+ g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
39
+ t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
40
+
41
+ l_attn_mask = l_tokens["attention_mask"]
42
+ g_attn_mask = g_tokens["attention_mask"]
43
+ t5_attn_mask = t5_tokens["attention_mask"]
44
+ l_tokens = l_tokens["input_ids"]
45
+ g_tokens = g_tokens["input_ids"]
46
+ t5_tokens = t5_tokens["input_ids"]
47
+
48
+ return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
49
+
50
+
51
+ class Sd3TextEncodingStrategy(TextEncodingStrategy):
52
+ def __init__(
53
+ self,
54
+ apply_lg_attn_mask: Optional[bool] = None,
55
+ apply_t5_attn_mask: Optional[bool] = None,
56
+ l_dropout_rate: float = 0.0,
57
+ g_dropout_rate: float = 0.0,
58
+ t5_dropout_rate: float = 0.0,
59
+ ) -> None:
60
+ """
61
+ Args:
62
+ apply_t5_attn_mask: Default value for apply_t5_attn_mask.
63
+ """
64
+ self.apply_lg_attn_mask = apply_lg_attn_mask
65
+ self.apply_t5_attn_mask = apply_t5_attn_mask
66
+ self.l_dropout_rate = l_dropout_rate
67
+ self.g_dropout_rate = g_dropout_rate
68
+ self.t5_dropout_rate = t5_dropout_rate
69
+
70
+ def encode_tokens(
71
+ self,
72
+ tokenize_strategy: TokenizeStrategy,
73
+ models: List[Any],
74
+ tokens: List[torch.Tensor],
75
+ apply_lg_attn_mask: Optional[bool] = False,
76
+ apply_t5_attn_mask: Optional[bool] = False,
77
+ enable_dropout: bool = True,
78
+ ) -> List[torch.Tensor]:
79
+ """
80
+ returned embeddings are not masked
81
+ """
82
+ clip_l, clip_g, t5xxl = models
83
+ clip_l: Optional[CLIPTextModel]
84
+ clip_g: Optional[CLIPTextModelWithProjection]
85
+ t5xxl: Optional[T5EncoderModel]
86
+
87
+ if apply_lg_attn_mask is None:
88
+ apply_lg_attn_mask = self.apply_lg_attn_mask
89
+ if apply_t5_attn_mask is None:
90
+ apply_t5_attn_mask = self.apply_t5_attn_mask
91
+
92
+ l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
93
+
94
+ # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
95
+
96
+ if l_tokens is None or clip_l is None:
97
+ assert g_tokens is None, "g_tokens must be None if l_tokens is None"
98
+ lg_out = None
99
+ lg_pooled = None
100
+ l_attn_mask = None
101
+ g_attn_mask = None
102
+ else:
103
+ assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
104
+
105
+ # drop some members of the batch: we do not call clip_l and clip_g for dropped members
106
+ batch_size, l_seq_len = l_tokens.shape
107
+ g_seq_len = g_tokens.shape[1]
108
+
109
+ non_drop_l_indices = []
110
+ non_drop_g_indices = []
111
+ for i in range(l_tokens.shape[0]):
112
+ drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
113
+ drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
114
+ if not drop_l:
115
+ non_drop_l_indices.append(i)
116
+ if not drop_g:
117
+ non_drop_g_indices.append(i)
118
+
119
+ # filter out dropped members
120
+ if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
121
+ l_tokens = l_tokens[non_drop_l_indices]
122
+ l_attn_mask = l_attn_mask[non_drop_l_indices]
123
+ if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
124
+ g_tokens = g_tokens[non_drop_g_indices]
125
+ g_attn_mask = g_attn_mask[non_drop_g_indices]
126
+
127
+ # call clip_l for non-dropped members
128
+ if len(non_drop_l_indices) > 0:
129
+ nd_l_attn_mask = l_attn_mask.to(clip_l.device)
130
+ prompt_embeds = clip_l(
131
+ l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
132
+ )
133
+ nd_l_pooled = prompt_embeds[0]
134
+ nd_l_out = prompt_embeds.hidden_states[-2]
135
+ if len(non_drop_g_indices) > 0:
136
+ nd_g_attn_mask = g_attn_mask.to(clip_g.device)
137
+ prompt_embeds = clip_g(
138
+ g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
139
+ )
140
+ nd_g_pooled = prompt_embeds[0]
141
+ nd_g_out = prompt_embeds.hidden_states[-2]
142
+
143
+ # fill in the dropped members
144
+ if len(non_drop_l_indices) == batch_size:
145
+ l_pooled = nd_l_pooled
146
+ l_out = nd_l_out
147
+ else:
148
+ # model output is always float32 because of the models are wrapped with Accelerator
149
+ l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
150
+ l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
151
+ l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
152
+ if len(non_drop_l_indices) > 0:
153
+ l_pooled[non_drop_l_indices] = nd_l_pooled
154
+ l_out[non_drop_l_indices] = nd_l_out
155
+ l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
156
+
157
+ if len(non_drop_g_indices) == batch_size:
158
+ g_pooled = nd_g_pooled
159
+ g_out = nd_g_out
160
+ else:
161
+ g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
162
+ g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
163
+ g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
164
+ if len(non_drop_g_indices) > 0:
165
+ g_pooled[non_drop_g_indices] = nd_g_pooled
166
+ g_out[non_drop_g_indices] = nd_g_out
167
+ g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
168
+
169
+ lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
170
+ lg_out = torch.cat([l_out, g_out], dim=-1)
171
+
172
+ if t5xxl is None or t5_tokens is None:
173
+ t5_out = None
174
+ t5_attn_mask = None
175
+ else:
176
+ # drop some members of the batch: we do not call t5xxl for dropped members
177
+ batch_size, t5_seq_len = t5_tokens.shape
178
+ non_drop_t5_indices = []
179
+ for i in range(t5_tokens.shape[0]):
180
+ drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
181
+ if not drop_t5:
182
+ non_drop_t5_indices.append(i)
183
+
184
+ # filter out dropped members
185
+ if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
186
+ t5_tokens = t5_tokens[non_drop_t5_indices]
187
+ t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
188
+
189
+ # call t5xxl for non-dropped members
190
+ if len(non_drop_t5_indices) > 0:
191
+ nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
192
+ nd_t5_out, _ = t5xxl(
193
+ t5_tokens.to(t5xxl.device),
194
+ nd_t5_attn_mask if apply_t5_attn_mask else None,
195
+ return_dict=False,
196
+ output_hidden_states=True,
197
+ )
198
+
199
+ # fill in the dropped members
200
+ if len(non_drop_t5_indices) == batch_size:
201
+ t5_out = nd_t5_out
202
+ else:
203
+ t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
204
+ t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
205
+ if len(non_drop_t5_indices) > 0:
206
+ t5_out[non_drop_t5_indices] = nd_t5_out
207
+ t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
208
+
209
+ # masks are used for attention masking in transformer
210
+ return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
211
+
212
+ def drop_cached_text_encoder_outputs(
213
+ self,
214
+ lg_out: torch.Tensor,
215
+ t5_out: torch.Tensor,
216
+ lg_pooled: torch.Tensor,
217
+ l_attn_mask: torch.Tensor,
218
+ g_attn_mask: torch.Tensor,
219
+ t5_attn_mask: torch.Tensor,
220
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
221
+ # dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
222
+ if lg_out is not None:
223
+ for i in range(lg_out.shape[0]):
224
+ drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
225
+ if drop_l:
226
+ lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
227
+ lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
228
+ if l_attn_mask is not None:
229
+ l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
230
+ drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
231
+ if drop_g:
232
+ lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
233
+ lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
234
+ if g_attn_mask is not None:
235
+ g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
236
+
237
+ if t5_out is not None:
238
+ for i in range(t5_out.shape[0]):
239
+ drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
240
+ if drop_t5:
241
+ t5_out[i] = torch.zeros_like(t5_out[i])
242
+ if t5_attn_mask is not None:
243
+ t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
244
+
245
+ return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
246
+
247
+ def concat_encodings(
248
+ self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
249
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
250
+ lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
251
+ if t5_out is None:
252
+ t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
253
+ return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
254
+
255
+
256
+ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
257
+ SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
258
+
259
+ def __init__(
260
+ self,
261
+ cache_to_disk: bool,
262
+ batch_size: int,
263
+ skip_disk_cache_validity_check: bool,
264
+ is_partial: bool = False,
265
+ apply_lg_attn_mask: bool = False,
266
+ apply_t5_attn_mask: bool = False,
267
+ ) -> None:
268
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
269
+ self.apply_lg_attn_mask = apply_lg_attn_mask
270
+ self.apply_t5_attn_mask = apply_t5_attn_mask
271
+
272
+ def get_outputs_npz_path(self, image_abs_path: str) -> str:
273
+ return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
274
+
275
+ def is_disk_cached_outputs_expected(self, npz_path: str):
276
+ if not self.cache_to_disk:
277
+ return False
278
+ if not os.path.exists(npz_path):
279
+ return False
280
+ if self.skip_disk_cache_validity_check:
281
+ return True
282
+
283
+ try:
284
+ npz = np.load(npz_path)
285
+ if "lg_out" not in npz:
286
+ return False
287
+ if "lg_pooled" not in npz:
288
+ return False
289
+ if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
290
+ return False
291
+ if "apply_lg_attn_mask" not in npz:
292
+ return False
293
+ if "t5_out" not in npz:
294
+ return False
295
+ if "t5_attn_mask" not in npz:
296
+ return False
297
+ npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
298
+ if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
299
+ return False
300
+ if "apply_t5_attn_mask" not in npz:
301
+ return False
302
+ npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
303
+ if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
304
+ return False
305
+ except Exception as e:
306
+ logger.error(f"Error loading file: {npz_path}")
307
+ raise e
308
+
309
+ return True
310
+
311
+ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
312
+ data = np.load(npz_path)
313
+ lg_out = data["lg_out"]
314
+ lg_pooled = data["lg_pooled"]
315
+ t5_out = data["t5_out"]
316
+
317
+ l_attn_mask = data["clip_l_attn_mask"]
318
+ g_attn_mask = data["clip_g_attn_mask"]
319
+ t5_attn_mask = data["t5_attn_mask"]
320
+
321
+ # apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
322
+ return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
323
+
324
+ def cache_batch_outputs(
325
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
326
+ ):
327
+ sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
328
+ captions = [info.caption for info in infos]
329
+
330
+ tokens_and_masks = tokenize_strategy.tokenize(captions)
331
+ with torch.no_grad():
332
+ # always disable dropout during caching
333
+ lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
334
+ tokenize_strategy,
335
+ models,
336
+ tokens_and_masks,
337
+ apply_lg_attn_mask=self.apply_lg_attn_mask,
338
+ apply_t5_attn_mask=self.apply_t5_attn_mask,
339
+ enable_dropout=False,
340
+ )
341
+
342
+ if lg_out.dtype == torch.bfloat16:
343
+ lg_out = lg_out.float()
344
+ if lg_pooled.dtype == torch.bfloat16:
345
+ lg_pooled = lg_pooled.float()
346
+ if t5_out.dtype == torch.bfloat16:
347
+ t5_out = t5_out.float()
348
+
349
+ lg_out = lg_out.cpu().numpy()
350
+ lg_pooled = lg_pooled.cpu().numpy()
351
+ t5_out = t5_out.cpu().numpy()
352
+
353
+ l_attn_mask = tokens_and_masks[3].cpu().numpy()
354
+ g_attn_mask = tokens_and_masks[4].cpu().numpy()
355
+ t5_attn_mask = tokens_and_masks[5].cpu().numpy()
356
+
357
+ for i, info in enumerate(infos):
358
+ lg_out_i = lg_out[i]
359
+ t5_out_i = t5_out[i]
360
+ lg_pooled_i = lg_pooled[i]
361
+ l_attn_mask_i = l_attn_mask[i]
362
+ g_attn_mask_i = g_attn_mask[i]
363
+ t5_attn_mask_i = t5_attn_mask[i]
364
+ apply_lg_attn_mask = self.apply_lg_attn_mask
365
+ apply_t5_attn_mask = self.apply_t5_attn_mask
366
+
367
+ if self.cache_to_disk:
368
+ np.savez(
369
+ info.text_encoder_outputs_npz,
370
+ lg_out=lg_out_i,
371
+ lg_pooled=lg_pooled_i,
372
+ t5_out=t5_out_i,
373
+ clip_l_attn_mask=l_attn_mask_i,
374
+ clip_g_attn_mask=g_attn_mask_i,
375
+ t5_attn_mask=t5_attn_mask_i,
376
+ apply_lg_attn_mask=apply_lg_attn_mask,
377
+ apply_t5_attn_mask=apply_t5_attn_mask,
378
+ )
379
+ else:
380
+ # it's fine that attn mask is not None. it's overwritten before calling the model if necessary
381
+ info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
382
+
383
+
384
+ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
385
+ SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
386
+
387
+ def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
388
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
389
+
390
+ @property
391
+ def cache_suffix(self) -> str:
392
+ return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
393
+
394
+ def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
395
+ return (
396
+ os.path.splitext(absolute_path)[0]
397
+ + f"_{image_size[0]:04d}x{image_size[1]:04d}"
398
+ + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
399
+ )
400
+
401
+ def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
402
+ return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
403
+
404
+ def load_latents_from_disk(
405
+ self, npz_path: str, bucket_reso: Tuple[int, int]
406
+ ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
407
+ return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
408
+
409
+ # TODO remove circular dependency for ImageInfo
410
+ def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
411
+ encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
412
+ vae_device = vae.device
413
+ vae_dtype = vae.dtype
414
+
415
+ self._default_cache_batch_latents(
416
+ encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
417
+ )
418
+
419
+ if not train_util.HIGH_VRAM:
420
+ train_util.clean_memory_on_device(vae.device)
library/strategy_sdxl.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
7
+ from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
8
+
9
+
10
+ from library.utils import setup_logging
11
+
12
+ setup_logging()
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
19
+ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
20
+
21
+
22
+ class SdxlTokenizeStrategy(TokenizeStrategy):
23
+ def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
24
+ self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
25
+ self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
26
+ self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
27
+
28
+ if max_length is None:
29
+ self.max_length = self.tokenizer1.model_max_length
30
+ else:
31
+ self.max_length = max_length + 2
32
+
33
+ def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
34
+ text = [text] if isinstance(text, str) else text
35
+ return (
36
+ torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
37
+ torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
38
+ )
39
+
40
+ def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
41
+ text = [text] if isinstance(text, str) else text
42
+ tokens1_list, tokens2_list = [], []
43
+ weights1_list, weights2_list = [], []
44
+ for t in text:
45
+ tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
46
+ tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
47
+ tokens1_list.append(tokens1)
48
+ tokens2_list.append(tokens2)
49
+ weights1_list.append(weights1)
50
+ weights2_list.append(weights2)
51
+ return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
52
+ torch.stack(weights1_list, dim=0),
53
+ torch.stack(weights2_list, dim=0),
54
+ ]
55
+
56
+
57
+ class SdxlTextEncodingStrategy(TextEncodingStrategy):
58
+ def __init__(self) -> None:
59
+ pass
60
+
61
+ def _pool_workaround(
62
+ self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
63
+ ):
64
+ r"""
65
+ workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
66
+ instead of the hidden states for the EOS token
67
+ If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
68
+
69
+ Original code from CLIP's pooling function:
70
+
71
+ \# text_embeds.shape = [batch_size, sequence_length, transformer.width]
72
+ \# take features from the eot embedding (eot_token is the highest number in each sequence)
73
+ \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
74
+ pooled_output = last_hidden_state[
75
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
76
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
77
+ ]
78
+ """
79
+
80
+ # input_ids: b*n,77
81
+ # find index for EOS token
82
+
83
+ # Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
84
+ # eos_token_index = torch.where(input_ids == eos_token_id)[1]
85
+ # eos_token_index = eos_token_index.to(device=last_hidden_state.device)
86
+
87
+ # Create a mask where the EOS tokens are
88
+ eos_token_mask = (input_ids == eos_token_id).int()
89
+
90
+ # Use argmax to find the last index of the EOS token for each element in the batch
91
+ eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
92
+ eos_token_index = eos_token_index.to(device=last_hidden_state.device)
93
+
94
+ # get hidden states for EOS token
95
+ pooled_output = last_hidden_state[
96
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
97
+ ]
98
+
99
+ # apply projection: projection may be of different dtype than last_hidden_state
100
+ pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
101
+ pooled_output = pooled_output.to(last_hidden_state.dtype)
102
+
103
+ return pooled_output
104
+
105
+ def _get_hidden_states_sdxl(
106
+ self,
107
+ input_ids1: torch.Tensor,
108
+ input_ids2: torch.Tensor,
109
+ tokenizer1: CLIPTokenizer,
110
+ tokenizer2: CLIPTokenizer,
111
+ text_encoder1: Union[CLIPTextModel, torch.nn.Module],
112
+ text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
113
+ unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
114
+ ):
115
+ # input_ids: b,n,77 -> b*n, 77
116
+ b_size = input_ids1.size()[0]
117
+ if input_ids1.size()[1] == 1:
118
+ max_token_length = None
119
+ else:
120
+ max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
121
+ input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
122
+ input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
123
+ input_ids1 = input_ids1.to(text_encoder1.device)
124
+ input_ids2 = input_ids2.to(text_encoder2.device)
125
+
126
+ # text_encoder1
127
+ enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
128
+ hidden_states1 = enc_out["hidden_states"][11]
129
+
130
+ # text_encoder2
131
+ enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
132
+ hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
133
+
134
+ # pool2 = enc_out["text_embeds"]
135
+ unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
136
+ pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
137
+
138
+ # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
139
+ n_size = 1 if max_token_length is None else max_token_length // 75
140
+ hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
141
+ hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
142
+
143
+ if max_token_length is not None:
144
+ # bs*3, 77, 768 or 1024
145
+ # encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
146
+ states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
147
+ for i in range(1, max_token_length, tokenizer1.model_max_length):
148
+ states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
149
+ states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
150
+ hidden_states1 = torch.cat(states_list, dim=1)
151
+
152
+ # v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
153
+ states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
154
+ for i in range(1, max_token_length, tokenizer2.model_max_length):
155
+ chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
156
+ # this causes an error:
157
+ # RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
158
+ # if i > 1:
159
+ # for j in range(len(chunk)): # batch_size
160
+ # if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
161
+ # chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
162
+ states_list.append(chunk) # <BOS> の後から <EOS> の前まで
163
+ states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
164
+ hidden_states2 = torch.cat(states_list, dim=1)
165
+
166
+ # pool はnの最初のものを使う
167
+ pool2 = pool2[::n_size]
168
+
169
+ return hidden_states1, hidden_states2, pool2
170
+
171
+ def encode_tokens(
172
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
173
+ ) -> List[torch.Tensor]:
174
+ """
175
+ Args:
176
+ tokenize_strategy: TokenizeStrategy
177
+ models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
178
+ If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
179
+ tokens: List of tokens, for text_encoder1 and text_encoder2
180
+ """
181
+ if len(models) == 2:
182
+ text_encoder1, text_encoder2 = models
183
+ unwrapped_text_encoder2 = None
184
+ else:
185
+ text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
186
+ tokens1, tokens2 = tokens
187
+ sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
188
+ tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
189
+
190
+ hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
191
+ tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
192
+ )
193
+ return [hidden_states1, hidden_states2, pool2]
194
+
195
+ def encode_tokens_with_weights(
196
+ self,
197
+ tokenize_strategy: TokenizeStrategy,
198
+ models: List[Any],
199
+ tokens_list: List[torch.Tensor],
200
+ weights_list: List[torch.Tensor],
201
+ ) -> List[torch.Tensor]:
202
+ hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
203
+
204
+ weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
205
+
206
+ # apply weights
207
+ if weights_list[0].shape[1] == 1: # no max_token_length
208
+ # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
209
+ hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
210
+ hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
211
+ else:
212
+ # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
213
+ for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
214
+ for i in range(weight.shape[1]):
215
+ hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
216
+ :, i, 1:-1
217
+ ].unsqueeze(-1)
218
+
219
+ return [hidden_states1, hidden_states2, pool2]
220
+
221
+
222
+ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
223
+ SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
224
+
225
+ def __init__(
226
+ self,
227
+ cache_to_disk: bool,
228
+ batch_size: int,
229
+ skip_disk_cache_validity_check: bool,
230
+ is_partial: bool = False,
231
+ is_weighted: bool = False,
232
+ ) -> None:
233
+ super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
234
+
235
+ def get_outputs_npz_path(self, image_abs_path: str) -> str:
236
+ return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
237
+
238
+ def is_disk_cached_outputs_expected(self, npz_path: str):
239
+ if not self.cache_to_disk:
240
+ return False
241
+ if not os.path.exists(npz_path):
242
+ return False
243
+ if self.skip_disk_cache_validity_check:
244
+ return True
245
+
246
+ try:
247
+ npz = np.load(npz_path)
248
+ if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
249
+ return False
250
+ except Exception as e:
251
+ logger.error(f"Error loading file: {npz_path}")
252
+ raise e
253
+
254
+ return True
255
+
256
+ def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
257
+ data = np.load(npz_path)
258
+ hidden_state1 = data["hidden_state1"]
259
+ hidden_state2 = data["hidden_state2"]
260
+ pool2 = data["pool2"]
261
+ return [hidden_state1, hidden_state2, pool2]
262
+
263
+ def cache_batch_outputs(
264
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
265
+ ):
266
+ sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
267
+ captions = [info.caption for info in infos]
268
+
269
+ if self.is_weighted:
270
+ tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
271
+ with torch.no_grad():
272
+ hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
273
+ tokenize_strategy, models, tokens_list, weights_list
274
+ )
275
+ else:
276
+ tokens1, tokens2 = tokenize_strategy.tokenize(captions)
277
+ with torch.no_grad():
278
+ hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
279
+ tokenize_strategy, models, [tokens1, tokens2]
280
+ )
281
+
282
+ if hidden_state1.dtype == torch.bfloat16:
283
+ hidden_state1 = hidden_state1.float()
284
+ if hidden_state2.dtype == torch.bfloat16:
285
+ hidden_state2 = hidden_state2.float()
286
+ if pool2.dtype == torch.bfloat16:
287
+ pool2 = pool2.float()
288
+
289
+ hidden_state1 = hidden_state1.cpu().numpy()
290
+ hidden_state2 = hidden_state2.cpu().numpy()
291
+ pool2 = pool2.cpu().numpy()
292
+
293
+ for i, info in enumerate(infos):
294
+ hidden_state1_i = hidden_state1[i]
295
+ hidden_state2_i = hidden_state2[i]
296
+ pool2_i = pool2[i]
297
+
298
+ if self.cache_to_disk:
299
+ np.savez(
300
+ info.text_encoder_outputs_npz,
301
+ hidden_state1=hidden_state1_i,
302
+ hidden_state2=hidden_state2_i,
303
+ pool2=pool2_i,
304
+ )
305
+ else:
306
+ info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
library/train_util.py ADDED
The diff for this file is too large to render. See raw diff
 
library/utils.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ import threading
4
+ from typing import *
5
+ import json
6
+ import struct
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torchvision import transforms
11
+ from diffusers import EulerAncestralDiscreteScheduler
12
+ import diffusers.schedulers.scheduling_euler_ancestral_discrete
13
+ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
14
+ import cv2
15
+ from PIL import Image
16
+ import numpy as np
17
+ from safetensors.torch import load_file
18
+
19
+
20
+ def fire_in_thread(f, *args, **kwargs):
21
+ threading.Thread(target=f, args=args, kwargs=kwargs).start()
22
+
23
+
24
+ # region Logging
25
+
26
+
27
+ def add_logging_arguments(parser):
28
+ parser.add_argument(
29
+ "--console_log_level",
30
+ type=str,
31
+ default=None,
32
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
33
+ help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
34
+ )
35
+ parser.add_argument(
36
+ "--console_log_file",
37
+ type=str,
38
+ default=None,
39
+ help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
40
+ )
41
+ parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
42
+
43
+
44
+ def setup_logging(args=None, log_level=None, reset=False):
45
+ if logging.root.handlers:
46
+ if reset:
47
+ # remove all handlers
48
+ for handler in logging.root.handlers[:]:
49
+ logging.root.removeHandler(handler)
50
+ else:
51
+ return
52
+
53
+ # log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
54
+ if log_level is None and args is not None:
55
+ log_level = args.console_log_level
56
+ if log_level is None:
57
+ log_level = "INFO"
58
+ log_level = getattr(logging, log_level)
59
+
60
+ msg_init = None
61
+ if args is not None and args.console_log_file:
62
+ handler = logging.FileHandler(args.console_log_file, mode="w")
63
+ else:
64
+ handler = None
65
+ if not args or not args.console_log_simple:
66
+ try:
67
+ from rich.logging import RichHandler
68
+ from rich.console import Console
69
+ from rich.logging import RichHandler
70
+
71
+ handler = RichHandler(console=Console(stderr=True))
72
+ except ImportError:
73
+ # print("rich is not installed, using basic logging")
74
+ msg_init = "rich is not installed, using basic logging"
75
+
76
+ if handler is None:
77
+ handler = logging.StreamHandler(sys.stdout) # same as print
78
+ handler.propagate = False
79
+
80
+ formatter = logging.Formatter(
81
+ fmt="%(message)s",
82
+ datefmt="%Y-%m-%d %H:%M:%S",
83
+ )
84
+ handler.setFormatter(formatter)
85
+ logging.root.setLevel(log_level)
86
+ logging.root.addHandler(handler)
87
+
88
+ if msg_init is not None:
89
+ logger = logging.getLogger(__name__)
90
+ logger.info(msg_init)
91
+
92
+
93
+ # endregion
94
+
95
+ # region PyTorch utils
96
+
97
+
98
+ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
99
+ assert layer_to_cpu.__class__ == layer_to_cuda.__class__
100
+
101
+ weight_swap_jobs = []
102
+ for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
103
+ if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
104
+ weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
105
+
106
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
107
+
108
+ stream = torch.cuda.Stream()
109
+ with torch.cuda.stream(stream):
110
+ # cuda to cpu
111
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
112
+ cuda_data_view.record_stream(stream)
113
+ module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
114
+
115
+ stream.synchronize()
116
+
117
+ # cpu to cuda
118
+ for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
119
+ cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
120
+ module_to_cuda.weight.data = cuda_data_view
121
+
122
+ stream.synchronize()
123
+ torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
124
+
125
+
126
+ def weighs_to_device(layer: nn.Module, device: torch.device):
127
+ for module in layer.modules():
128
+ if hasattr(module, "weight") and module.weight is not None:
129
+ module.weight.data = module.weight.data.to(device, non_blocking=True)
130
+
131
+
132
+ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
133
+ """
134
+ Convert a string to a torch.dtype
135
+
136
+ Args:
137
+ s: string representation of the dtype
138
+ default_dtype: default dtype to return if s is None
139
+
140
+ Returns:
141
+ torch.dtype: the corresponding torch.dtype
142
+
143
+ Raises:
144
+ ValueError: if the dtype is not supported
145
+
146
+ Examples:
147
+ >>> str_to_dtype("float32")
148
+ torch.float32
149
+ >>> str_to_dtype("fp32")
150
+ torch.float32
151
+ >>> str_to_dtype("float16")
152
+ torch.float16
153
+ >>> str_to_dtype("fp16")
154
+ torch.float16
155
+ >>> str_to_dtype("bfloat16")
156
+ torch.bfloat16
157
+ >>> str_to_dtype("bf16")
158
+ torch.bfloat16
159
+ >>> str_to_dtype("fp8")
160
+ torch.float8_e4m3fn
161
+ >>> str_to_dtype("fp8_e4m3fn")
162
+ torch.float8_e4m3fn
163
+ >>> str_to_dtype("fp8_e4m3fnuz")
164
+ torch.float8_e4m3fnuz
165
+ >>> str_to_dtype("fp8_e5m2")
166
+ torch.float8_e5m2
167
+ >>> str_to_dtype("fp8_e5m2fnuz")
168
+ torch.float8_e5m2fnuz
169
+ """
170
+ if s is None:
171
+ return default_dtype
172
+ if s in ["bf16", "bfloat16"]:
173
+ return torch.bfloat16
174
+ elif s in ["fp16", "float16"]:
175
+ return torch.float16
176
+ elif s in ["fp32", "float32", "float"]:
177
+ return torch.float32
178
+ elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
179
+ return torch.float8_e4m3fn
180
+ elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
181
+ return torch.float8_e4m3fnuz
182
+ elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
183
+ return torch.float8_e5m2
184
+ elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
185
+ return torch.float8_e5m2fnuz
186
+ elif s in ["fp8", "float8"]:
187
+ return torch.float8_e4m3fn # default fp8
188
+ else:
189
+ raise ValueError(f"Unsupported dtype: {s}")
190
+
191
+
192
+ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
193
+ """
194
+ memory efficient save file
195
+ """
196
+
197
+ _TYPES = {
198
+ torch.float64: "F64",
199
+ torch.float32: "F32",
200
+ torch.float16: "F16",
201
+ torch.bfloat16: "BF16",
202
+ torch.int64: "I64",
203
+ torch.int32: "I32",
204
+ torch.int16: "I16",
205
+ torch.int8: "I8",
206
+ torch.uint8: "U8",
207
+ torch.bool: "BOOL",
208
+ getattr(torch, "float8_e5m2", None): "F8_E5M2",
209
+ getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
210
+ }
211
+ _ALIGN = 256
212
+
213
+ def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
214
+ validated = {}
215
+ for key, value in metadata.items():
216
+ if not isinstance(key, str):
217
+ raise ValueError(f"Metadata key must be a string, got {type(key)}")
218
+ if not isinstance(value, str):
219
+ print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
220
+ validated[key] = str(value)
221
+ else:
222
+ validated[key] = value
223
+ return validated
224
+
225
+ print(f"Using memory efficient save file: {filename}")
226
+
227
+ header = {}
228
+ offset = 0
229
+ if metadata:
230
+ header["__metadata__"] = validate_metadata(metadata)
231
+ for k, v in tensors.items():
232
+ if v.numel() == 0: # empty tensor
233
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
234
+ else:
235
+ size = v.numel() * v.element_size()
236
+ header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
237
+ offset += size
238
+
239
+ hjson = json.dumps(header).encode("utf-8")
240
+ hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
241
+
242
+ with open(filename, "wb") as f:
243
+ f.write(struct.pack("<Q", len(hjson)))
244
+ f.write(hjson)
245
+
246
+ for k, v in tensors.items():
247
+ if v.numel() == 0:
248
+ continue
249
+ if v.is_cuda:
250
+ # Direct GPU to disk save
251
+ with torch.cuda.device(v.device):
252
+ if v.dim() == 0: # if scalar, need to add a dimension to work with view
253
+ v = v.unsqueeze(0)
254
+ tensor_bytes = v.contiguous().view(torch.uint8)
255
+ tensor_bytes.cpu().numpy().tofile(f)
256
+ else:
257
+ # CPU tensor save
258
+ if v.dim() == 0: # if scalar, need to add a dimension to work with view
259
+ v = v.unsqueeze(0)
260
+ v.contiguous().view(torch.uint8).numpy().tofile(f)
261
+
262
+
263
+ class MemoryEfficientSafeOpen:
264
+ # does not support metadata loading
265
+ def __init__(self, filename):
266
+ self.filename = filename
267
+ self.header, self.header_size = self._read_header()
268
+ self.file = open(filename, "rb")
269
+
270
+ def __enter__(self):
271
+ return self
272
+
273
+ def __exit__(self, exc_type, exc_val, exc_tb):
274
+ self.file.close()
275
+
276
+ def keys(self):
277
+ return [k for k in self.header.keys() if k != "__metadata__"]
278
+
279
+ def get_tensor(self, key):
280
+ if key not in self.header:
281
+ raise KeyError(f"Tensor '{key}' not found in the file")
282
+
283
+ metadata = self.header[key]
284
+ offset_start, offset_end = metadata["data_offsets"]
285
+
286
+ if offset_start == offset_end:
287
+ tensor_bytes = None
288
+ else:
289
+ # adjust offset by header size
290
+ self.file.seek(self.header_size + 8 + offset_start)
291
+ tensor_bytes = self.file.read(offset_end - offset_start)
292
+
293
+ return self._deserialize_tensor(tensor_bytes, metadata)
294
+
295
+ def _read_header(self):
296
+ with open(self.filename, "rb") as f:
297
+ header_size = struct.unpack("<Q", f.read(8))[0]
298
+ header_json = f.read(header_size).decode("utf-8")
299
+ return json.loads(header_json), header_size
300
+
301
+ def _deserialize_tensor(self, tensor_bytes, metadata):
302
+ dtype = self._get_torch_dtype(metadata["dtype"])
303
+ shape = metadata["shape"]
304
+
305
+ if tensor_bytes is None:
306
+ byte_tensor = torch.empty(0, dtype=torch.uint8)
307
+ else:
308
+ tensor_bytes = bytearray(tensor_bytes) # make it writable
309
+ byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
310
+
311
+ # process float8 types
312
+ if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
313
+ return self._convert_float8(byte_tensor, metadata["dtype"], shape)
314
+
315
+ # convert to the target dtype and reshape
316
+ return byte_tensor.view(dtype).reshape(shape)
317
+
318
+ @staticmethod
319
+ def _get_torch_dtype(dtype_str):
320
+ dtype_map = {
321
+ "F64": torch.float64,
322
+ "F32": torch.float32,
323
+ "F16": torch.float16,
324
+ "BF16": torch.bfloat16,
325
+ "I64": torch.int64,
326
+ "I32": torch.int32,
327
+ "I16": torch.int16,
328
+ "I8": torch.int8,
329
+ "U8": torch.uint8,
330
+ "BOOL": torch.bool,
331
+ }
332
+ # add float8 types if available
333
+ if hasattr(torch, "float8_e5m2"):
334
+ dtype_map["F8_E5M2"] = torch.float8_e5m2
335
+ if hasattr(torch, "float8_e4m3fn"):
336
+ dtype_map["F8_E4M3"] = torch.float8_e4m3fn
337
+ return dtype_map.get(dtype_str)
338
+
339
+ @staticmethod
340
+ def _convert_float8(byte_tensor, dtype_str, shape):
341
+ if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
342
+ return byte_tensor.view(torch.float8_e5m2).reshape(shape)
343
+ elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
344
+ return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
345
+ else:
346
+ # # convert to float16 if float8 is not supported
347
+ # print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
348
+ # return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
349
+ raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
350
+
351
+
352
+ def load_safetensors(
353
+ path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
354
+ ) -> dict[str, torch.Tensor]:
355
+ if disable_mmap:
356
+ # return safetensors.torch.load(open(path, "rb").read())
357
+ # use experimental loader
358
+ # logger.info(f"Loading without mmap (experimental)")
359
+ state_dict = {}
360
+ with MemoryEfficientSafeOpen(path) as f:
361
+ for key in f.keys():
362
+ state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
363
+ return state_dict
364
+ else:
365
+ try:
366
+ state_dict = load_file(path, device=device)
367
+ except:
368
+ state_dict = load_file(path) # prevent device invalid Error
369
+ if dtype is not None:
370
+ for key in state_dict.keys():
371
+ state_dict[key] = state_dict[key].to(dtype=dtype)
372
+ return state_dict
373
+
374
+
375
+ # endregion
376
+
377
+ # region Image utils
378
+
379
+
380
+ def pil_resize(image, size, interpolation=Image.LANCZOS):
381
+ has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
382
+
383
+ if has_alpha:
384
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
385
+ else:
386
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
387
+
388
+ resized_pil = pil_image.resize(size, interpolation)
389
+
390
+ # Convert back to cv2 format
391
+ if has_alpha:
392
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
393
+ else:
394
+ resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
395
+
396
+ return resized_cv2
397
+
398
+
399
+ # endregion
400
+
401
+ # TODO make inf_utils.py
402
+ # region Gradual Latent hires fix
403
+
404
+
405
+ class GradualLatent:
406
+ def __init__(
407
+ self,
408
+ ratio,
409
+ start_timesteps,
410
+ every_n_steps,
411
+ ratio_step,
412
+ s_noise=1.0,
413
+ gaussian_blur_ksize=None,
414
+ gaussian_blur_sigma=0.5,
415
+ gaussian_blur_strength=0.5,
416
+ unsharp_target_x=True,
417
+ ):
418
+ self.ratio = ratio
419
+ self.start_timesteps = start_timesteps
420
+ self.every_n_steps = every_n_steps
421
+ self.ratio_step = ratio_step
422
+ self.s_noise = s_noise
423
+ self.gaussian_blur_ksize = gaussian_blur_ksize
424
+ self.gaussian_blur_sigma = gaussian_blur_sigma
425
+ self.gaussian_blur_strength = gaussian_blur_strength
426
+ self.unsharp_target_x = unsharp_target_x
427
+
428
+ def __str__(self) -> str:
429
+ return (
430
+ f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
431
+ + f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
432
+ + f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
433
+ + f"unsharp_target_x={self.unsharp_target_x})"
434
+ )
435
+
436
+ def apply_unshark_mask(self, x: torch.Tensor):
437
+ if self.gaussian_blur_ksize is None:
438
+ return x
439
+ blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
440
+ # mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
441
+ mask = (x - blurred) * self.gaussian_blur_strength
442
+ sharpened = x + mask
443
+ return sharpened
444
+
445
+ def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
446
+ org_dtype = x.dtype
447
+ if org_dtype == torch.bfloat16:
448
+ x = x.float()
449
+
450
+ x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
451
+
452
+ # apply unsharp mask / アンシャープマスクを適用する
453
+ if unsharp and self.gaussian_blur_ksize:
454
+ x = self.apply_unshark_mask(x)
455
+
456
+ return x
457
+
458
+
459
+ class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
460
+ def __init__(self, *args, **kwargs):
461
+ super().__init__(*args, **kwargs)
462
+ self.resized_size = None
463
+ self.gradual_latent = None
464
+
465
+ def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
466
+ self.resized_size = size
467
+ self.gradual_latent = gradual_latent
468
+
469
+ def step(
470
+ self,
471
+ model_output: torch.FloatTensor,
472
+ timestep: Union[float, torch.FloatTensor],
473
+ sample: torch.FloatTensor,
474
+ generator: Optional[torch.Generator] = None,
475
+ return_dict: bool = True,
476
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
477
+ """
478
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
479
+ process from the learned model outputs (most often the predicted noise).
480
+
481
+ Args:
482
+ model_output (`torch.FloatTensor`):
483
+ The direct output from learned diffusion model.
484
+ timestep (`float`):
485
+ The current discrete timestep in the diffusion chain.
486
+ sample (`torch.FloatTensor`):
487
+ A current instance of a sample created by the diffusion process.
488
+ generator (`torch.Generator`, *optional*):
489
+ A random number generator.
490
+ return_dict (`bool`):
491
+ Whether or not to return a
492
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
493
+
494
+ Returns:
495
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
496
+ If return_dict is `True`,
497
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
498
+ otherwise a tuple is returned where the first element is the sample tensor.
499
+
500
+ """
501
+
502
+ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
503
+ raise ValueError(
504
+ (
505
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
506
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
507
+ " one of the `scheduler.timesteps` as a timestep."
508
+ ),
509
+ )
510
+
511
+ if not self.is_scale_input_called:
512
+ # logger.warning(
513
+ print(
514
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
515
+ "See `StableDiffusionPipeline` for a usage example."
516
+ )
517
+
518
+ if self.step_index is None:
519
+ self._init_step_index(timestep)
520
+
521
+ sigma = self.sigmas[self.step_index]
522
+
523
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
524
+ if self.config.prediction_type == "epsilon":
525
+ pred_original_sample = sample - sigma * model_output
526
+ elif self.config.prediction_type == "v_prediction":
527
+ # * c_out + input * c_skip
528
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
529
+ elif self.config.prediction_type == "sample":
530
+ raise NotImplementedError("prediction_type not implemented yet: sample")
531
+ else:
532
+ raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
533
+
534
+ sigma_from = self.sigmas[self.step_index]
535
+ sigma_to = self.sigmas[self.step_index + 1]
536
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
537
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
538
+
539
+ # 2. Convert to an ODE derivative
540
+ derivative = (sample - pred_original_sample) / sigma
541
+
542
+ dt = sigma_down - sigma
543
+
544
+ device = model_output.device
545
+ if self.resized_size is None:
546
+ prev_sample = sample + derivative * dt
547
+
548
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
549
+ model_output.shape, dtype=model_output.dtype, device=device, generator=generator
550
+ )
551
+ s_noise = 1.0
552
+ else:
553
+ print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
554
+ s_noise = self.gradual_latent.s_noise
555
+
556
+ if self.gradual_latent.unsharp_target_x:
557
+ prev_sample = sample + derivative * dt
558
+ prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
559
+ else:
560
+ sample = self.gradual_latent.interpolate(sample, self.resized_size)
561
+ derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
562
+ prev_sample = sample + derivative * dt
563
+
564
+ noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
565
+ (model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
566
+ dtype=model_output.dtype,
567
+ device=device,
568
+ generator=generator,
569
+ )
570
+
571
+ prev_sample = prev_sample + noise * sigma_up * s_noise
572
+
573
+ # upon completion increase step index by one
574
+ self._step_index += 1
575
+
576
+ if not return_dict:
577
+ return (prev_sample,)
578
+
579
+ return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
580
+
581
+
582
+ # endregion