ginipick commited on
Commit
d9dee7c
·
verified ·
1 Parent(s): ccb2f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1071 -54
app.py CHANGED
@@ -113,7 +113,35 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
113
  """
114
  Pipeline for text-to-image generation using Stable Diffusion XL.
115
 
116
- [클래스 설명 생략 …]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  """
118
  model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
119
 
@@ -152,7 +180,469 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
152
  else:
153
  self.watermark = None
154
 
155
- # (이하 기존 메서드들 생략 …)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  @torch.no_grad()
158
  @replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -213,17 +703,554 @@ class AccDiffusionSDXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoade
213
  seed: Optional[int] = None,
214
  c : Optional[float] = 0.3,
215
  ):
216
- r"""
217
- [함수 설명 생략 …]
218
- """
219
- # (여기서는 기존 __call__ 함수 내부 구현을 그대로 유지합니다.)
220
- # ... (중략)
221
- output_images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  ###################################################### Phase Initialization ########################################################
224
- # (중략) 실제 denoising 및 upscaling 부분
225
 
226
- # 마지막에 이미지 저장 및 반환
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  return output_images
228
 
229
 
@@ -260,36 +1287,8 @@ if __name__ == "__main__":
260
 
261
  args = parser.parse_args()
262
 
263
- # 파이프라인 불러오기 (필요한 모델 체크포인트 사용)
264
  pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
265
 
266
-
267
- # ----------------------- GRADIO INTERFACE (개선된 UI) -----------------------
268
-
269
- # 사용자 인터페이스에 적용할 CSS (배경, 폰트, 카드 스타일 등)
270
- css = """
271
- body {
272
- background: linear-gradient(135deg, #2c3e50, #4ca1af);
273
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
274
- color: #ffffff;
275
- }
276
- #col-container {
277
- margin: 20px auto;
278
- padding: 20px;
279
- max-width: 900px;
280
- background-color: rgba(0, 0, 0, 0.5);
281
- border-radius: 12px;
282
- box-shadow: 0 4px 12px rgba(0,0,0,0.5);
283
- }
284
- h1, h2 {
285
- text-align: center;
286
- margin-bottom: 10px;
287
- }
288
- footer {
289
- visibility: hidden;
290
- }
291
- """
292
-
293
  @spaces.GPU(duration=200)
294
  def infer(prompt, resolution, num_inference_steps, guidance_scale, seed, use_multidiffusion, use_skip_residual, use_dilated_sampling, use_progressive_upscaling, shuffle, use_md_prompt, progress=gr.Progress(track_tqdm=True)):
295
  set_seed(seed)
@@ -319,8 +1318,7 @@ if __name__ == "__main__":
319
  cosine_scale_1=args.cosine_scale_1,
320
  cosine_scale_2=args.cosine_scale_2,
321
  cosine_scale_3=args.cosine_scale_3,
322
- sigma=args.sigma,
323
- use_guassian=args.use_guassian,
324
  multi_decoder=args.multi_decoder,
325
  upscale_mode=args.upscale_mode,
326
  use_multidiffusion=use_multidiffusion,
@@ -329,18 +1327,37 @@ if __name__ == "__main__":
329
  use_dilated_sampling=use_dilated_sampling,
330
  shuffle=shuffle,
331
  result_path=result_path,
332
- debug=args.debug,
333
- save_attention_map=args.save_attention_map,
334
- use_md_prompt=use_md_prompt,
335
- c=args.c
336
  )
337
  print(images)
338
 
339
  return images
340
-
341
-
342
  MAX_SEED = np.iinfo(np.int32).max
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  with gr.Blocks(css=css) as demo:
345
  with gr.Column(elem_id="col-container"):
346
  gr.Markdown("<h1>AccDiffusion: Advanced AI Art Generator</h1>")
@@ -355,10 +1372,10 @@ if __name__ == "__main__":
355
  with gr.Row():
356
  resolution = gr.Radio(
357
  label="Resolution",
358
- choices=[
359
  "1024,1024", "2048,2048", "2048,1024", "1536,3072", "3072,3072", "4096,4096", "4096,2048"
360
  ],
361
- value="1024,1024",
362
  interactive=True
363
  )
364
  with gr.Column():
@@ -377,7 +1394,7 @@ if __name__ == "__main__":
377
  output_images = gr.Gallery(label="Output Images", format="png").style(grid=[2], height="auto")
378
  gr.Markdown("### Example Prompts")
379
  gr.Examples(
380
- examples=[
381
  ["A surreal landscape with floating islands and vibrant colors."],
382
  ["Cyberpunk cityscape at night with neon lights and futuristic architecture."],
383
  ["A majestic dragon soaring over a medieval castle amidst stormy skies."],
@@ -385,14 +1402,14 @@ if __name__ == "__main__":
385
  ["Abstract geometric patterns in vivid, pulsating colors."],
386
  ["A mystical forest illuminated by bioluminescent plants under a starry sky."]
387
  ],
388
- inputs=[prompt],
389
  label="Click an example to populate the prompt box."
390
  )
391
  submit_btn.click(
392
- fn=infer,
393
- inputs=[prompt, resolution, num_inference_steps, guidance_scale, seed,
394
- use_multidiffusion, use_skip_residual, use_dilated_sampling, use_progressive_upscaling, shuffle, use_md_prompt],
395
- outputs=[output_images],
396
  show_api=False
397
  )
398
  demo.launch(show_api=False, show_error=True)
 
113
  """
114
  Pipeline for text-to-image generation using Stable Diffusion XL.
115
 
116
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
117
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
118
+
119
+ In addition the pipeline inherits the following loading methods:
120
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
121
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
122
+
123
+ as well as the following saving methods:
124
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
125
+
126
+ Args:
127
+ vae ([`AutoencoderKL`]):
128
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
129
+ text_encoder ([`CLIPTextModel`]):
130
+ Frozen text-encoder.
131
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
132
+ Second frozen text-encoder.
133
+ tokenizer (`CLIPTokenizer`):
134
+ Tokenizer.
135
+ tokenizer_2 (`CLIPTokenizer`):
136
+ Second Tokenizer.
137
+ unet ([`UNet2DConditionModel`]):
138
+ Conditional U-Net architecture.
139
+ scheduler ([`SchedulerMixin`]):
140
+ A scheduler to be used in combination with `unet`.
141
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
142
+ Whether the negative prompt embeddings shall be forced to always be set to 0.
143
+ add_watermarker (`bool`, *optional*):
144
+ Whether to use the invisible watermark library.
145
  """
146
  model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
147
 
 
180
  else:
181
  self.watermark = None
182
 
183
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
184
+ def enable_vae_slicing(self):
185
+ self.vae.enable_slicing()
186
+
187
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
188
+ def disable_vae_slicing(self):
189
+ self.vae.disable_slicing()
190
+
191
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
192
+ def enable_vae_tiling(self):
193
+ self.vae.enable_tiling()
194
+
195
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
196
+ def disable_vae_tiling(self):
197
+ self.vae.disable_tiling()
198
+
199
+ def encode_prompt(
200
+ self,
201
+ prompt: str,
202
+ prompt_2: Optional[str] = None,
203
+ device: Optional[torch.device] = None,
204
+ num_images_per_prompt: int = 1,
205
+ do_classifier_free_guidance: bool = True,
206
+ negative_prompt: Optional[str] = None,
207
+ negative_prompt_2: Optional[str] = None,
208
+ prompt_embeds: Optional[torch.FloatTensor] = None,
209
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
210
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
211
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
212
+ lora_scale: Optional[float] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+
216
+ # set lora scale so that monkey patched LoRA function of text encoder can correctly access it
217
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
218
+ self._lora_scale = lora_scale
219
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
220
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
221
+
222
+ if prompt is not None and isinstance(prompt, str):
223
+ batch_size = 1
224
+ elif prompt is not None and isinstance(prompt, list):
225
+ batch_size = len(prompt)
226
+ else:
227
+ batch_size = prompt_embeds.shape[0]
228
+
229
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
230
+ text_encoders = (
231
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
232
+ )
233
+
234
+ if prompt_embeds is None:
235
+ prompt_2 = prompt_2 or prompt
236
+ prompt_embeds_list = []
237
+ prompts = [prompt, prompt_2]
238
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
239
+ if isinstance(self, TextualInversionLoaderMixin):
240
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
241
+
242
+ text_inputs = tokenizer(
243
+ prompt,
244
+ padding="max_length",
245
+ max_length=tokenizer.model_max_length,
246
+ truncation=True,
247
+ return_tensors="pt",
248
+ )
249
+
250
+ text_input_ids = text_inputs.input_ids
251
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
252
+
253
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
254
+ text_input_ids, untruncated_ids
255
+ ):
256
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
257
+ logger.warning(
258
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
259
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
260
+ )
261
+
262
+ prompt_embeds = text_encoder(
263
+ text_input_ids.to(device),
264
+ output_hidden_states=True,
265
+ )
266
+
267
+ pooled_prompt_embeds = prompt_embeds[0]
268
+ prompt_embeds = prompt_embeds.hidden_states[-2]
269
+
270
+ prompt_embeds_list.append(prompt_embeds)
271
+
272
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
273
+
274
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
275
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
276
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
277
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
278
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
279
+ negative_prompt = negative_prompt or ""
280
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
281
+
282
+ if prompt is not None and type(prompt) is not type(negative_prompt):
283
+ raise TypeError(
284
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
285
+ f" {type(prompt)}."
286
+ )
287
+ elif isinstance(negative_prompt, str):
288
+ uncond_tokens = [negative_prompt, negative_prompt_2]
289
+ elif batch_size != len(negative_prompt):
290
+ raise ValueError(
291
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
292
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
293
+ " the batch size of `prompt`."
294
+ )
295
+ else:
296
+ uncond_tokens = [negative_prompt, negative_prompt_2]
297
+
298
+ negative_prompt_embeds_list = []
299
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
300
+ if isinstance(self, TextualInversionLoaderMixin):
301
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
302
+
303
+ max_length = prompt_embeds.shape[1]
304
+ uncond_input = tokenizer(
305
+ negative_prompt,
306
+ padding="max_length",
307
+ max_length=max_length,
308
+ truncation=True,
309
+ return_tensors="pt",
310
+ )
311
+
312
+ negative_prompt_embeds = text_encoder(
313
+ uncond_input.input_ids.to(device),
314
+ output_hidden_states=True,
315
+ )
316
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
317
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
318
+
319
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
320
+
321
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
322
+
323
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
324
+ bs_embed, seq_len, _ = prompt_embeds.shape
325
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
326
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
327
+
328
+ if do_classifier_free_guidance:
329
+ seq_len = negative_prompt_embeds.shape[1]
330
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
331
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
332
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
333
+
334
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
335
+ bs_embed * num_images_per_prompt, -1
336
+ )
337
+ if do_classifier_free_guidance:
338
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
339
+ bs_embed * num_images_per_prompt, -1
340
+ )
341
+
342
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
343
+
344
+ def prepare_extra_step_kwargs(self, generator, eta):
345
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
346
+ extra_step_kwargs = {}
347
+ if accepts_eta:
348
+ extra_step_kwargs["eta"] = eta
349
+
350
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
351
+ if accepts_generator:
352
+ extra_step_kwargs["generator"] = generator
353
+ return extra_step_kwargs
354
+
355
+ def check_inputs(
356
+ self,
357
+ prompt,
358
+ prompt_2,
359
+ height,
360
+ width,
361
+ callback_steps,
362
+ negative_prompt=None,
363
+ negative_prompt_2=None,
364
+ prompt_embeds=None,
365
+ negative_prompt_embeds=None,
366
+ pooled_prompt_embeds=None,
367
+ negative_pooled_prompt_embeds=None,
368
+ num_images_per_prompt=None,
369
+ ):
370
+ if height % 8 != 0 or width % 8 != 0:
371
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
372
+
373
+ if (callback_steps is None) or (
374
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
375
+ ):
376
+ raise ValueError(
377
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type {type(callback_steps)}."
378
+ )
379
+
380
+ if prompt is not None and prompt_embeds is not None:
381
+ raise ValueError(
382
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
383
+ )
384
+ elif prompt_2 is not None and prompt_embeds is not None:
385
+ raise ValueError(
386
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to only forward one of the two."
387
+ )
388
+ elif prompt is None and prompt_embeds is None:
389
+ raise ValueError(
390
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
391
+ )
392
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
393
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
394
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
395
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
396
+
397
+ if negative_prompt is not None and negative_prompt_embeds is not None:
398
+ raise ValueError(
399
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
400
+ )
401
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
402
+ raise ValueError(
403
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to only forward one of the two."
404
+ )
405
+
406
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
407
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
408
+ raise ValueError(
409
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds` {negative_prompt_embeds.shape}."
410
+ )
411
+
412
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
413
+ raise ValueError(
414
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
415
+ )
416
+
417
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
418
+ raise ValueError(
419
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
420
+ )
421
+
422
+ if max(height, width) % 1024 != 0:
423
+ raise ValueError(f"the larger one of `height` and `width` has to be divisible by 1024 but are {height} and {width}.")
424
+
425
+ if num_images_per_prompt != 1:
426
+ warnings.warn("num_images_per_prompt != 1 is not supported by AccDiffusion and will be ignored.")
427
+ num_images_per_prompt = 1
428
+
429
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
430
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
431
+ if isinstance(generator, list) and len(generator) != batch_size:
432
+ raise ValueError(
433
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}. Make sure the batch size matches the length of the generators."
434
+ )
435
+
436
+ if latents is None:
437
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
438
+ else:
439
+ latents = latents.to(device)
440
+
441
+ latents = latents * self.scheduler.init_noise_sigma
442
+ return latents
443
+
444
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
445
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
446
+
447
+ passed_add_embed_dim = (
448
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
449
+ )
450
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
451
+
452
+ if expected_add_embed_dim != passed_add_embed_dim:
453
+ raise ValueError(
454
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. \
455
+ The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
456
+ )
457
+
458
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
459
+ return add_time_ids
460
+
461
+ def get_views(self, height, width, window_size=128, stride=64, random_jitter=False):
462
+ height //= self.vae_scale_factor
463
+ width //= self.vae_scale_factor
464
+ num_blocks_height = int((height - window_size) / stride - 1e-6) + 2 if height > window_size else 1
465
+ num_blocks_width = int((width - window_size) / stride - 1e-6) + 2 if width > window_size else 1
466
+ total_num_blocks = int(num_blocks_height * num_blocks_width)
467
+ views = []
468
+ for i in range(total_num_blocks):
469
+ h_start = int((i // num_blocks_width) * stride)
470
+ h_end = h_start + window_size
471
+ w_start = int((i % num_blocks_width) * stride)
472
+ w_end = w_start + window_size
473
+
474
+ if h_end > height:
475
+ h_start = int(h_start + height - h_end)
476
+ h_end = int(height)
477
+ if w_end > width:
478
+ w_start = int(w_start + width - w_end)
479
+ w_end = int(width)
480
+ if h_start < 0:
481
+ h_end = int(h_end - h_start)
482
+ h_start = 0
483
+ if w_start < 0:
484
+ w_end = int(w_end - w_start)
485
+ w_start = 0
486
+
487
+ if random_jitter:
488
+ jitter_range = (window_size - stride) // 4
489
+ w_jitter = 0
490
+ h_jitter = 0
491
+ if (w_start != 0) and (w_end != width):
492
+ w_jitter = random.randint(-jitter_range, jitter_range)
493
+ elif (w_start == 0) and (w_end != width):
494
+ w_jitter = random.randint(-jitter_range, 0)
495
+ elif (w_start != 0) and (w_end == width):
496
+ w_jitter = random.randint(0, jitter_range)
497
+
498
+ if (h_start != 0) and (h_end != height):
499
+ h_jitter = random.randint(-jitter_range, jitter_range)
500
+ elif (h_start == 0) and (h_end != height):
501
+ h_jitter = random.randint(-jitter_range, 0)
502
+ elif (h_start != 0) and (h_end == height):
503
+ h_jitter = random.randint(0, jitter_range)
504
+ h_start = h_start + h_jitter + jitter_range
505
+ h_end = h_end + h_jitter + jitter_range
506
+ w_start = w_start + w_jitter + jitter_range
507
+ w_end = w_end + w_jitter + jitter_range
508
+
509
+ views.append((h_start, h_end, w_start, w_end))
510
+ return views
511
+
512
+ def upcast_vae(self):
513
+ dtype = self.vae.dtype
514
+ self.vae.to(dtype=torch.float32)
515
+ use_torch_2_0_or_xformers = isinstance(
516
+ self.vae.decoder.mid_block.attentions[0].processor,
517
+ (
518
+ AttnProcessor2_0,
519
+ XFormersAttnProcessor,
520
+ LoRAXFormersAttnProcessor,
521
+ LoRAAttnProcessor2_0,
522
+ ),
523
+ )
524
+ if use_torch_2_0_or_xformers:
525
+ self.vae.post_quant_conv.to(dtype)
526
+ self.vae.decoder.conv_in.to(dtype)
527
+ self.vae.decoder.mid_block.to(dtype)
528
+
529
+ def register_attention_control(self, controller):
530
+ attn_procs = {}
531
+ cross_att_count = 0
532
+ ori_attn_processors = self.unet.attn_processors
533
+ for name in self.unet.attn_processors.keys():
534
+ if name.startswith("mid_block"):
535
+ place_in_unet = "mid"
536
+ elif name.startswith("up_blocks"):
537
+ place_in_unet = "up"
538
+ elif name.startswith("down_blocks"):
539
+ place_in_unet = "down"
540
+ else:
541
+ continue
542
+ cross_att_count += 1
543
+ attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
544
+
545
+ self.unet.set_attn_processor(attn_procs)
546
+ controller.num_att_layers = cross_att_count
547
+ return ori_attn_processors
548
+
549
+ def recover_attention_control(self, ori_attn_processors):
550
+ self.unet.set_attn_processor(ori_attn_processors)
551
+
552
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
553
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
554
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
555
+ else:
556
+ raise ImportError("Offloading requires `accelerate v0.17.0` or higher.")
557
+
558
+ is_model_cpu_offload = False
559
+ is_sequential_cpu_offload = False
560
+ recursive = False
561
+ for _, component in self.components.items():
562
+ if isinstance(component, torch.nn.Module):
563
+ if hasattr(component, "_hf_hook"):
564
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
565
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
566
+ logger.info(
567
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
568
+ )
569
+ recursive = is_sequential_cpu_offload
570
+ remove_hook_from_module(component, recurse=recursive)
571
+ state_dict, network_alphas = self.lora_state_dict(
572
+ pretrained_model_name_or_path_or_dict,
573
+ unet_config=self.unet.config,
574
+ **kwargs,
575
+ )
576
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
577
+
578
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
579
+ if len(text_encoder_state_dict) > 0:
580
+ self.load_lora_into_text_encoder(
581
+ text_encoder_state_dict,
582
+ network_alphas=network_alphas,
583
+ text_encoder=self.text_encoder,
584
+ prefix="text_encoder",
585
+ lora_scale=self.lora_scale,
586
+ )
587
+
588
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
589
+ if len(text_encoder_2_state_dict) > 0:
590
+ self.load_lora_into_text_encoder(
591
+ text_encoder_2_state_dict,
592
+ network_alphas=network_alphas,
593
+ text_encoder=self.text_encoder_2,
594
+ prefix="text_encoder_2",
595
+ lora_scale=self.lora_scale,
596
+ )
597
+
598
+ if is_model_cpu_offload:
599
+ self.enable_model_cpu_offload()
600
+ elif is_sequential_cpu_offload:
601
+ self.enable_sequential_cpu_offload()
602
+
603
+ @classmethod
604
+ def save_lora_weights(
605
+ self,
606
+ save_directory: Union[str, os.PathLike],
607
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
608
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
609
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
610
+ is_main_process: bool = True,
611
+ weight_name: str = None,
612
+ save_function: Callable = None,
613
+ safe_serialization: bool = True,
614
+ ):
615
+ state_dict = {}
616
+
617
+ def pack_weights(layers, prefix):
618
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
619
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
620
+ return layers_state_dict
621
+
622
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
623
+ raise ValueError(
624
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
625
+ )
626
+
627
+ if unet_lora_layers:
628
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
629
+
630
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
631
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
632
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
633
+
634
+ self.write_lora_layers(
635
+ state_dict=state_dict,
636
+ save_directory=save_directory,
637
+ is_main_process=is_main_process,
638
+ weight_name=weight_name,
639
+ save_function=save_function,
640
+ safe_serialization=safe_serialization,
641
+ )
642
+
643
+ def _remove_text_encoder_monkey_patch(self):
644
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
645
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
646
 
647
  @torch.no_grad()
648
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
703
  seed: Optional[int] = None,
704
  c : Optional[float] = 0.3,
705
  ):
706
+ if debug:
707
+ num_inference_steps = 1
708
+
709
+ height = height or self.default_sample_size * self.vae_scale_factor
710
+ width = width or self.default_sample_size * self.vae_scale_factor
711
+
712
+ x1_size = self.default_sample_size * self.vae_scale_factor
713
+
714
+ height_scale = height / x1_size
715
+ width_scale = width / x1_size
716
+ scale_num = int(max(height_scale, width_scale))
717
+ aspect_ratio = min(height_scale, width_scale) / max(height_scale, width_scale)
718
+
719
+ original_size = original_size or (height, width)
720
+ target_size = target_size or (height, width)
721
+
722
+ if attn_res is None:
723
+ attn_res = int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32)), int(np.ceil(self.default_sample_size * self.vae_scale_factor / 32))
724
+ self.attn_res = attn_res
725
+
726
+ if lowvram:
727
+ attention_map_device = torch.device("cpu")
728
+ else:
729
+ attention_map_device = self.device
730
+
731
+ self.controller = create_controller(
732
+ prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=attention_map_device, attn_res=self.attn_res
733
+ )
734
+
735
+ if save_attention_map or use_md_prompt:
736
+ ori_attn_processors = self.register_attention_control(self.controller)
737
+
738
+ self.check_inputs(
739
+ prompt,
740
+ prompt_2,
741
+ height,
742
+ width,
743
+ callback_steps,
744
+ negative_prompt,
745
+ negative_prompt_2,
746
+ prompt_embeds,
747
+ negative_prompt_embeds,
748
+ pooled_prompt_embeds,
749
+ negative_pooled_prompt_embeds,
750
+ num_images_per_prompt,
751
+ )
752
+
753
+ if prompt is not None and isinstance(prompt, str):
754
+ batch_size = 1
755
+ elif prompt is not None and isinstance(prompt, list):
756
+ batch_size = len(prompt)
757
+ else:
758
+ batch_size = prompt_embeds.shape[0]
759
+
760
+ device = self._execution_device
761
+ self.lowvram = lowvram
762
+ if self.lowvram:
763
+ self.vae.cpu()
764
+ self.unet.cpu()
765
+ self.text_encoder.to(device)
766
+ self.text_encoder_2.to(device)
767
+
768
+ do_classifier_free_guidance = guidance_scale > 1.0
769
+
770
+ text_encoder_lora_scale = (
771
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
772
+ )
773
+
774
+ (
775
+ prompt_embeds,
776
+ negative_prompt_embeds,
777
+ pooled_prompt_embeds,
778
+ negative_pooled_prompt_embeds,
779
+ ) = self.encode_prompt(
780
+ prompt=prompt,
781
+ prompt_2=prompt_2,
782
+ device=device,
783
+ num_images_per_prompt=num_images_per_prompt,
784
+ do_classifier_free_guidance=do_classifier_free_guidance,
785
+ negative_prompt=negative_prompt,
786
+ negative_prompt_2=negative_prompt_2,
787
+ prompt_embeds=prompt_embeds,
788
+ negative_prompt_embeds=negative_prompt_embeds,
789
+ pooled_prompt_embeds=pooled_prompt_embeds,
790
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
791
+ lora_scale=text_encoder_lora_scale,
792
+ )
793
+
794
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
795
+ timesteps = self.scheduler.timesteps
796
+
797
+ num_channels_latents = self.unet.config.in_channels
798
+ latents = self.prepare_latents(
799
+ batch_size * num_images_per_prompt,
800
+ num_channels_latents,
801
+ height // scale_num,
802
+ width // scale_num,
803
+ prompt_embeds.dtype,
804
+ device,
805
+ generator,
806
+ latents,
807
+ )
808
+
809
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
810
+
811
+ add_text_embeds = pooled_prompt_embeds
812
+
813
+ add_time_ids = self._get_add_time_ids(
814
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
815
+ )
816
+
817
+ if negative_original_size is not None and negative_target_size is not None:
818
+ negative_add_time_ids = self._get_add_time_ids(
819
+ negative_original_size,
820
+ negative_crops_coords_top_left,
821
+ negative_target_size,
822
+ dtype=prompt_embeds.dtype,
823
+ )
824
+ else:
825
+ negative_add_time_ids = add_time_ids
826
 
827
+ if do_classifier_free_guidance:
828
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0).to(device)
829
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0).to(device)
830
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0).to(device).repeat(batch_size * num_images_per_prompt, 1)
831
+
832
+ del negative_prompt_embeds, negative_pooled_prompt_embeds, negative_add_time_ids
833
+
834
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
835
+
836
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
837
+ discrete_timestep_cutoff = int(
838
+ round(
839
+ self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps)
840
+ )
841
+ )
842
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
843
+ timesteps = timesteps[:num_inference_steps]
844
+
845
+ output_images = []
846
+
847
  ###################################################### Phase Initialization ########################################################
 
848
 
849
+ if self.lowvram:
850
+ self.text_encoder.cpu()
851
+ self.text_encoder_2.cpu()
852
+
853
+ if image_lr == None:
854
+ print("### Phase 1 Denoising ###")
855
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
856
+ for i, t in enumerate(timesteps):
857
+
858
+ if self.lowvram:
859
+ self.vae.cpu()
860
+ self.unet.to(device)
861
+
862
+ latents_for_view = latents
863
+
864
+ latent_model_input = (
865
+ latents.repeat_interleave(2, dim=0)
866
+ if do_classifier_free_guidance
867
+ else latents
868
+ )
869
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
870
+
871
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
872
+
873
+ noise_pred = self.unet(
874
+ latent_model_input,
875
+ t,
876
+ encoder_hidden_states=prompt_embeds,
877
+ added_cond_kwargs=added_cond_kwargs,
878
+ return_dict=False,
879
+ )[0]
880
+
881
+ if do_classifier_free_guidance:
882
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
883
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
884
+
885
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
886
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
887
+
888
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
889
+
890
+ if t == 1 and use_md_prompt:
891
+ md_prompts, views_attention = get_multidiffusion_prompts(tokenizer=self.tokenizer, prompts=[prompt], threthod=c, attention_store=self.controller, height=height//scale_num, width=width//scale_num, from_where=["up","down"], random_jitter=True, scale_num=scale_num)
892
+
893
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
894
+ progress_bar.update()
895
+ if callback is not None and i % callback_steps == 0:
896
+ step_idx = i // getattr(self.scheduler, "order", 1)
897
+ callback(step_idx, t, latents)
898
+
899
+ del latents_for_view, latent_model_input, noise_pred, noise_pred_text, noise_pred_uncond
900
+ if use_md_prompt or save_attention_map:
901
+ self.recover_attention_control(ori_attn_processors=ori_attn_processors)
902
+ del self.controller
903
+ torch.cuda.empty_cache()
904
+ else:
905
+ print("### Encoding Real Image ###")
906
+ latents = self.vae.encode(image_lr)
907
+ latents = latents.latent_dist.sample() * self.vae.config.scaling_factor
908
+
909
+ anchor_mean = latents.mean()
910
+ anchor_std = latents.std()
911
+ if self.lowvram:
912
+ latents = latents.cpu()
913
+ torch.cuda.empty_cache()
914
+ if not output_type == "latent":
915
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
916
+
917
+ if self.lowvram:
918
+ needs_upcasting = False
919
+ self.unet.cpu()
920
+ self.vae.to(device)
921
+
922
+ if needs_upcasting:
923
+ self.upcast_vae()
924
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
925
+ if self.lowvram and multi_decoder:
926
+ current_width_height = self.unet.config.sample_size * self.vae_scale_factor
927
+ image = self.tiled_decode(latents, current_width_height, current_width_height)
928
+ else:
929
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
930
+ if needs_upcasting:
931
+ self.vae.to(dtype=torch.float16)
932
+
933
+ image = self.image_processor.postprocess(image, output_type=output_type)
934
+ if not os.path.exists(f'{result_path}'):
935
+ os.makedirs(f'{result_path}')
936
+
937
+ image_lr_save_path = f'{result_path}/{image[0].size[0]}_{image[0].size[1]}.png'
938
+ image[0].save(image_lr_save_path)
939
+ output_images.append(image[0])
940
+
941
+ ####################################################### Phase Upscaling #####################################################
942
+ if use_progressive_upscaling:
943
+ if image_lr == None:
944
+ starting_scale = 2
945
+ else:
946
+ starting_scale = 1
947
+ else:
948
+ starting_scale = scale_num
949
+
950
+ for current_scale_num in range(starting_scale, scale_num + 1):
951
+ if self.lowvram:
952
+ latents = latents.to(device)
953
+ self.unet.to(device)
954
+ torch.cuda.empty_cache()
955
+
956
+ current_height = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
957
+ current_width = self.unet.config.sample_size * self.vae_scale_factor * current_scale_num
958
+
959
+ if height > width:
960
+ current_width = int(current_width * aspect_ratio)
961
+ else:
962
+ current_height = int(current_height * aspect_ratio)
963
+
964
+ if upscale_mode == "bicubic_latent" or debug:
965
+ latents = F.interpolate(latents.to(device), size=(int(current_height / self.vae_scale_factor), int(current_width / self.vae_scale_factor)), mode='bicubic')
966
+ else:
967
+ raise NotImplementedError
968
+
969
+ print("### Phase {} Denoising ###".format(current_scale_num))
970
+ noise_latents = []
971
+ noise = torch.randn_like(latents)
972
+ for timestep in timesteps:
973
+ noise_latent = self.scheduler.add_noise(latents, noise, timestep.unsqueeze(0))
974
+ noise_latents.append(noise_latent)
975
+ latents = noise_latents[0]
976
+
977
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
978
+ for i, t in enumerate(timesteps):
979
+ count = torch.zeros_like(latents)
980
+ value = torch.zeros_like(latents)
981
+ cosine_factor = 0.5 * (1 + torch.cos(torch.pi * (self.scheduler.config.num_train_timesteps - t) / self.scheduler.config.num_train_timesteps)).cpu()
982
+ if use_skip_residual:
983
+ c1 = cosine_factor ** cosine_scale_1
984
+ latents = latents * (1 - c1) + noise_latents[i] * c1
985
+
986
+ if use_multidiffusion:
987
+ if use_md_prompt:
988
+ md_prompt_embeds_list = []
989
+ md_add_text_embeds_list = []
990
+ for md_prompt in md_prompts[current_scale_num]:
991
+ (
992
+ md_prompt_embeds,
993
+ md_negative_prompt_embeds,
994
+ md_pooled_prompt_embeds,
995
+ md_negative_pooled_prompt_embeds,
996
+ ) = self.encode_prompt(
997
+ prompt=md_prompt,
998
+ prompt_2=prompt_2,
999
+ device=device,
1000
+ num_images_per_prompt=num_images_per_prompt,
1001
+ do_classifier_free_guidance=do_classifier_free_guidance,
1002
+ negative_prompt=negative_prompt,
1003
+ negative_prompt_2=negative_prompt_2,
1004
+ prompt_embeds=None,
1005
+ negative_prompt_embeds=None,
1006
+ pooled_prompt_embeds=None,
1007
+ negative_pooled_prompt_embeds=None,
1008
+ lora_scale=text_encoder_lora_scale,
1009
+ )
1010
+ md_prompt_embeds_list.append(torch.cat([md_negative_prompt_embeds, md_prompt_embeds], dim=0).to(device))
1011
+ md_add_text_embeds_list.append(torch.cat([md_negative_pooled_prompt_embeds, md_pooled_prompt_embeds], dim=0).to(device))
1012
+ del md_negative_prompt_embeds, md_negative_pooled_prompt_embeds
1013
+
1014
+ if use_md_prompt:
1015
+ random_jitter = True
1016
+ views = [(h_start*4, h_end*4, w_start*4, w_end*4) for h_start, h_end, w_start, w_end in views_attention[current_scale_num]]
1017
+ else:
1018
+ random_jitter = True
1019
+ views = self.get_views(current_height, current_width, stride=stride, window_size=self.unet.config.sample_size, random_jitter=random_jitter)
1020
+
1021
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1022
+
1023
+ if use_md_prompt:
1024
+ views_prompt_embeds_input = [md_prompt_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1025
+ views_add_text_embeds_input = [md_add_text_embeds_list[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1026
+
1027
+ if random_jitter:
1028
+ jitter_range = int((self.unet.config.sample_size - stride) // 4)
1029
+ latents_ = F.pad(latents, (jitter_range, jitter_range, jitter_range, jitter_range), 'constant', 0)
1030
+ else:
1031
+ latents_ = latents
1032
+
1033
+ count_local = torch.zeros_like(latents_)
1034
+ value_local = torch.zeros_like(latents_)
1035
+
1036
+ for j, batch_view in enumerate(views_batch):
1037
+ vb_size = len(batch_view)
1038
+ latents_for_view = torch.cat(
1039
+ [
1040
+ latents_[:, :, h_start:h_end, w_start:w_end]
1041
+ for h_start, h_end, w_start, w_end in batch_view
1042
+ ]
1043
+ )
1044
+
1045
+ latent_model_input = latents_for_view
1046
+ latent_model_input = (
1047
+ latent_model_input.repeat_interleave(2, dim=0)
1048
+ if do_classifier_free_guidance
1049
+ else latent_model_input
1050
+ )
1051
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1052
+
1053
+ add_time_ids_input = []
1054
+ for h_start, h_end, w_start, w_end in batch_view:
1055
+ add_time_ids_ = add_time_ids.clone()
1056
+ add_time_ids_[:, 2] = h_start * self.vae_scale_factor
1057
+ add_time_ids_[:, 3] = w_start * self.vae_scale_factor
1058
+ add_time_ids_input.append(add_time_ids_)
1059
+ add_time_ids_input = torch.cat(add_time_ids_input)
1060
+
1061
+ if not use_md_prompt:
1062
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
1063
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
1064
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
1065
+ noise_pred = self.unet(
1066
+ latent_model_input,
1067
+ t,
1068
+ encoder_hidden_states=prompt_embeds_input,
1069
+ added_cond_kwargs=added_cond_kwargs,
1070
+ return_dict=False,
1071
+ )[0]
1072
+ else:
1073
+ md_prompt_embeds_input = torch.cat(views_prompt_embeds_input[j])
1074
+ md_add_text_embeds_input = torch.cat(views_add_text_embeds_input[j])
1075
+ md_added_cond_kwargs = {"text_embeds": md_add_text_embeds_input, "time_ids": add_time_ids_input}
1076
+ noise_pred = self.unet(
1077
+ latent_model_input,
1078
+ t,
1079
+ encoder_hidden_states=md_prompt_embeds_input,
1080
+ added_cond_kwargs=md_added_cond_kwargs,
1081
+ return_dict=False,
1082
+ )[0]
1083
+
1084
+ if do_classifier_free_guidance:
1085
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
1086
+ noise_pred = noise_pred_uncond + multi_guidance_scale * (noise_pred_text - noise_pred_uncond)
1087
+
1088
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1089
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1090
+
1091
+ self.scheduler._init_step_index(t)
1092
+ latents_denoised_batch = self.scheduler.step(
1093
+ noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
1094
+
1095
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
1096
+ latents_denoised_batch.chunk(vb_size), batch_view
1097
+ ):
1098
+ value_local[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
1099
+ count_local[:, :, h_start:h_end, w_start:w_end] += 1
1100
+
1101
+ if random_jitter:
1102
+ value_local = value_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor]
1103
+ count_local = count_local[: ,:, jitter_range: jitter_range + current_height // self.vae_scale_factor, jitter_range: jitter_range + current_width // self.vae_scale_factor]
1104
+
1105
+ noise_index = i + 1 if i != (len(timesteps) - 1) else i
1106
+
1107
+ value_local = torch.where(count_local == 0, noise_latents[noise_index], value_local)
1108
+ count_local = torch.where(count_local == 0, torch.ones_like(count_local), count_local)
1109
+ if use_dilated_sampling:
1110
+ c2 = cosine_factor ** cosine_scale_2
1111
+ value += value_local / count_local * (1 - c2)
1112
+ count += torch.ones_like(value_local) * (1 - c2)
1113
+ else:
1114
+ value += value_local / count_local
1115
+ count += torch.ones_like(value_local)
1116
+
1117
+ if use_dilated_sampling:
1118
+ views = [[h, w] for h in range(current_scale_num) for w in range(current_scale_num)]
1119
+ views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
1120
+
1121
+ h_pad = (current_scale_num - (latents.size(2) % current_scale_num)) % current_scale_num
1122
+ w_pad = (current_scale_num - (latents.size(3) % current_scale_num)) % current_scale_num
1123
+ latents_ = F.pad(latents, (w_pad, 0, h_pad, 0), 'constant', 0)
1124
+
1125
+ count_global = torch.zeros_like(latents_)
1126
+ value_global = torch.zeros_like(latents_)
1127
+
1128
+ if use_guassian:
1129
+ c3 = 0.99 * cosine_factor ** cosine_scale_3 + 1e-2
1130
+ std_, mean_ = latents_.std(), latents_.mean()
1131
+ latents_gaussian = gaussian_filter(latents_, kernel_size=(2*current_scale_num-1), sigma=sigma*c3)
1132
+ latents_gaussian = (latents_gaussian - latents_gaussian.mean()) / latents_gaussian.std() * std_ + mean_
1133
+ else:
1134
+ latents_gaussian = latents_
1135
+
1136
+ for j, batch_view in enumerate(views_batch):
1137
+
1138
+ latents_for_view = torch.cat(
1139
+ [
1140
+ latents_[:, :, h::current_scale_num, w::current_scale_num]
1141
+ for h, w in batch_view
1142
+ ]
1143
+ )
1144
+
1145
+ latents_for_view_gaussian = torch.cat(
1146
+ [
1147
+ latents_gaussian[:, :, h::current_scale_num, w::current_scale_num]
1148
+ for h, w in batch_view
1149
+ ]
1150
+ )
1151
+
1152
+ if shuffle:
1153
+ shape = latents_for_view.shape
1154
+ shuffle_index = torch.stack([torch.randperm(shape[0]) for _ in range(latents_for_view.reshape(-1).shape[0]//shape[0]])
1155
+ shuffle_index = shuffle_index.view(shape[1],shape[2],shape[3],shape[0])
1156
+ original_index = torch.zeros_like(shuffle_index).scatter_(3, shuffle_index, torch.arange(shape[0]).repeat(shape[1], shape[2], shape[3], 1))
1157
+ shuffle_index = shuffle_index.permute(3,0,1,2).to(device)
1158
+ original_index = original_index.permute(3,0,1,2).to(device)
1159
+ latents_for_view_gaussian = latents_for_view_gaussian.gather(0, shuffle_index)
1160
+
1161
+ vb_size = latents_for_view.size(0)
1162
+
1163
+ latent_model_input = latents_for_view_gaussian
1164
+ latent_model_input = (
1165
+ latent_model_input.repeat_interleave(2, dim=0)
1166
+ if do_classifier_free_guidance
1167
+ else latent_model_input
1168
+ )
1169
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1170
+
1171
+ prompt_embeds_input = torch.cat([prompt_embeds] * vb_size)
1172
+ add_text_embeds_input = torch.cat([add_text_embeds] * vb_size)
1173
+ add_time_ids_input = torch.cat([add_time_ids] * vb_size)
1174
+
1175
+ added_cond_kwargs = {"text_embeds": add_text_embeds_input, "time_ids": add_time_ids_input}
1176
+ noise_pred = self.unet(
1177
+ latent_model_input,
1178
+ t,
1179
+ encoder_hidden_states=prompt_embeds_input,
1180
+ added_cond_kwargs=added_cond_kwargs,
1181
+ return_dict=False,
1182
+ )[0]
1183
+
1184
+ if do_classifier_free_guidance:
1185
+ noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2]
1186
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1187
+
1188
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1189
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1190
+
1191
+ if shuffle:
1192
+ noise_pred = noise_pred.gather(0, original_index)
1193
+
1194
+ self.scheduler._init_step_index(t)
1195
+ latents_denoised_batch = self.scheduler.step(noise_pred, t, latents_for_view, **extra_step_kwargs, return_dict=False)[0]
1196
+
1197
+ for latents_view_denoised, (h, w) in zip(
1198
+ latents_denoised_batch.chunk(vb_size), batch_view
1199
+ ):
1200
+ value_global[:, :, h::current_scale_num, w::current_scale_num] += latents_view_denoised
1201
+ count_global[:, :, h::current_scale_num, w::current_scale_num] += 1
1202
+
1203
+ value_global = value_global[: ,:, h_pad:, w_pad:]
1204
+
1205
+ if use_multidiffusion:
1206
+ c2 = cosine_factor ** cosine_scale_2
1207
+ value += value_global * c2
1208
+ count += torch.ones_like(value_global) * c2
1209
+ else:
1210
+ value += value_global
1211
+ count += torch.ones_like(value_global)
1212
+
1213
+ latents = torch.where(count > 0, value / count, value)
1214
+
1215
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1216
+ progress_bar.update()
1217
+ if callback is not None and i % callback_steps == 0:
1218
+ step_idx = i // getattr(self.scheduler, "order", 1)
1219
+ callback(step_idx, t, latents)
1220
+
1221
+ latents = (latents - latents.mean()) / latents.std() * anchor_std + anchor_mean
1222
+ if self.lowvram:
1223
+ latents = latents.cpu()
1224
+ torch.cuda.empty_cache()
1225
+ if not output_type == "latent":
1226
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1227
+ if self.lowvram:
1228
+ needs_upcasting = False
1229
+ self.unet.cpu()
1230
+ self.vae.to(device)
1231
+
1232
+ if needs_upcasting:
1233
+ self.upcast_vae()
1234
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1235
+
1236
+ print("### Phase {} Decoding ###".format(current_scale_num))
1237
+ if current_height > 2048 or current_width > 2048:
1238
+ self.enable_vae_tiling()
1239
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1240
+ else:
1241
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1242
+
1243
+ image = self.image_processor.postprocess(image, output_type=output_type)
1244
+ image[0].save(f'{result_path}/AccDiffusion_{current_scale_num}.png')
1245
+ output_images.append(image[0])
1246
+
1247
+ if needs_upcasting:
1248
+ self.vae.to(dtype=torch.float16)
1249
+ else:
1250
+ image = latents
1251
+
1252
+ self.maybe_free_model_hooks()
1253
+
1254
  return output_images
1255
 
1256
 
 
1287
 
1288
  args = parser.parse_args()
1289
 
 
1290
  pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
1291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1292
  @spaces.GPU(duration=200)
1293
  def infer(prompt, resolution, num_inference_steps, guidance_scale, seed, use_multidiffusion, use_skip_residual, use_dilated_sampling, use_progressive_upscaling, shuffle, use_md_prompt, progress=gr.Progress(track_tqdm=True)):
1294
  set_seed(seed)
 
1318
  cosine_scale_1=args.cosine_scale_1,
1319
  cosine_scale_2=args.cosine_scale_2,
1320
  cosine_scale_3=args.cosine_scale_3,
1321
+ sigma=args.sigma, use_guassian=args.use_guassian,
 
1322
  multi_decoder=args.multi_decoder,
1323
  upscale_mode=args.upscale_mode,
1324
  use_multidiffusion=use_multidiffusion,
 
1327
  use_dilated_sampling=use_dilated_sampling,
1328
  shuffle=shuffle,
1329
  result_path=result_path,
1330
+ debug=args.debug, save_attention_map=args.save_attention_map, use_md_prompt=use_md_prompt, c=args.c
 
 
 
1331
  )
1332
  print(images)
1333
 
1334
  return images
1335
+
 
1336
  MAX_SEED = np.iinfo(np.int32).max
1337
 
1338
+ css = """
1339
+ body {
1340
+ background: linear-gradient(135deg, #2c3e50, #4ca1af);
1341
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
1342
+ color: #ffffff;
1343
+ }
1344
+ #col-container {
1345
+ margin: 20px auto;
1346
+ padding: 20px;
1347
+ max-width: 900px;
1348
+ background-color: rgba(0, 0, 0, 0.5);
1349
+ border-radius: 12px;
1350
+ box-shadow: 0 4px 12px rgba(0,0,0,0.5);
1351
+ }
1352
+ h1, h2 {
1353
+ text-align: center;
1354
+ margin-bottom: 10px;
1355
+ }
1356
+ footer {
1357
+ visibility: hidden;
1358
+ }
1359
+ """
1360
+
1361
  with gr.Blocks(css=css) as demo:
1362
  with gr.Column(elem_id="col-container"):
1363
  gr.Markdown("<h1>AccDiffusion: Advanced AI Art Generator</h1>")
 
1372
  with gr.Row():
1373
  resolution = gr.Radio(
1374
  label="Resolution",
1375
+ choices = [
1376
  "1024,1024", "2048,2048", "2048,1024", "1536,3072", "3072,3072", "4096,4096", "4096,2048"
1377
  ],
1378
+ value = "1024,1024",
1379
  interactive=True
1380
  )
1381
  with gr.Column():
 
1394
  output_images = gr.Gallery(label="Output Images", format="png").style(grid=[2], height="auto")
1395
  gr.Markdown("### Example Prompts")
1396
  gr.Examples(
1397
+ examples = [
1398
  ["A surreal landscape with floating islands and vibrant colors."],
1399
  ["Cyberpunk cityscape at night with neon lights and futuristic architecture."],
1400
  ["A majestic dragon soaring over a medieval castle amidst stormy skies."],
 
1402
  ["Abstract geometric patterns in vivid, pulsating colors."],
1403
  ["A mystical forest illuminated by bioluminescent plants under a starry sky."]
1404
  ],
1405
+ inputs = [prompt],
1406
  label="Click an example to populate the prompt box."
1407
  )
1408
  submit_btn.click(
1409
+ fn = infer,
1410
+ inputs = [prompt, resolution, num_inference_steps, guidance_scale, seed,
1411
+ use_multidiffusion, use_skip_residual, use_dilated_sampling, use_progressive_upscaling, shuffle, use_md_prompt],
1412
+ outputs = [output_images],
1413
  show_api=False
1414
  )
1415
  demo.launch(show_api=False, show_error=True)