prithivMLmods commited on
Commit
02e4f58
·
verified ·
1 Parent(s): 104aded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -39
app.py CHANGED
@@ -13,7 +13,9 @@ import gradio as gr
13
  import spaces
14
  from diffusers import (
15
  DiffusionPipeline,
16
- FlowMatchEulerDiscreteScheduler)
 
 
17
  from huggingface_hub import (
18
  hf_hub_download,
19
  HfFileSystem,
@@ -142,15 +144,30 @@ pipe = DiffusionPipeline.from_pretrained(
142
  base_model, scheduler=scheduler, torch_dtype=dtype
143
  ).to(device)
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  # Lightning LoRA info (no global state)
146
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
147
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
148
 
149
- MAX_SEED = np.iinfo(np.int32).max
150
 
151
- class Timer:
152
- def __init__(self, task_name=""):
153
- self.task_name = task_name
154
 
155
  def __enter__(self):
156
  self.start_time = time.time()
@@ -159,8 +176,8 @@ class Timer:
159
  def __exit__(self, exc_type, exc_value, traceback):
160
  self.end_time = time.time()
161
  self.elapsed_time = self.end_time - self.start_time
162
- if self.task_name:
163
- print(f"Elapsed time for {self.task_name}: {self.elapsed_time:.6f} seconds")
164
  else:
165
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
166
 
@@ -213,26 +230,85 @@ def adjust_generation_mode(speed_mode):
213
  return gr.update(value="Base mode selected - 48 steps for best quality"), 48, 4.0
214
 
215
  @spaces.GPU(duration=100)
216
- def create_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
 
217
  pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  generator = torch.Generator(device="cuda").manual_seed(seed)
219
-
220
- with Timer("Generating image"):
221
- # Generate image
222
- image = pipe(
223
- prompt=prompt_mash,
224
- negative_prompt=negative_prompt,
225
- num_inference_steps=steps,
226
- true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
227
- width=width,
228
- height=height,
229
- generator=generator,
230
- ).images[0]
231
-
232
- return image
 
233
 
234
  @spaces.GPU(duration=100)
235
- def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
236
  if selected_index is None:
237
  raise gr.Error("You must select a LoRA before proceeding.")
238
 
@@ -253,14 +329,16 @@ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomi
253
  prompt_mash = prompt
254
 
255
  # Always unload any existing LoRAs first to avoid conflicts
256
- with Timer("Unloading existing LoRAs"):
257
  pipe.unload_lora_weights()
 
 
 
258
 
259
- # Load LoRAs based on speed mode
260
  if speed_mode == "Fast (8 steps)":
261
- with Timer("Loading Lightning LoRA and style LoRA"):
262
  # Load Lightning LoRA first
263
- pipe.load_lora_weights(
264
  LIGHTNING_LORA_REPO,
265
  weight_name=LIGHTNING_LORA_WEIGHT,
266
  adapter_name="lightning"
@@ -268,7 +346,7 @@ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomi
268
 
269
  # Load the selected style LoRA
270
  weight_name = selected_lora.get("weights", None)
271
- pipe.load_lora_weights(
272
  lora_path,
273
  weight_name=weight_name,
274
  low_cpu_mem_usage=True,
@@ -276,29 +354,36 @@ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomi
276
  )
277
 
278
  # Set both adapters active with their weights
279
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
280
  else:
281
  # Quality mode - only load the style LoRA
282
- with Timer(f"Loading LoRA weights for {selected_lora['title']}"):
283
  weight_name = selected_lora.get("weights", None)
284
- pipe.load_lora_weights(
285
  lora_path,
286
  weight_name=weight_name,
287
  low_cpu_mem_usage=True
288
  )
289
 
290
  # Set random seed for reproducibility
291
- with Timer("Randomizing seed"):
292
  if randomize_seed:
293
  seed = random.randint(0, MAX_SEED)
294
 
295
  # Get image dimensions from aspect ratio
296
  width, height = compute_image_dimensions(aspect_ratio)
297
 
298
- # Generate the image
299
- final_image = create_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
 
 
 
300
 
301
- return final_image, seed
 
 
 
 
302
 
303
  def fetch_hf_adapter_files(link):
304
  split_link = link.split("/")
@@ -422,8 +507,6 @@ def incorporate_custom_adapter(custom_lora):
422
  def discard_custom_adapter():
423
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
424
 
425
- process_adapter_generation.zerogpu = True
426
-
427
  css = '''
428
  #gen_btn{height: 100%}
429
  #gen_column{align-self: stretch}
@@ -436,6 +519,10 @@ css = '''
436
  .card_internal img{margin-right: 1em}
437
  .styler{--form-gap-width: 0px !important}
438
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
 
 
 
 
439
  '''
440
 
441
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
@@ -467,6 +554,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
467
 
468
  with gr.Column():
469
  result = gr.Image(label="Generated Image")
 
470
 
471
  with gr.Row():
472
  aspect_ratio = gr.Dropdown(
@@ -508,6 +596,10 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
508
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
509
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
510
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
 
 
 
 
511
 
512
  # Event handlers
513
  gallery.select(
@@ -536,8 +628,8 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
536
  gr.on(
537
  triggers=[generate_button.click, prompt.submit],
538
  fn=process_adapter_generation,
539
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode],
540
- outputs=[result, seed]
541
  )
542
 
543
  app.queue()
 
13
  import spaces
14
  from diffusers import (
15
  DiffusionPipeline,
16
+ FlowMatchEulerDiscreteScheduler,
17
+ AutoencoderKL,
18
+ AutoPipelineForImage2Image)
19
  from huggingface_hub import (
20
  hf_hub_download,
21
  HfFileSystem,
 
144
  base_model, scheduler=scheduler, torch_dtype=dtype
145
  ).to(device)
146
 
147
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
148
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
149
+ pipe.vae = taef1
150
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
151
+ base_model,
152
+ vae=good_vae,
153
+ transformer=pipe.transformer,
154
+ text_encoder=pipe.text_encoder,
155
+ tokenizer=pipe.tokenizer,
156
+ text_encoder_2=pipe.text_encoder_2,
157
+ tokenizer_2=pipe.tokenizer_2,
158
+ scheduler=scheduler,
159
+ torch_dtype=dtype
160
+ ).to(device)
161
+
162
  # Lightning LoRA info (no global state)
163
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
164
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
165
 
166
+ MAX_SEED = 2**32 - 1
167
 
168
+ class calculateDuration:
169
+ def __init__(self, activity_name=""):
170
+ self.activity_name = activity_name
171
 
172
  def __enter__(self):
173
  self.start_time = time.time()
 
176
  def __exit__(self, exc_type, exc_value, traceback):
177
  self.end_time = time.time()
178
  self.elapsed_time = self.end_time - self.start_time
179
+ if self.activity_name:
180
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
181
  else:
182
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
183
 
 
230
  return gr.update(value="Base mode selected - 48 steps for best quality"), 48, 4.0
231
 
232
  @spaces.GPU(duration=100)
233
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
234
+ generator = torch.Generator(device="cuda").manual_seed(seed)
235
  pipe.to("cuda")
236
+
237
+ batch_size = 1
238
+ prompt = prompt_mash
239
+ do_classifier_free_guidance = cfg_scale > 1.0
240
+ prompt_embeds, pooled_prompt_embeds = pipe.encode_prompt(
241
+ prompt,
242
+ num_images_per_prompt=1,
243
+ do_classifier_free_guidance=do_classifier_free_guidance,
244
+ prompt_2=None,
245
+ max_sequence_length=256,
246
+ )
247
+ height, width = height - height % 16, width - width % 16
248
+ latents = pipe.prepare_latents(
249
+ batch_size,
250
+ pipe.transformer.config.in_channels,
251
+ height,
252
+ width,
253
+ dtype,
254
+ device,
255
+ generator,
256
+ latents=None,
257
+ )
258
+ pipe.scheduler.set_timesteps(steps)
259
+ timesteps = pipe.scheduler.timesteps
260
+ joint_attention_kwargs = {"scale": lora_scale}
261
+ for i in range(steps):
262
+ t = pipe.scheduler.sigmas[i]
263
+ latent_model_input = latents
264
+ with torch.no_grad():
265
+ noise_pred = pipe.transformer(
266
+ hidden_states=latent_model_input,
267
+ timestep=t,
268
+ guidance=cfg_scale,
269
+ pooled_projections=pooled_prompt_embeds,
270
+ encoder_hidden_states=prompt_embeds,
271
+ joint_attention_kwargs=joint_attention_kwargs,
272
+ return_dict=False,
273
+ )[0]
274
+ latents = pipe.scheduler.step(
275
+ model_output=noise_pred,
276
+ timestep=t,
277
+ sample=latent_model_input,
278
+ return_dict=False,
279
+ )[0]
280
+ # preview
281
+ with torch.no_grad():
282
+ decoded = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
283
+ image = pipe.image_processor.pt_to_pil(decoded)[0]
284
+ yield image
285
+ # final
286
+ with torch.no_grad():
287
+ decoded = good_vae.decode(latents / good_vae.config.scaling_factor, return_dict=False)[0]
288
+ image = pipe.image_processor.pt_to_pil(decoded)[0]
289
+ yield image
290
+
291
+ @spaces.GPU(duration=100)
292
+ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
293
  generator = torch.Generator(device="cuda").manual_seed(seed)
294
+ pipe_i2i.to("cuda")
295
+ image_input = load_image(image_input_path)
296
+ final_image = pipe_i2i(
297
+ prompt=prompt_mash,
298
+ image=image_input,
299
+ strength=image_strength,
300
+ num_inference_steps=steps,
301
+ guidance_scale=cfg_scale,
302
+ width=width,
303
+ height=height,
304
+ generator=generator,
305
+ joint_attention_kwargs={"scale": lora_scale},
306
+ output_type="pil",
307
+ ).images[0]
308
+ return final_image
309
 
310
  @spaces.GPU(duration=100)
311
+ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_input, image_strength, negative_prompt="", progress=gr.Progress(track_tqdm=True)):
312
  if selected_index is None:
313
  raise gr.Error("You must select a LoRA before proceeding.")
314
 
 
329
  prompt_mash = prompt
330
 
331
  # Always unload any existing LoRAs first to avoid conflicts
332
+ with calculateDuration("Unloading existing LoRAs"):
333
  pipe.unload_lora_weights()
334
+ pipe_i2i.unload_lora_weights()
335
+
336
+ pipe_to_use = pipe_i2i if image_input is not None else pipe
337
 
 
338
  if speed_mode == "Fast (8 steps)":
339
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
340
  # Load Lightning LoRA first
341
+ pipe_to_use.load_lora_weights(
342
  LIGHTNING_LORA_REPO,
343
  weight_name=LIGHTNING_LORA_WEIGHT,
344
  adapter_name="lightning"
 
346
 
347
  # Load the selected style LoRA
348
  weight_name = selected_lora.get("weights", None)
349
+ pipe_to_use.load_lora_weights(
350
  lora_path,
351
  weight_name=weight_name,
352
  low_cpu_mem_usage=True,
 
354
  )
355
 
356
  # Set both adapters active with their weights
357
+ pipe_to_use.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
358
  else:
359
  # Quality mode - only load the style LoRA
360
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
361
  weight_name = selected_lora.get("weights", None)
362
+ pipe_to_use.load_lora_weights(
363
  lora_path,
364
  weight_name=weight_name,
365
  low_cpu_mem_usage=True
366
  )
367
 
368
  # Set random seed for reproducibility
369
+ with calculateDuration("Randomizing seed"):
370
  if randomize_seed:
371
  seed = random.randint(0, MAX_SEED)
372
 
373
  # Get image dimensions from aspect ratio
374
  width, height = compute_image_dimensions(aspect_ratio)
375
 
376
+ if image_input is not None:
377
+ final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
378
+ yield final_image, seed, gr.update(visible=False)
379
+ else:
380
+ image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt)
381
 
382
+ step_counter = 0
383
+ for image in image_generator:
384
+ step_counter += 1
385
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
386
+ yield image, seed, gr.update(value=progress_bar, visible=True)
387
 
388
  def fetch_hf_adapter_files(link):
389
  split_link = link.split("/")
 
507
  def discard_custom_adapter():
508
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
509
 
 
 
510
  css = '''
511
  #gen_btn{height: 100%}
512
  #gen_column{align-self: stretch}
 
519
  .card_internal img{margin-right: 1em}
520
  .styler{--form-gap-width: 0px !important}
521
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
522
+ #progress{height:30px}
523
+ #progress .generating{display:none}
524
+ .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
525
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
526
  '''
527
 
528
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
 
554
 
555
  with gr.Column():
556
  result = gr.Image(label="Generated Image")
557
+ progress_html = gr.HTML(visible=False, elem_id="progress")
558
 
559
  with gr.Row():
560
  aspect_ratio = gr.Dropdown(
 
596
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
597
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
598
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
599
+
600
+ with gr.Row():
601
+ image_input = gr.Image(label="Input Image for Image2Image", type="filepath")
602
+ image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, step=0.01, value=0.35)
603
 
604
  # Event handlers
605
  gallery.select(
 
628
  gr.on(
629
  triggers=[generate_button.click, prompt.submit],
630
  fn=process_adapter_generation,
631
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_input, image_strength],
632
+ outputs=[result, seed, progress_html]
633
  )
634
 
635
  app.queue()