ginipick commited on
Commit
7380a20
·
verified ·
1 Parent(s): 37f1d99

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -204
app.py CHANGED
@@ -71,7 +71,6 @@ def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
71
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
72
  gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
73
  kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
74
-
75
  return kernel
76
 
77
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
@@ -88,9 +87,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
88
  """
89
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
90
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
91
- # rescale the results from guidance (fixes overexposure)
92
  noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
93
- # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
94
  noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
95
  return noise_cfg
96
 
@@ -144,7 +141,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
144
  add_watermarker: Optional[bool] = None,
145
  ):
146
  super().__init__()
147
-
148
  self.register_modules(
149
  vae=vae,
150
  text_encoder=text_encoder,
@@ -158,27 +154,21 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
158
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
159
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
160
  self.default_sample_size = self.unet.config.sample_size
161
-
162
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
163
-
164
  if add_watermarker:
165
  self.watermark = StableDiffusionXLWatermarker()
166
  else:
167
  self.watermark = None
168
 
169
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
170
  def enable_vae_slicing(self):
171
  self.vae.enable_slicing()
172
 
173
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
174
  def disable_vae_slicing(self):
175
  self.vae.disable_slicing()
176
 
177
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
178
  def enable_vae_tiling(self):
179
  self.vae.enable_tiling()
180
 
181
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
182
  def disable_vae_tiling(self):
183
  self.vae.disable_tiling()
184
 
@@ -198,25 +188,20 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
198
  lora_scale: Optional[float] = None,
199
  ):
200
  device = device or self._execution_device
201
-
202
- # set lora scale so that monkey patched LoRA function of text encoder can correctly access it
203
  if lora_scale is not None and isinstance(self, LoraLoaderMixin):
204
  self._lora_scale = lora_scale
205
  adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
206
  adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
207
-
208
  if prompt is not None and isinstance(prompt, str):
209
  batch_size = 1
210
  elif prompt is not None and isinstance(prompt, list):
211
  batch_size = len(prompt)
212
  else:
213
  batch_size = prompt_embeds.shape[0]
214
-
215
  tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
216
  text_encoders = (
217
  [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
218
  )
219
-
220
  if prompt_embeds is None:
221
  prompt_2 = prompt_2 or prompt
222
  prompt_embeds_list = []
@@ -224,7 +209,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
224
  for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
225
  if isinstance(self, TextualInversionLoaderMixin):
226
  prompt = self.maybe_convert_prompt(prompt, tokenizer)
227
-
228
  text_inputs = tokenizer(
229
  prompt,
230
  padding="max_length",
@@ -232,10 +216,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
232
  truncation=True,
233
  return_tensors="pt",
234
  )
235
-
236
  text_input_ids = text_inputs.input_ids
237
  untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
238
-
239
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
240
  text_input_ids, untruncated_ids
241
  ):
@@ -244,19 +226,14 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
244
  "The following part of your input was truncated because CLIP can only handle sequences up to"
245
  f" {tokenizer.model_max_length} tokens: {removed_text}"
246
  )
247
-
248
  prompt_embeds = text_encoder(
249
  text_input_ids.to(device),
250
  output_hidden_states=True,
251
  )
252
-
253
  pooled_prompt_embeds = prompt_embeds[0]
254
  prompt_embeds = prompt_embeds.hidden_states[-2]
255
-
256
  prompt_embeds_list.append(prompt_embeds)
257
-
258
  prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
259
-
260
  zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
261
  if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
262
  negative_prompt_embeds = torch.zeros_like(prompt_embeds)
@@ -264,7 +241,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
264
  elif do_classifier_free_guidance and negative_prompt_embeds is None:
265
  negative_prompt = negative_prompt or ""
266
  negative_prompt_2 = negative_prompt_2 or negative_prompt
267
-
268
  if prompt is not None and type(prompt) is not type(negative_prompt):
269
  raise TypeError(
270
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
@@ -280,12 +256,10 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
280
  )
281
  else:
282
  uncond_tokens = [negative_prompt, negative_prompt_2]
283
-
284
  negative_prompt_embeds_list = []
285
  for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
286
  if isinstance(self, TextualInversionLoaderMixin):
287
  negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
288
-
289
  max_length = prompt_embeds.shape[1]
290
  uncond_input = tokenizer(
291
  negative_prompt,
@@ -294,29 +268,23 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
294
  truncation=True,
295
  return_tensors="pt",
296
  )
297
-
298
  negative_prompt_embeds = text_encoder(
299
  uncond_input.input_ids.to(device),
300
  output_hidden_states=True,
301
  )
302
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
303
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
304
-
305
  negative_prompt_embeds_list.append(negative_prompt_embeds)
306
-
307
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
308
-
309
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
310
  bs_embed, seq_len, _ = prompt_embeds.shape
311
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
312
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
313
-
314
  if do_classifier_free_guidance:
315
  seq_len = negative_prompt_embeds.shape[1]
316
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
317
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
318
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
319
-
320
  pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
321
  bs_embed * num_images_per_prompt, -1
322
  )
@@ -324,7 +292,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
324
  negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
325
  bs_embed * num_images_per_prompt, -1
326
  )
327
-
328
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
329
 
330
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -332,7 +299,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
332
  extra_step_kwargs = {}
333
  if accepts_eta:
334
  extra_step_kwargs["eta"] = eta
335
-
336
  accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
337
  if accepts_generator:
338
  extra_step_kwargs["generator"] = generator
@@ -355,14 +321,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
355
  ):
356
  if height % 8 != 0 or width % 8 != 0:
357
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
358
-
359
  if (callback_steps is None) or (
360
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
361
  ):
362
  raise ValueError(
363
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type {type(callback_steps)}."
364
  )
365
-
366
  if prompt is not None and prompt_embeds is not None:
367
  raise ValueError(
368
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
@@ -379,7 +343,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
379
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
380
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
381
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
382
-
383
  if negative_prompt is not None and negative_prompt_embeds is not None:
384
  raise ValueError(
385
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
@@ -388,26 +351,21 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
388
  raise ValueError(
389
  f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
390
  )
391
-
392
  if prompt_embeds is not None and negative_prompt_embeds is not None:
393
  if prompt_embeds.shape != negative_prompt_embeds.shape:
394
  raise ValueError(
395
  "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds` {negative_prompt_embeds.shape}."
396
  )
397
-
398
  if prompt_embeds is not None and pooled_prompt_embeds is None:
399
  raise ValueError(
400
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
401
  )
402
-
403
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
404
  raise ValueError(
405
  "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
406
  )
407
-
408
  if max(height, width) % 1024 != 0:
409
  raise ValueError(f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}.")
410
-
411
  if num_images_per_prompt != 1:
412
  warnings.warn("num_images_per_prompt != 1 is not supported by AccDiffusion and will be ignored.")
413
  num_images_per_prompt = 1
@@ -418,29 +376,24 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
418
  raise ValueError(
419
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators."
420
  )
421
-
422
  if latents is None:
423
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
424
  else:
425
  latents = latents.to(device)
426
-
427
  latents = latents * self.scheduler.init_noise_sigma
428
  return latents
429
 
430
  def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
431
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
432
-
433
  passed_add_embed_dim = (
434
  self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
435
  )
436
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
437
-
438
  if expected_add_embed_dim != passed_add_embed_dim:
439
  raise ValueError(
440
  f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. \
441
  The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
442
  )
443
-
444
  add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
445
  return add_time_ids
446
 
@@ -456,7 +409,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
456
  h_end = h_start + window_size
457
  w_start = int((i % num_blocks_width) * stride)
458
  w_end = w_start + window_size
459
-
460
  if h_end > height:
461
  h_start = int(h_start + height - h_end)
462
  h_end = int(height)
@@ -469,7 +421,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
469
  if w_start < 0:
470
  w_end = int(w_end - w_start)
471
  w_start = 0
472
-
473
  if random_jitter:
474
  jitter_range = (window_size - stride) // 4
475
  w_jitter = 0
@@ -480,7 +431,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
480
  w_jitter = random.randint(-jitter_range, 0)
481
  elif (w_start != 0) and (w_end == width):
482
  w_jitter = random.randint(0, jitter_range)
483
-
484
  if (h_start != 0) and (h_end != height):
485
  h_jitter = random.randint(-jitter_range, jitter_range)
486
  elif (h_start == 0) and (h_end != height):
@@ -491,7 +441,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
491
  h_end = h_end + h_jitter + jitter_range
492
  w_start = w_start + w_jitter + jitter_range
493
  w_end = w_end + w_jitter + jitter_range
494
-
495
  views.append((h_start, h_end, w_start, w_end))
496
  return views
497
 
@@ -527,7 +476,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
527
  continue
528
  cross_att_count += 1
529
  attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
530
-
531
  self.unet.set_attn_processor(attn_procs)
532
  controller.num_att_layers = cross_att_count
533
  return ori_attn_processors
@@ -540,7 +488,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
540
  from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
541
  else:
542
  raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
543
-
544
  is_model_cpu_offload = False
545
  is_sequential_cpu_offload = False
546
  recursive = False
@@ -560,7 +507,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
560
  **kwargs,
561
  )
562
  self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
563
-
564
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
565
  if len(text_encoder_state_dict) > 0:
566
  self.load_lora_into_text_encoder(
@@ -570,7 +516,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
570
  prefix="text_encoder",
571
  lora_scale=self.lora_scale,
572
  )
573
-
574
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
575
  if len(text_encoder_2_state_dict) > 0:
576
  self.load_lora_into_text_encoder(
@@ -580,7 +525,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
580
  prefix="text_encoder_2",
581
  lora_scale=self.lora_scale,
582
  )
583
-
584
  if is_model_cpu_offload:
585
  self.enable_model_cpu_offload()
586
  elif is_sequential_cpu_offload:
@@ -599,24 +543,19 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
599
  safe_serialization: bool = True,
600
  ):
601
  state_dict = {}
602
-
603
  def pack_weights(layers, prefix):
604
  layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
605
  layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
606
  return layers_state_dict
607
-
608
  if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
609
  raise ValueError(
610
  "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
611
  )
612
-
613
  if unet_lora_layers:
614
  state_dict.update(pack_weights(unet_lora_layers, "unet"))
615
-
616
  if text_encoder_lora_layers and text_encoder_2_lora_layers:
617
  state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
618
  state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
619
-
620
  self.write_lora_layers(
621
  state_dict=state_dict,
622
  save_directory=save_directory,
@@ -697,39 +636,29 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
697
 
698
  Examples:
699
  """
700
-
701
  if debug:
702
  num_inference_steps = 1
703
-
704
  height = height or self.default_sample_size * self.vae_scale_factor
705
  width = width or self.default_sample_size * self.vae_scale_factor
706
-
707
  x1_size = self.default_sample_size * self.vae_scale_factor
708
-
709
  height_scale = height / x1_size
710
  width_scale = width / x1_size
711
  scale_num = int(max(height_scale, width_scale))
712
  aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)
713
-
714
  original_size = original_size or (height, width)
715
  target_size = target_size or (height, width)
716
-
717
  if attn_res is None:
718
- attn_res = int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32)), int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32))
719
  self.attn_res = attn_res
720
-
721
  if lowvram:
722
  attention_map_device = torch.device("cpu")
723
  else:
724
  attention_map_device = self.device
725
-
726
  self.controller = create_controller(
727
  prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=attention_map_device, attn_res=self.attn_res
728
  )
729
-
730
  if save_attention_map or use_md_prompt:
731
  ori_attn_processors = self.register_attention_control(self.controller)
732
-
733
  self.check_inputs(
734
  prompt,
735
  prompt_2,
@@ -744,14 +673,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
744
  negative_pooled_prompt_embeds,
745
  num_images_per_prompt,
746
  )
747
-
748
  if prompt is not None and isinstance(prompt, str):
749
  batch_size = 1
750
  elif prompt is not None and isinstance(prompt, list):
751
  batch_size = len(prompt)
752
  else:
753
  batch_size = prompt_embeds.shape[0]
754
-
755
  device = self._execution_device
756
  self.lowvram = lowvram
757
  if self.lowvram:
@@ -759,13 +686,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
759
  self.unet.cpu()
760
  self.text_encoder.to(device)
761
  self.text_encoder_2.to(device)
762
-
763
  do_classifier_free_guidance = guidance_scale > 1.0
764
-
765
- text_encoder_lora_scale = (
766
- cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
767
- )
768
-
769
  (
770
  prompt_embeds,
771
  negative_prompt_embeds,
@@ -785,10 +707,8 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
785
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
786
  lora_scale=text_encoder_lora_scale,
787
  )
788
-
789
  self.scheduler.set_timesteps(num_inference_steps, device=device)
790
  timesteps = self.scheduler.timesteps
791
-
792
  num_channels_latents = self.unet.config.in_channels
793
  latents = self.prepare_latents(
794
  batch_size * num_images_per_prompt,
@@ -800,15 +720,11 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
800
  generator,
801
  latents,
802
  )
803
-
804
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
805
-
806
  add_text_embeds = pooled_prompt_embeds
807
-
808
  add_time_ids = self._get_add_time_ids(
809
  original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
810
  )
811
-
812
  if negative_original_size is not None and negative_target_size is not None:
813
  negative_add_time_ids = self._get_add_time_ids(
814
  negative_original_size,
@@ -818,16 +734,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
818
  )
819
  else:
820
  negative_add_time_ids = add_time_ids
821
-
822
  if do_classifier_free_guidance:
823
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
824
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0).to(device)
825
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device).repeat(batch_size * num_images_per_prompt, 1)
826
-
827
  del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
828
-
829
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
830
-
831
  if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
832
  discrete_timestep_cutoff = int(
833
  round(
@@ -836,35 +748,26 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
836
  )
837
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
838
  timesteps = timesteps[:num_inference_steps]
839
-
840
  output_images = []
841
-
842
  ###################################################### Phase Initialization ########################################################
843
-
844
  if self.lowvram:
845
  self.text_encoder.cpu()
846
  self.text_encoder_2.cpu()
847
-
848
  if image_lr == None:
849
  print("### Phase 1 Denoising ###")
850
  with self.progress_bar(total=num_inference_steps) as progress_bar:
851
  for i, t in enumerate(timesteps):
852
-
853
  if self.lowvram:
854
  self.vae.cpu()
855
  self.unet.to(device)
856
-
857
  latents_for_view = latents
858
-
859
  latent_model_input = (
860
  latents.repeat_interleave(2, dim=0)
861
  if do_classifier_free_guidance
862
  else latents
863
  )
864
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
865
-
866
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
867
-
868
  noise_pred = self.unet(
869
  latent_model_input,
870
  t,
@@ -872,25 +775,19 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
872
  added_cond_kwargs=added_cond_kwargs,
873
  return_dict=False,
874
  )[0]
875
-
876
  if do_classifier_free_guidance:
877
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
878
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
879
-
880
  if do_classifier_free_guidance and guidance_rescale > 0.0:
881
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
882
-
883
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
884
-
885
  if t == 1 and use_md_prompt:
886
  md_prompts, views_attention = get_multidiffusion_prompts(tokenizer=self.tokenizer, prompts=[prompt], threthod=c, attention_store=self.controller, height=height//scale_num, width=width//scale_num, from_where=["up","down"], random_jitter=True, scale_num=scale_num)
887
-
888
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
889
  progress_bar.update()
890
  if callback is not None and i % callback_steps == 0:
891
  step_idx = i // getattr(self.scheduler, "order", 1)
892
  callback(step_idx, t, latents)
893
-
894
  del latents_for_view, latent_model_input, noise_pred, noise_pred_text, noise_pred_uncond
895
  if use_md_prompt or save_attention_map:
896
  self.recover_attention_control(ori_attn_processors=ori_attn_processors)
@@ -900,7 +797,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
900
  print("### Encoding Real Image ###")
901
  latents = self.vae.encode(image_lr)
902
  latents = latents.latent_dist.sample() * self.vae.config.scaling_factor
903
-
904
  anchor_mean = latents.mean()
905
  anchor_std = latents.std()
906
  if self.lowvram:
@@ -908,12 +804,10 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
908
  torch.cuda.empty_cache()
909
  if not output_type == "latent":
910
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
911
-
912
  if self.lowvram:
913
  needs_upcasting = False
914
  self.unet.cpu()
915
  self.vae.to(device)
916
-
917
  if needs_upcasting:
918
  self.upcast_vae()
919
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
@@ -924,15 +818,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
924
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
925
  if needs_upcasting:
926
  self.vae.to(dtype=torch.float16)
927
-
928
  image = self.image_processor.postprocess(image, output_type=output_type)
929
  if not os.path.exists(f'{result_path}'):
930
  os.makedirs(f'{result_path}')
931
-
932
  image_lr_save_path = f'{result_path}/{image[0].size[0]}_{image[0].size[1]}.png'
933
  image[0].save(image_lr_save_path)
934
  output_images.append(image[0])
935
-
936
  ####################################################### Phase Upscaling #####################################################
937
  if use_progressive_upscaling:
938
  if image_lr == None:
@@ -941,26 +832,21 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
941
  starting_scale = 1
942
  else:
943
  starting_scale = scale_num
944
-
945
  for current_scale_num in range(starting_scale, scale_num + 1):
946
  if self.lowvram:
947
  latents = latents.to(device)
948
  self.unet.to(device)
949
  torch.cuda.empty_cache()
950
-
951
  current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
952
  current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
953
-
954
  if height > width:
955
  current_width = int(current_width * aspect_ratio)
956
  else:
957
  current_height = int(current_height * aspect_ratio)
958
-
959
  if upscale_mode == "bicubic_latent" or debug:
960
  latents = F.interpolate(latents.to(device), size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)), mode='bicubic')
961
  else:
962
  raise NotImplementedError
963
-
964
  print("### Phase {} Denoising ###".format(current_scale_num))
965
  noise_latents = []
966
  noise = torch.randn_like(latents)
@@ -968,7 +854,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
968
  noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
969
  noise_latents.append(noise_latent)
970
  latents = noise_latents[0]
971
-
972
  with self.progress_bar(total=num_inference_steps) as progress_bar:
973
  for i, t in enumerate(timesteps):
974
  count = torch.zeros_like(latents)
@@ -977,7 +862,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
977
  if use_skip_residual:
978
  c1 = cosine_factor ** cosine_scale_1
979
  latents = latents * (1 - c1) + noise_latents[i] * c1
980
-
981
  if use_multidiffusion:
982
  if use_md_prompt:
983
  md_prompt_embeds_list = []
@@ -1005,46 +889,33 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1005
  md_prompt_embeds_list.append(torch.cat([md_negative_prompt_embeds, md_prompt_embeds], dim=0).to(device))
1006
  md_add_text_embeds_list.append(torch.cat([md_negative_pooled_prompt_embeds, md_pooled_prompt_embeds], dim=0).to(device))
1007
  del md_negative_prompt_embeds, md_negative_pooled_prompt_embeds
1008
-
1009
  if use_md_prompt:
1010
  random_jitter = True
1011
  views = [(h_start*4, h_end*4, w_start*4, w_end*4) for h_start, h_end, w_start, w_end in views_attention[current_scale_num]]
1012
  else:
1013
  random_jitter = True
1014
  views = self.get_views(current_height, current_width, stride=stride, window_size=self.unet.config.sample_size, random_jitter=random_jitter)
1015
-
1016
  views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1017
-
1018
  if use_md_prompt:
1019
  views_prompt_embeds_input = [md_prompt_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1020
  views_add_text_embeds_input = [md_add_text_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1021
-
1022
  if random_jitter:
1023
  jitter_range = int((self.unet.config.sample_size - stride) // 4)
1024
  latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), 'constant', 0)
1025
  else:
1026
  latents_ = latents
1027
-
1028
  count_local = torch.zeros_like(latents_)
1029
  value_local = torch.zeros_like(latents_)
1030
-
1031
  for j, batch_view in enumerate(views_batch):
1032
  vb_size = len(batch_view)
1033
  latents_for_view = torch.cat(
1034
- [
1035
- latents_[:, :, h_start:h_end, w_start:w_end]
1036
- for h_start, h_end, w_start, w_end in batch_view
1037
- ]
1038
  )
1039
-
1040
  latent_model_input = latents_for_view
1041
- latent_model_input = (
1042
- latent_model_input.repeat_interleave(2, dim=0)
1043
- if do_classifier_free_guidance
1044
- else latent_model_input
1045
- )
1046
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1047
-
1048
  add_time_ids_input = []
1049
  for h_start, h_end, w_start, w_end in batch_view:
1050
  add_time_ids_ = add_time_ids.clone()
@@ -1052,7 +923,6 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1052
  add_time_ids_[:, 3] = w_start * self.vae_scale_factor
1053
  add_time_ids_input.append(add_time_ids_)
1054
  add_time_ids_input = torch.cat(add_time_ids_input)
1055
-
1056
  if not use_md_prompt:
1057
  prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
1058
  add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
@@ -1075,30 +945,20 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1075
  added_cond_kwargs=md_added_cond_kwargs,
1076
  return_dict=False,
1077
  )[0]
1078
-
1079
  if do_classifier_free_guidance:
1080
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
1081
  noise_pred = noise_pred_uncond + multi_guidance_scale * (noise_pred_text - noise_pred_uncond)
1082
-
1083
  if do_classifier_free_guidance and guidance_rescale > 0.0:
1084
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1085
-
1086
  self.scheduler._init_step_index(t)
1087
- latents_denoised_batch = self.scheduler.step(
1088
- noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
1089
-
1090
- for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
1091
- latents_denoised_batch.chunk(vb_size), batch_view
1092
- ):
1093
  value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
1094
  count_local[:, :, h_start:h_end, w_start:w_end] += 1
1095
-
1096
  if random_jitter:
1097
  value_local = value_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
1098
  count_local = count_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
1099
-
1100
  noise_index = i + 1 if i != (len(timesteps) - 1) else i
1101
-
1102
  value_local = torch.where(count_local == 0, noise_latents[noise_index], value_local)
1103
  count_local = torch.where(count_local == 0, torch.ones_like(count_local), count_local)
1104
  if use_dilated_sampling:
@@ -1107,19 +967,15 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1107
  count += torch.ones_like(value_local) * (1 - c2)
1108
  else:
1109
  value += value_local / count_local
1110
- count += torch.ones_like(value_local)
1111
-
1112
  if use_dilated_sampling:
1113
  views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]
1114
  views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1115
-
1116
  h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num
1117
  w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num
1118
  latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), 'constant', 0)
1119
-
1120
  count_global = torch.zeros_like(latents_)
1121
  value_global = torch.zeros_like(latents_)
1122
-
1123
  if use_guassian:
1124
  c3 = 0.99 * cosine_factor ** cosine_scale_3 + 1e-2
1125
  std_, mean_ = latents_.std(), latents_.mean()
@@ -1127,47 +983,31 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1127
  latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_
1128
  else:
1129
  latents_gaussian = latents_
1130
-
1131
  for j, batch_view in enumerate(views_batch):
1132
-
1133
  latents_for_view = torch.cat(
1134
- [
1135
- latents_[:, :, h::current_scale_num, w::current_scale_num]
1136
- for h, w in batch_view
1137
- ]
1138
  )
1139
-
1140
  latents_for_view_gaussian = torch.cat(
1141
- [
1142
- latents_gaussian[:, :, h::current_scale_num, w::current_scale_num]
1143
- for h, w in batch_view
1144
- ]
1145
  )
1146
-
1147
  if shuffle:
1148
  shape = latents_for_view.shape
1149
- # 수정: range(...) 괄호를 추가합니다.
1150
  shuffle_index = torch.stack([torch.randperm(shape[0]) for _ in range(latents_for_view.reshape(-1).shape[0]//shape[0])])
1151
  shuffle_index = shuffle_index.view(shape[1], shape[2], shape[3], shape[0])
1152
  original_index = torch.zeros_like(shuffle_index).scatter_(3, shuffle_index, torch.arange(shape[0]).repeat(shape[1], shape[2], shape[3], 1))
1153
  shuffle_index = shuffle_index.permute(3, 0, 1, 2).to(device)
1154
  original_index = original_index.permute(3, 0, 1, 2).to(device)
1155
  latents_for_view_gaussian = latents_for_view_gaussian.gather(0, shuffle_index)
1156
-
1157
  vb_size = latents_for_view.size(0)
1158
-
1159
  latent_model_input = latents_for_view_gaussian
1160
- latent_model_input = (
1161
- latent_model_input.repeat_interleave(2, dim=0)
1162
- if do_classifier_free_guidance
1163
- else latent_model_input
1164
- )
1165
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1166
-
1167
  prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
1168
  add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
1169
  add_time_ids_input = torch.cat([add_time_ids] * vb_size)
1170
-
1171
  added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
1172
  noise_pred = self.unet(
1173
  latent_model_input,
@@ -1176,28 +1016,19 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1176
  added_cond_kwargs=added_cond_kwargs,
1177
  return_dict=False,
1178
  )[0]
1179
-
1180
  if do_classifier_free_guidance:
1181
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
1182
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1183
-
1184
  if do_classifier_free_guidance and guidance_rescale > 0.0:
1185
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1186
-
1187
  if shuffle:
1188
  noise_pred = noise_pred.gather(0, original_index)
1189
-
1190
  self.scheduler._init_step_index(t)
1191
  latents_denoised_batch = self.scheduler.step(noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
1192
-
1193
- for latents_view_denoised, (h, w) in zip(
1194
- latents_denoised_batch.chunk(vb_size), batch_view
1195
- ):
1196
  value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
1197
  count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
1198
-
1199
  value_global = value_global[:, :, h_pad:, w_pad:]
1200
-
1201
  if use_multidiffusion:
1202
  c2 = cosine_factor ** cosine_scale_2
1203
  value += value_global * c2
@@ -1205,15 +1036,12 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1205
  else:
1206
  value += value_global
1207
  count += torch.ones_like(value_global)
1208
-
1209
  latents = torch.where(count > 0, value / count, value)
1210
-
1211
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1212
  progress_bar.update()
1213
  if callback is not None and i % callback_steps == 0:
1214
  step_idx = i // getattr(self.scheduler, "order", 1)
1215
  callback(step_idx, t, latents)
1216
-
1217
  latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
1218
  if self.lowvram:
1219
  latents = latents.cpu()
@@ -1224,29 +1052,23 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
1224
  needs_upcasting = False
1225
  self.unet.cpu()
1226
  self.vae.to(device)
1227
-
1228
  if needs_upcasting:
1229
  self.upcast_vae()
1230
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1231
-
1232
  print("### Phase {} Decoding ###".format(current_scale_num))
1233
  if current_height > 2048 or current_width > 2048:
1234
  self.enable_vae_tiling()
1235
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1236
  else:
1237
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1238
-
1239
  image = self.image_processor.postprocess(image, output_type=output_type)
1240
  image[0].save(f'{result_path}/AccDiffusion_{current_scale_num}.png')
1241
  output_images.append(image[0])
1242
-
1243
  if needs_upcasting:
1244
  self.vae.to(dtype=torch.float16)
1245
  else:
1246
  image = latents
1247
-
1248
  self.maybe_free_model_hooks()
1249
-
1250
  return output_images
1251
 
1252
 
@@ -1280,9 +1102,7 @@ if __name__ == "__main__":
1280
  ## others ##
1281
  parser.add_argument('--debug', default=False, action='store_true')
1282
  parser.add_argument('--experiment_name', default="AccDiffusion")
1283
-
1284
  args = parser.parse_args()
1285
-
1286
  pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1287
 
1288
  @spaces.GPU(duration=200)
@@ -1294,11 +1114,9 @@ if __name__ == "__main__":
1294
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1295
  }
1296
  generator = torch.Generator(device='cuda').manual_seed(seed)
1297
-
1298
  print(f"Prompt: {prompt}")
1299
  md5_hash = hashlib.md5(prompt.encode()).hexdigest()
1300
  result_path = f"./output/{args.experiment_name}/{md5_hash}/{width}_{height}_{seed}/"
1301
-
1302
  images = pipe(
1303
  prompt,
1304
  negative_prompt=args.negative_prompt,
@@ -1326,11 +1144,9 @@ if __name__ == "__main__":
1326
  debug=args.debug, save_attention_map=args.save_attention_map, use_md_prompt=use_md_prompt, c=args.c
1327
  )
1328
  print(images)
1329
-
1330
  return images
1331
 
1332
  MAX_SEED = np.iinfo(np.int32).max
1333
-
1334
  css = """
1335
  body {
1336
  background: linear-gradient(135deg, #2c3e50, #4ca1af);
@@ -1353,7 +1169,6 @@ if __name__ == "__main__":
1353
  visibility: hidden;
1354
  }
1355
  """
1356
-
1357
  with gr.Blocks(css=css) as demo:
1358
  with gr.Column(elem_id="col-container"):
1359
  gr.Markdown("<h1>AccDiffusion: Advanced AI Art Generator</h1>")
@@ -1363,7 +1178,6 @@ if __name__ == "__main__":
1363
  with gr.Row():
1364
  prompt = gr.Textbox(label="Prompt", placeholder="예: A surreal landscape with floating islands and vibrant colors.", lines=2, scale=4)
1365
  submit_btn = gr.Button("Generate", scale=1)
1366
-
1367
  with gr.Accordion("Advanced Settings", open=False):
1368
  with gr.Row():
1369
  resolution = gr.Radio(
@@ -1386,8 +1200,7 @@ if __name__ == "__main__":
1386
  use_progressive_upscaling = gr.Checkbox(label="Use Progressive Upscaling", value=False)
1387
  shuffle = gr.Checkbox(label="Shuffle", value=False)
1388
  use_md_prompt = gr.Checkbox(label="Use MD Prompt", value=False)
1389
-
1390
- output_images = gr.Gallery(label="Output Images", format="png").style(grid=[2], height="auto")
1391
  gr.Markdown("### Example Prompts")
1392
  gr.Examples(
1393
  examples = [
 
71
  gaussian_1d = gaussian_1d / gaussian_1d.sum()
72
  gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
73
  kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
 
74
  return kernel
75
 
76
  def gaussian_filter(latents, kernel_size=3, sigma=1.0):
 
87
  """
88
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
89
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
 
90
  noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
 
91
  noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
92
  return noise_cfg
93
 
 
141
  add_watermarker: Optional[bool] = None,
142
  ):
143
  super().__init__()
 
144
  self.register_modules(
145
  vae=vae,
146
  text_encoder=text_encoder,
 
154
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
155
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
156
  self.default_sample_size = self.unet.config.sample_size
 
157
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
 
158
  if add_watermarker:
159
  self.watermark = StableDiffusionXLWatermarker()
160
  else:
161
  self.watermark = None
162
 
 
163
  def enable_vae_slicing(self):
164
  self.vae.enable_slicing()
165
 
 
166
  def disable_vae_slicing(self):
167
  self.vae.disable_slicing()
168
 
 
169
  def enable_vae_tiling(self):
170
  self.vae.enable_tiling()
171
 
 
172
  def disable_vae_tiling(self):
173
  self.vae.disable_tiling()
174
 
 
188
  lora_scale: Optional[float] = None,
189
  ):
190
  device = device or self._execution_device
 
 
191
  if lora_scale is not None and isinstance(self, LoraLoaderMixin):
192
  self._lora_scale = lora_scale
193
  adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
194
  adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
 
195
  if prompt is not None and isinstance(prompt, str):
196
  batch_size = 1
197
  elif prompt is not None and isinstance(prompt, list):
198
  batch_size = len(prompt)
199
  else:
200
  batch_size = prompt_embeds.shape[0]
 
201
  tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
202
  text_encoders = (
203
  [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
204
  )
 
205
  if prompt_embeds is None:
206
  prompt_2 = prompt_2 or prompt
207
  prompt_embeds_list = []
 
209
  for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
210
  if isinstance(self, TextualInversionLoaderMixin):
211
  prompt = self.maybe_convert_prompt(prompt, tokenizer)
 
212
  text_inputs = tokenizer(
213
  prompt,
214
  padding="max_length",
 
216
  truncation=True,
217
  return_tensors="pt",
218
  )
 
219
  text_input_ids = text_inputs.input_ids
220
  untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
 
221
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
222
  text_input_ids, untruncated_ids
223
  ):
 
226
  "The following part of your input was truncated because CLIP can only handle sequences up to"
227
  f" {tokenizer.model_max_length} tokens: {removed_text}"
228
  )
 
229
  prompt_embeds = text_encoder(
230
  text_input_ids.to(device),
231
  output_hidden_states=True,
232
  )
 
233
  pooled_prompt_embeds = prompt_embeds[0]
234
  prompt_embeds = prompt_embeds.hidden_states[-2]
 
235
  prompt_embeds_list.append(prompt_embeds)
 
236
  prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
 
237
  zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
238
  if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
239
  negative_prompt_embeds = torch.zeros_like(prompt_embeds)
 
241
  elif do_classifier_free_guidance and negative_prompt_embeds is None:
242
  negative_prompt = negative_prompt or ""
243
  negative_prompt_2 = negative_prompt_2 or negative_prompt
 
244
  if prompt is not None and type(prompt) is not type(negative_prompt):
245
  raise TypeError(
246
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
 
256
  )
257
  else:
258
  uncond_tokens = [negative_prompt, negative_prompt_2]
 
259
  negative_prompt_embeds_list = []
260
  for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
261
  if isinstance(self, TextualInversionLoaderMixin):
262
  negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
 
263
  max_length = prompt_embeds.shape[1]
264
  uncond_input = tokenizer(
265
  negative_prompt,
 
268
  truncation=True,
269
  return_tensors="pt",
270
  )
 
271
  negative_prompt_embeds = text_encoder(
272
  uncond_input.input_ids.to(device),
273
  output_hidden_states=True,
274
  )
275
  negative_pooled_prompt_embeds = negative_prompt_embeds[0]
276
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
 
277
  negative_prompt_embeds_list.append(negative_prompt_embeds)
 
278
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
 
279
  prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
280
  bs_embed, seq_len, _ = prompt_embeds.shape
281
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
282
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
283
  if do_classifier_free_guidance:
284
  seq_len = negative_prompt_embeds.shape[1]
285
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
286
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
287
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
288
  pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
289
  bs_embed * num_images_per_prompt, -1
290
  )
 
292
  negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
293
  bs_embed * num_images_per_prompt, -1
294
  )
 
295
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
296
 
297
  def prepare_extra_step_kwargs(self, generator, eta):
 
299
  extra_step_kwargs = {}
300
  if accepts_eta:
301
  extra_step_kwargs["eta"] = eta
 
302
  accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
  if accepts_generator:
304
  extra_step_kwargs["generator"] = generator
 
321
  ):
322
  if height % 8 != 0 or width % 8 != 0:
323
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
324
  if (callback_steps is None) or (
325
  callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
326
  ):
327
  raise ValueError(
328
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type {type(callback_steps)}."
329
  )
 
330
  if prompt is not None and prompt_embeds is not None:
331
  raise ValueError(
332
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
 
343
  raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
344
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
345
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
 
346
  if negative_prompt is not None and negative_prompt_embeds is not None:
347
  raise ValueError(
348
  f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
 
351
  raise ValueError(
352
  f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
353
  )
 
354
  if prompt_embeds is not None and negative_prompt_embeds is not None:
355
  if prompt_embeds.shape != negative_prompt_embeds.shape:
356
  raise ValueError(
357
  "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds` {negative_prompt_embeds.shape}."
358
  )
 
359
  if prompt_embeds is not None and pooled_prompt_embeds is None:
360
  raise ValueError(
361
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
362
  )
 
363
  if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
364
  raise ValueError(
365
  "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
366
  )
 
367
  if max(height, width) % 1024 != 0:
368
  raise ValueError(f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}.")
 
369
  if num_images_per_prompt != 1:
370
  warnings.warn("num_images_per_prompt != 1 is not supported by AccDiffusion and will be ignored.")
371
  num_images_per_prompt = 1
 
376
  raise ValueError(
377
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators."
378
  )
 
379
  if latents is None:
380
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
381
  else:
382
  latents = latents.to(device)
 
383
  latents = latents * self.scheduler.init_noise_sigma
384
  return latents
385
 
386
  def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
387
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
 
388
  passed_add_embed_dim = (
389
  self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
390
  )
391
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
 
392
  if expected_add_embed_dim != passed_add_embed_dim:
393
  raise ValueError(
394
  f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. \
395
  The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
396
  )
 
397
  add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
398
  return add_time_ids
399
 
 
409
  h_end = h_start + window_size
410
  w_start = int((i % num_blocks_width) * stride)
411
  w_end = w_start + window_size
 
412
  if h_end > height:
413
  h_start = int(h_start + height - h_end)
414
  h_end = int(height)
 
421
  if w_start < 0:
422
  w_end = int(w_end - w_start)
423
  w_start = 0
 
424
  if random_jitter:
425
  jitter_range = (window_size - stride) // 4
426
  w_jitter = 0
 
431
  w_jitter = random.randint(-jitter_range, 0)
432
  elif (w_start != 0) and (w_end == width):
433
  w_jitter = random.randint(0, jitter_range)
 
434
  if (h_start != 0) and (h_end != height):
435
  h_jitter = random.randint(-jitter_range, jitter_range)
436
  elif (h_start == 0) and (h_end != height):
 
441
  h_end = h_end + h_jitter + jitter_range
442
  w_start = w_start + w_jitter + jitter_range
443
  w_end = w_end + w_jitter + jitter_range
 
444
  views.append((h_start, h_end, w_start, w_end))
445
  return views
446
 
 
476
  continue
477
  cross_att_count += 1
478
  attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
 
479
  self.unet.set_attn_processor(attn_procs)
480
  controller.num_att_layers = cross_att_count
481
  return ori_attn_processors
 
488
  from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
489
  else:
490
  raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
 
491
  is_model_cpu_offload = False
492
  is_sequential_cpu_offload = False
493
  recursive = False
 
507
  **kwargs,
508
  )
509
  self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
 
510
  text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
511
  if len(text_encoder_state_dict) > 0:
512
  self.load_lora_into_text_encoder(
 
516
  prefix="text_encoder",
517
  lora_scale=self.lora_scale,
518
  )
 
519
  text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
520
  if len(text_encoder_2_state_dict) > 0:
521
  self.load_lora_into_text_encoder(
 
525
  prefix="text_encoder_2",
526
  lora_scale=self.lora_scale,
527
  )
 
528
  if is_model_cpu_offload:
529
  self.enable_model_cpu_offload()
530
  elif is_sequential_cpu_offload:
 
543
  safe_serialization: bool = True,
544
  ):
545
  state_dict = {}
 
546
  def pack_weights(layers, prefix):
547
  layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
548
  layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
549
  return layers_state_dict
 
550
  if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
551
  raise ValueError(
552
  "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
553
  )
 
554
  if unet_lora_layers:
555
  state_dict.update(pack_weights(unet_lora_layers, "unet"))
 
556
  if text_encoder_lora_layers and text_encoder_2_lora_layers:
557
  state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
558
  state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
 
559
  self.write_lora_layers(
560
  state_dict=state_dict,
561
  save_directory=save_directory,
 
636
 
637
  Examples:
638
  """
 
639
  if debug:
640
  num_inference_steps = 1
 
641
  height = height or self.default_sample_size * self.vae_scale_factor
642
  width = width or self.default_sample_size * self.vae_scale_factor
 
643
  x1_size = self.default_sample_size * self.vae_scale_factor
 
644
  height_scale = height / x1_size
645
  width_scale = width / x1_size
646
  scale_num = int(max(height_scale, width_scale))
647
  aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)
 
648
  original_size = original_size or (height, width)
649
  target_size = target_size or (height, width)
 
650
  if attn_res is None:
651
+ attn_res = (int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32)), int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32)))
652
  self.attn_res = attn_res
 
653
  if lowvram:
654
  attention_map_device = torch.device("cpu")
655
  else:
656
  attention_map_device = self.device
 
657
  self.controller = create_controller(
658
  prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=attention_map_device, attn_res=self.attn_res
659
  )
 
660
  if save_attention_map or use_md_prompt:
661
  ori_attn_processors = self.register_attention_control(self.controller)
 
662
  self.check_inputs(
663
  prompt,
664
  prompt_2,
 
673
  negative_pooled_prompt_embeds,
674
  num_images_per_prompt,
675
  )
 
676
  if prompt is not None and isinstance(prompt, str):
677
  batch_size = 1
678
  elif prompt is not None and isinstance(prompt, list):
679
  batch_size = len(prompt)
680
  else:
681
  batch_size = prompt_embeds.shape[0]
 
682
  device = self._execution_device
683
  self.lowvram = lowvram
684
  if self.lowvram:
 
686
  self.unet.cpu()
687
  self.text_encoder.to(device)
688
  self.text_encoder_2.to(device)
 
689
  do_classifier_free_guidance = guidance_scale > 1.0
690
+ text_encoder_lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)
 
 
 
 
691
  (
692
  prompt_embeds,
693
  negative_prompt_embeds,
 
707
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
708
  lora_scale=text_encoder_lora_scale,
709
  )
 
710
  self.scheduler.set_timesteps(num_inference_steps, device=device)
711
  timesteps = self.scheduler.timesteps
 
712
  num_channels_latents = self.unet.config.in_channels
713
  latents = self.prepare_latents(
714
  batch_size * num_images_per_prompt,
 
720
  generator,
721
  latents,
722
  )
 
723
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
724
  add_text_embeds = pooled_prompt_embeds
 
725
  add_time_ids = self._get_add_time_ids(
726
  original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
727
  )
 
728
  if negative_original_size is not None and negative_target_size is not None:
729
  negative_add_time_ids = self._get_add_time_ids(
730
  negative_original_size,
 
734
  )
735
  else:
736
  negative_add_time_ids = add_time_ids
 
737
  if do_classifier_free_guidance:
738
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
739
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0).to(device)
740
  add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device).repeat(batch_size * num_images_per_prompt, 1)
 
741
  del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
 
742
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
743
  if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
744
  discrete_timestep_cutoff = int(
745
  round(
 
748
  )
749
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
750
  timesteps = timesteps[:num_inference_steps]
 
751
  output_images = []
 
752
  ###################################################### Phase Initialization ########################################################
 
753
  if self.lowvram:
754
  self.text_encoder.cpu()
755
  self.text_encoder_2.cpu()
 
756
  if image_lr == None:
757
  print("### Phase 1 Denoising ###")
758
  with self.progress_bar(total=num_inference_steps) as progress_bar:
759
  for i, t in enumerate(timesteps):
 
760
  if self.lowvram:
761
  self.vae.cpu()
762
  self.unet.to(device)
 
763
  latents_for_view = latents
 
764
  latent_model_input = (
765
  latents.repeat_interleave(2, dim=0)
766
  if do_classifier_free_guidance
767
  else latents
768
  )
769
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
770
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
771
  noise_pred = self.unet(
772
  latent_model_input,
773
  t,
 
775
  added_cond_kwargs=added_cond_kwargs,
776
  return_dict=False,
777
  )[0]
 
778
  if do_classifier_free_guidance:
779
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
780
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
781
  if do_classifier_free_guidance and guidance_rescale > 0.0:
782
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
 
783
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
784
  if t == 1 and use_md_prompt:
785
  md_prompts, views_attention = get_multidiffusion_prompts(tokenizer=self.tokenizer, prompts=[prompt], threthod=c, attention_store=self.controller, height=height//scale_num, width=width//scale_num, from_where=["up","down"], random_jitter=True, scale_num=scale_num)
 
786
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
787
  progress_bar.update()
788
  if callback is not None and i % callback_steps == 0:
789
  step_idx = i // getattr(self.scheduler, "order", 1)
790
  callback(step_idx, t, latents)
 
791
  del latents_for_view, latent_model_input, noise_pred, noise_pred_text, noise_pred_uncond
792
  if use_md_prompt or save_attention_map:
793
  self.recover_attention_control(ori_attn_processors=ori_attn_processors)
 
797
  print("### Encoding Real Image ###")
798
  latents = self.vae.encode(image_lr)
799
  latents = latents.latent_dist.sample() * self.vae.config.scaling_factor
 
800
  anchor_mean = latents.mean()
801
  anchor_std = latents.std()
802
  if self.lowvram:
 
804
  torch.cuda.empty_cache()
805
  if not output_type == "latent":
806
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
 
807
  if self.lowvram:
808
  needs_upcasting = False
809
  self.unet.cpu()
810
  self.vae.to(device)
 
811
  if needs_upcasting:
812
  self.upcast_vae()
813
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
 
818
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
819
  if needs_upcasting:
820
  self.vae.to(dtype=torch.float16)
 
821
  image = self.image_processor.postprocess(image, output_type=output_type)
822
  if not os.path.exists(f'{result_path}'):
823
  os.makedirs(f'{result_path}')
 
824
  image_lr_save_path = f'{result_path}/{image[0].size[0]}_{image[0].size[1]}.png'
825
  image[0].save(image_lr_save_path)
826
  output_images.append(image[0])
 
827
  ####################################################### Phase Upscaling #####################################################
828
  if use_progressive_upscaling:
829
  if image_lr == None:
 
832
  starting_scale = 1
833
  else:
834
  starting_scale = scale_num
 
835
  for current_scale_num in range(starting_scale, scale_num + 1):
836
  if self.lowvram:
837
  latents = latents.to(device)
838
  self.unet.to(device)
839
  torch.cuda.empty_cache()
 
840
  current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
841
  current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
 
842
  if height > width:
843
  current_width = int(current_width * aspect_ratio)
844
  else:
845
  current_height = int(current_height * aspect_ratio)
 
846
  if upscale_mode == "bicubic_latent" or debug:
847
  latents = F.interpolate(latents.to(device), size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)), mode='bicubic')
848
  else:
849
  raise NotImplementedError
 
850
  print("### Phase {} Denoising ###".format(current_scale_num))
851
  noise_latents = []
852
  noise = torch.randn_like(latents)
 
854
  noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
855
  noise_latents.append(noise_latent)
856
  latents = noise_latents[0]
 
857
  with self.progress_bar(total=num_inference_steps) as progress_bar:
858
  for i, t in enumerate(timesteps):
859
  count = torch.zeros_like(latents)
 
862
  if use_skip_residual:
863
  c1 = cosine_factor ** cosine_scale_1
864
  latents = latents * (1 - c1) + noise_latents[i] * c1
 
865
  if use_multidiffusion:
866
  if use_md_prompt:
867
  md_prompt_embeds_list = []
 
889
  md_prompt_embeds_list.append(torch.cat([md_negative_prompt_embeds, md_prompt_embeds], dim=0).to(device))
890
  md_add_text_embeds_list.append(torch.cat([md_negative_pooled_prompt_embeds, md_pooled_prompt_embeds], dim=0).to(device))
891
  del md_negative_prompt_embeds, md_negative_pooled_prompt_embeds
 
892
  if use_md_prompt:
893
  random_jitter = True
894
  views = [(h_start*4, h_end*4, w_start*4, w_end*4) for h_start, h_end, w_start, w_end in views_attention[current_scale_num]]
895
  else:
896
  random_jitter = True
897
  views = self.get_views(current_height, current_width, stride=stride, window_size=self.unet.config.sample_size, random_jitter=random_jitter)
 
898
  views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
 
899
  if use_md_prompt:
900
  views_prompt_embeds_input = [md_prompt_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
901
  views_add_text_embeds_input = [md_add_text_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
 
902
  if random_jitter:
903
  jitter_range = int((self.unet.config.sample_size - stride) // 4)
904
  latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), 'constant', 0)
905
  else:
906
  latents_ = latents
 
907
  count_local = torch.zeros_like(latents_)
908
  value_local = torch.zeros_like(latents_)
 
909
  for j, batch_view in enumerate(views_batch):
910
  vb_size = len(batch_view)
911
  latents_for_view = torch.cat(
912
+ [latents_[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
 
 
 
913
  )
 
914
  latent_model_input = latents_for_view
915
+ latent_model_input = (latent_model_input.repeat_interleave(2, dim=0)
916
+ if do_classifier_free_guidance
917
+ else latent_model_input)
 
 
918
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
919
  add_time_ids_input = []
920
  for h_start, h_end, w_start, w_end in batch_view:
921
  add_time_ids_ = add_time_ids.clone()
 
923
  add_time_ids_[:, 3] = w_start * self.vae_scale_factor
924
  add_time_ids_input.append(add_time_ids_)
925
  add_time_ids_input = torch.cat(add_time_ids_input)
 
926
  if not use_md_prompt:
927
  prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
928
  add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
 
945
  added_cond_kwargs=md_added_cond_kwargs,
946
  return_dict=False,
947
  )[0]
 
948
  if do_classifier_free_guidance:
949
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
950
  noise_pred = noise_pred_uncond + multi_guidance_scale * (noise_pred_text - noise_pred_uncond)
 
951
  if do_classifier_free_guidance and guidance_rescale > 0.0:
952
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
 
953
  self.scheduler._init_step_index(t)
954
+ latents_denoised_batch = self.scheduler.step(noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
955
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(latents_denoised_batch.chunk(vb_size), batch_view):
 
 
 
 
956
  value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
957
  count_local[:, :, h_start:h_end, w_start:w_end] += 1
 
958
  if random_jitter:
959
  value_local = value_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
960
  count_local = count_local[:, :, jitter_range:jitter_range + current_height // self.vae_scale_factor, jitter_range:jitter_range + current_width // self.vae_scale_factor]
 
961
  noise_index = i + 1 if i != (len(timesteps) - 1) else i
 
962
  value_local = torch.where(count_local == 0, noise_latents[noise_index], value_local)
963
  count_local = torch.where(count_local == 0, torch.ones_like(count_local), count_local)
964
  if use_dilated_sampling:
 
967
  count += torch.ones_like(value_local) * (1 - c2)
968
  else:
969
  value += value_local / count_local
970
+ count += torch.ones_like(value_local)
 
971
  if use_dilated_sampling:
972
  views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]
973
  views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
 
974
  h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num
975
  w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num
976
  latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), 'constant', 0)
 
977
  count_global = torch.zeros_like(latents_)
978
  value_global = torch.zeros_like(latents_)
 
979
  if use_guassian:
980
  c3 = 0.99 * cosine_factor ** cosine_scale_3 + 1e-2
981
  std_, mean_ = latents_.std(), latents_.mean()
 
983
  latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_
984
  else:
985
  latents_gaussian = latents_
 
986
  for j, batch_view in enumerate(views_batch):
 
987
  latents_for_view = torch.cat(
988
+ [latents_[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
 
 
 
989
  )
 
990
  latents_for_view_gaussian = torch.cat(
991
+ [latents_gaussian[:, :, h::current_scale_num, w::current_scale_num] for h, w in batch_view]
 
 
 
992
  )
 
993
  if shuffle:
994
  shape = latents_for_view.shape
995
+ # 수정: range(...) 괄호 추가
996
  shuffle_index = torch.stack([torch.randperm(shape[0]) for _ in range(latents_for_view.reshape(-1).shape[0]//shape[0])])
997
  shuffle_index = shuffle_index.view(shape[1], shape[2], shape[3], shape[0])
998
  original_index = torch.zeros_like(shuffle_index).scatter_(3, shuffle_index, torch.arange(shape[0]).repeat(shape[1], shape[2], shape[3], 1))
999
  shuffle_index = shuffle_index.permute(3, 0, 1, 2).to(device)
1000
  original_index = original_index.permute(3, 0, 1, 2).to(device)
1001
  latents_for_view_gaussian = latents_for_view_gaussian.gather(0, shuffle_index)
 
1002
  vb_size = latents_for_view.size(0)
 
1003
  latent_model_input = latents_for_view_gaussian
1004
+ latent_model_input = (latent_model_input.repeat_interleave(2, dim=0)
1005
+ if do_classifier_free_guidance
1006
+ else latent_model_input)
 
 
1007
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
1008
  prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
1009
  add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
1010
  add_time_ids_input = torch.cat([add_time_ids] * vb_size)
 
1011
  added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
1012
  noise_pred = self.unet(
1013
  latent_model_input,
 
1016
  added_cond_kwargs=added_cond_kwargs,
1017
  return_dict=False,
1018
  )[0]
 
1019
  if do_classifier_free_guidance:
1020
  noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
1021
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
1022
  if do_classifier_free_guidance and guidance_rescale > 0.0:
1023
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
 
1024
  if shuffle:
1025
  noise_pred = noise_pred.gather(0, original_index)
 
1026
  self.scheduler._init_step_index(t)
1027
  latents_denoised_batch = self.scheduler.step(noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
1028
+ for latents_view_denoised, (h, w) in zip(latents_denoised_batch.chunk(vb_size), batch_view):
 
 
 
1029
  value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
1030
  count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
 
1031
  value_global = value_global[:, :, h_pad:, w_pad:]
 
1032
  if use_multidiffusion:
1033
  c2 = cosine_factor ** cosine_scale_2
1034
  value += value_global * c2
 
1036
  else:
1037
  value += value_global
1038
  count += torch.ones_like(value_global)
 
1039
  latents = torch.where(count > 0, value / count, value)
 
1040
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1041
  progress_bar.update()
1042
  if callback is not None and i % callback_steps == 0:
1043
  step_idx = i // getattr(self.scheduler, "order", 1)
1044
  callback(step_idx, t, latents)
 
1045
  latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
1046
  if self.lowvram:
1047
  latents = latents.cpu()
 
1052
  needs_upcasting = False
1053
  self.unet.cpu()
1054
  self.vae.to(device)
 
1055
  if needs_upcasting:
1056
  self.upcast_vae()
1057
  latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
 
1058
  print("### Phase {} Decoding ###".format(current_scale_num))
1059
  if current_height > 2048 or current_width > 2048:
1060
  self.enable_vae_tiling()
1061
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1062
  else:
1063
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
 
1064
  image = self.image_processor.postprocess(image, output_type=output_type)
1065
  image[0].save(f'{result_path}/AccDiffusion_{current_scale_num}.png')
1066
  output_images.append(image[0])
 
1067
  if needs_upcasting:
1068
  self.vae.to(dtype=torch.float16)
1069
  else:
1070
  image = latents
 
1071
  self.maybe_free_model_hooks()
 
1072
  return output_images
1073
 
1074
 
 
1102
  ## others ##
1103
  parser.add_argument('--debug', default=False, action='store_true')
1104
  parser.add_argument('--experiment_name', default="AccDiffusion")
 
1105
  args = parser.parse_args()
 
1106
  pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1107
 
1108
  @spaces.GPU(duration=200)
 
1114
  "n_cross_replace": {"default_": 1.0, "confetti": 0.8},
1115
  }
1116
  generator = torch.Generator(device='cuda').manual_seed(seed)
 
1117
  print(f"Prompt: {prompt}")
1118
  md5_hash = hashlib.md5(prompt.encode()).hexdigest()
1119
  result_path = f"./output/{args.experiment_name}/{md5_hash}/{width}_{height}_{seed}/"
 
1120
  images = pipe(
1121
  prompt,
1122
  negative_prompt=args.negative_prompt,
 
1144
  debug=args.debug, save_attention_map=args.save_attention_map, use_md_prompt=use_md_prompt, c=args.c
1145
  )
1146
  print(images)
 
1147
  return images
1148
 
1149
  MAX_SEED = np.iinfo(np.int32).max
 
1150
  css = """
1151
  body {
1152
  background: linear-gradient(135deg, #2c3e50, #4ca1af);
 
1169
  visibility: hidden;
1170
  }
1171
  """
 
1172
  with gr.Blocks(css=css) as demo:
1173
  with gr.Column(elem_id="col-container"):
1174
  gr.Markdown("<h1>AccDiffusion: Advanced AI Art Generator</h1>")
 
1178
  with gr.Row():
1179
  prompt = gr.Textbox(label="Prompt", placeholder="예: A surreal landscape with floating islands and vibrant colors.", lines=2, scale=4)
1180
  submit_btn = gr.Button("Generate", scale=1)
 
1181
  with gr.Accordion("Advanced Settings", open=False):
1182
  with gr.Row():
1183
  resolution = gr.Radio(
 
1200
  use_progressive_upscaling = gr.Checkbox(label="Use Progressive Upscaling", value=False)
1201
  shuffle = gr.Checkbox(label="Shuffle", value=False)
1202
  use_md_prompt = gr.Checkbox(label="Use MD Prompt", value=False)
1203
+ output_images = gr.Gallery(label="Output Images", format="png")
 
1204
  gr.Markdown("### Example Prompts")
1205
  gr.Examples(
1206
  examples = [