prithivMLmods commited on
Commit
35e8372
·
verified ·
1 Parent(s): db9fb52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -210
app.py CHANGED
@@ -13,15 +13,17 @@ import gradio as gr
13
  import spaces
14
  from diffusers import (
15
  DiffusionPipeline,
16
- FlowMatchEulerDiscreteScheduler,
17
- AutoencoderTiny,
18
  AutoencoderKL,
19
- AutoPipelineForImage2Image,)
 
 
 
20
  from huggingface_hub import (
21
  hf_hub_download,
22
  HfFileSystem,
23
  ModelCard,
24
- snapshot_download)
 
25
  from diffusers.utils import load_image
26
  import requests
27
  from urllib.parse import urlparse
@@ -118,10 +120,14 @@ loras = [
118
  },
119
  ]
120
 
121
- # Initialize the base model
122
  dtype = torch.bfloat16
123
  base_model = "Qwen/Qwen-Image"
124
 
 
 
 
 
125
  # Scheduler configuration from the Qwen-Image-Lightning repository
126
  scheduler_config = {
127
  "base_image_seq_len": 256,
@@ -141,34 +147,30 @@ scheduler_config = {
141
  }
142
 
143
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
 
 
144
  pipe = DiffusionPipeline.from_pretrained(
145
- base_model, scheduler=scheduler, torch_dtype=dtype
146
  ).to(device)
147
 
148
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
149
- good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
150
- pipe.vae = taef1
151
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
152
  base_model,
153
  vae=good_vae,
154
- transformer=pipe.transformer,
155
- text_encoder=pipe.text_encoder,
156
- tokenizer=pipe.tokenizer,
157
- text_encoder_2=pipe.text_encoder_2,
158
- tokenizer_2=pipe.tokenizer_2,
159
  scheduler=scheduler,
160
  torch_dtype=dtype
161
  ).to(device)
162
 
 
163
  # Lightning LoRA info (no global state)
164
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
165
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
166
 
167
- MAX_SEED = 2**32 - 1
168
 
169
- class calculateDuration:
170
- def __init__(self, activity_name=""):
171
- self.activity_name = activity_name
172
 
173
  def __enter__(self):
174
  self.start_time = time.time()
@@ -177,8 +179,8 @@ class calculateDuration:
177
  def __exit__(self, exc_type, exc_value, traceback):
178
  self.end_time = time.time()
179
  self.elapsed_time = self.end_time - self.start_time
180
- if self.activity_name:
181
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
182
  else:
183
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
184
 
@@ -230,88 +232,32 @@ def adjust_generation_mode(speed_mode):
230
  else:
231
  return gr.update(value="Base mode selected - 48 steps for best quality"), 48, 4.0
232
 
233
- @spaces.GPU(duration=100)
234
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
235
- generator = torch.Generator(device="cuda").manual_seed(seed)
236
- pipe.to("cuda")
237
-
238
- batch_size = 1
239
- prompt = prompt_mash
240
- do_classifier_free_guidance = cfg_scale > 1.0
241
- prompt_embeds, pooled_prompt_embeds = pipe.encode_prompt(
242
- prompt,
243
- num_images_per_prompt=1,
244
- do_classifier_free_guidance=do_classifier_free_guidance,
245
- prompt_2=None,
246
- max_sequence_length=256,
247
- )
248
- height, width = height - height % 16, width - width % 16
249
- latents = pipe.prepare_latents(
250
- batch_size,
251
- pipe.transformer.config.in_channels,
252
- height,
253
- width,
254
- dtype,
255
- device,
256
- generator,
257
- latents=None,
258
- )
259
- pipe.scheduler.set_timesteps(steps)
260
- timesteps = pipe.scheduler.timesteps
261
- joint_attention_kwargs = {"scale": lora_scale}
262
- for i in range(steps):
263
- t = pipe.scheduler.sigmas[i]
264
- latent_model_input = latents
265
- with torch.no_grad():
266
- noise_pred = pipe.transformer(
267
- hidden_states=latent_model_input,
268
- timestep=t,
269
- guidance=cfg_scale,
270
- pooled_projections=pooled_prompt_embeds,
271
- encoder_hidden_states=prompt_embeds,
272
- joint_attention_kwargs=joint_attention_kwargs,
273
- return_dict=False,
274
- )[0]
275
- latents = pipe.scheduler.step(
276
- model_output=noise_pred,
277
- timestep=t,
278
- sample=latent_model_input,
279
- return_dict=False,
280
- )[0]
281
- # preview
282
- with torch.no_grad():
283
- decoded = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
284
- image = pipe.image_processor.pt_to_pil(decoded)[0]
285
- yield image
286
- # final
287
- with torch.no_grad():
288
- decoded = good_vae.decode(latents / good_vae.config.scaling_factor, return_dict=False)[0]
289
- image = pipe.image_processor.pt_to_pil(decoded)[0]
290
- yield image
291
-
292
- @spaces.GPU(duration=100)
293
- def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed):
294
  generator = torch.Generator(device="cuda").manual_seed(seed)
295
  pipe_i2i.to("cuda")
296
- image_input = load_image(image_input_path)
 
 
 
297
  final_image = pipe_i2i(
298
  prompt=prompt_mash,
299
- image=image_input,
300
- strength=image_strength,
301
  num_inference_steps=steps,
302
  guidance_scale=cfg_scale,
303
- width=width,
304
- height=height,
305
  generator=generator,
306
- joint_attention_kwargs={"scale": lora_scale},
307
- output_type="pil",
308
  ).images[0]
309
  return final_image
310
 
311
  @spaces.GPU(duration=100)
312
- def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_input, image_strength, negative_prompt="", progress=gr.Progress(track_tqdm=True)):
 
 
 
313
  if selected_index is None:
314
- raise gr.Error("You must select a LoRA before proceeding.")
315
 
316
  selected_lora = loras[selected_index]
317
  lora_path = selected_lora["repo"]
@@ -319,72 +265,85 @@ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomi
319
 
320
  # Prepare prompt with trigger word
321
  if trigger_word:
322
- if "trigger_position" in selected_lora:
323
- if selected_lora["trigger_position"] == "prepend":
324
- prompt_mash = f"{trigger_word} {prompt}"
325
- else:
326
- prompt_mash = f"{prompt} {trigger_word}"
327
- else:
328
- prompt_mash = f"{trigger_word} {prompt}"
329
  else:
330
  prompt_mash = prompt
331
 
332
- # Always unload any existing LoRAs first to avoid conflicts
333
- with calculateDuration("Unloading existing LoRAs"):
334
- pipe.unload_lora_weights()
335
- pipe_i2i.unload_lora_weights()
336
-
337
  pipe_to_use = pipe_i2i if image_input is not None else pipe
 
 
 
 
338
 
 
339
  if speed_mode == "Fast (8 steps)":
340
- with calculateDuration("Loading Lightning LoRA and style LoRA"):
341
- # Load Lightning LoRA first
342
  pipe_to_use.load_lora_weights(
343
  LIGHTNING_LORA_REPO,
344
  weight_name=LIGHTNING_LORA_WEIGHT,
345
  adapter_name="lightning"
346
  )
347
-
348
- # Load the selected style LoRA
349
- weight_name = selected_lora.get("weights", None)
350
  pipe_to_use.load_lora_weights(
351
  lora_path,
352
  weight_name=weight_name,
353
- low_cpu_mem_usage=True,
354
  adapter_name="style"
355
  )
356
-
357
- # Set both adapters active with their weights
358
  pipe_to_use.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
359
- else:
360
- # Quality mode - only load the style LoRA
361
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
362
- weight_name = selected_lora.get("weights", None)
363
- pipe_to_use.load_lora_weights(
364
- lora_path,
365
- weight_name=weight_name,
366
- low_cpu_mem_usage=True
367
- )
368
-
369
- # Set random seed for reproducibility
370
- with calculateDuration("Randomizing seed"):
371
- if randomize_seed:
372
- seed = random.randint(0, MAX_SEED)
373
 
374
- # Get image dimensions from aspect ratio
375
  width, height = compute_image_dimensions(aspect_ratio)
376
-
 
377
  if image_input is not None:
378
- final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
 
379
  yield final_image, seed, gr.update(visible=False)
380
  else:
381
- image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt)
382
-
383
- step_counter = 0
384
- for image in image_generator:
385
- step_counter += 1
386
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
387
- yield image, seed, gr.update(value=progress_bar, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  def fetch_hf_adapter_files(link):
390
  split_link = link.split("/")
@@ -393,79 +352,37 @@ def fetch_hf_adapter_files(link):
393
 
394
  print(f"Repository attempted: {split_link}")
395
 
396
- # Load model card
397
  model_card = ModelCard.load(link)
398
  base_model = model_card.data.get("base_model")
399
  print(f"Base model: {base_model}")
400
 
401
- # Validate model type (for Qwen-Image)
402
  acceptable_models = {"Qwen/Qwen-Image"}
403
-
404
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
405
 
406
  if not any(model in acceptable_models for model in models_to_check):
407
  raise Exception("Not a Qwen-Image LoRA!")
408
 
409
- # Extract image and trigger word
410
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
411
  trigger_word = model_card.data.get("instance_prompt", "")
412
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
413
 
414
- # Initialize Hugging Face file system
415
  fs = HfFileSystem()
416
  try:
417
  list_of_files = fs.ls(link, detail=False)
418
-
419
- # Find safetensors file
420
- safetensors_name = None
421
- for file in list_of_files:
422
- filename = file.split("/")[-1]
423
- if filename.endswith(".safetensors"):
424
- safetensors_name = filename
425
- break
426
-
427
  if not safetensors_name:
428
  raise Exception("No valid *.safetensors file found in the repository.")
429
-
430
  except Exception as e:
431
  print(e)
432
- raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
433
 
434
  return split_link[1], link, safetensors_name, trigger_word, image_url
435
 
436
  def validate_custom_adapter(link):
437
  print(f"Checking a custom model on: {link}")
438
-
439
- if link.endswith('.safetensors'):
440
- if 'huggingface.co' in link:
441
- parts = link.split('/')
442
- try:
443
- hf_index = parts.index('huggingface.co')
444
- username = parts[hf_index + 1]
445
- repo_name = parts[hf_index + 2]
446
- repo = f"{username}/{repo_name}"
447
-
448
- safetensors_name = parts[-1]
449
-
450
- try:
451
- model_card = ModelCard.load(repo)
452
- trigger_word = model_card.data.get("instance_prompt", "")
453
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
454
- image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
455
- except:
456
- trigger_word = ""
457
- image_url = None
458
-
459
- return repo_name, repo, safetensors_name, trigger_word, image_url
460
- except:
461
- raise Exception("Invalid safetensors URL format")
462
-
463
- if link.startswith("https://"):
464
- if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
465
- link_split = link.split("huggingface.co/")
466
- return fetch_hf_adapter_files(link_split[1])
467
- else:
468
- return fetch_hf_adapter_files(link)
469
 
470
  def incorporate_custom_adapter(custom_lora):
471
  global loras
@@ -485,29 +402,22 @@ def incorporate_custom_adapter(custom_lora):
485
  </div>
486
  </div>
487
  '''
488
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
489
  if existing_item_index is None:
490
- new_item = {
491
- "image": image,
492
- "title": title,
493
- "repo": repo,
494
- "weights": path,
495
- "trigger_word": trigger_word
496
- }
497
- print(new_item)
498
  loras.append(new_item)
499
- existing_item_index = len(loras) - 1 # Get the actual index after adding
500
 
501
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
502
  except Exception as e:
503
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
504
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, ""
505
- else:
506
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
507
 
508
  def discard_custom_adapter():
509
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
510
 
 
511
  css = '''
512
  #gen_btn{height: 100%}
513
  #gen_column{align-self: stretch}
@@ -523,7 +433,7 @@ css = '''
523
  #progress{height:30px}
524
  #progress .generating{display:none}
525
  .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
526
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
527
  '''
528
 
529
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
@@ -547,6 +457,10 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
547
  elem_id="gallery",
548
  show_share_button=False
549
  )
 
 
 
 
550
  with gr.Group():
551
  custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/lora-model-name")
552
  gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
@@ -555,14 +469,14 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
555
 
556
  with gr.Column():
557
  result = gr.Image(label="Generated Image")
558
- progress_html = gr.HTML(visible=False, elem_id="progress")
559
 
560
  with gr.Row():
561
  aspect_ratio = gr.Dropdown(
562
  label="Aspect Ratio",
563
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
564
  value="1:1"
565
- )
566
  with gr.Row():
567
  speed_mode = gr.Dropdown(
568
  label="Output Mode",
@@ -577,12 +491,12 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
577
  with gr.Column():
578
  with gr.Row():
579
  cfg_scale = gr.Slider(
580
- label="Guidance Scale (True CFG)",
581
  minimum=1.0,
582
  maximum=5.0,
583
  step=0.1,
584
  value=4.0,
585
- info="Lower for speed mode, higher for quality"
586
  )
587
  steps = gr.Slider(
588
  label="Steps",
@@ -597,10 +511,6 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
597
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
598
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
599
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
600
-
601
- with gr.Row():
602
- image_input = gr.Image(label="Input Image for Image2Image", type="filepath")
603
- image_strength = gr.Slider(label="Image Strength", minimum=0, maximum=1, step=0.01, value=0.35)
604
 
605
  # Event handlers
606
  gallery.select(
@@ -626,11 +536,18 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
626
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
627
  )
628
 
629
- gr.on(
630
- triggers=[generate_button.click, prompt.submit],
631
- fn=process_adapter_generation,
632
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_input, image_strength],
633
- outputs=[result, seed, progress_html]
 
 
 
 
 
 
 
634
  )
635
 
636
  app.queue()
 
13
  import spaces
14
  from diffusers import (
15
  DiffusionPipeline,
 
 
16
  AutoencoderKL,
17
+ AutoencoderTiny,
18
+ AutoPipelineForImage2Image,
19
+ FlowMatchEulerDiscreteScheduler
20
+ )
21
  from huggingface_hub import (
22
  hf_hub_download,
23
  HfFileSystem,
24
  ModelCard,
25
+ snapshot_download
26
+ )
27
  from diffusers.utils import load_image
28
  import requests
29
  from urllib.parse import urlparse
 
120
  },
121
  ]
122
 
123
+ # Initialize the base model and autoencoders
124
  dtype = torch.bfloat16
125
  base_model = "Qwen/Qwen-Image"
126
 
127
+ # Initialize TAEF1 for fast previews and the standard VAE for high-quality final images
128
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
129
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
130
+
131
  # Scheduler configuration from the Qwen-Image-Lightning repository
132
  scheduler_config = {
133
  "base_image_seq_len": 256,
 
147
  }
148
 
149
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
150
+
151
+ # Main pipeline for text-to-image, using taef1 for fast decoding during generation
152
  pipe = DiffusionPipeline.from_pretrained(
153
+ base_model, scheduler=scheduler, torch_dtype=dtype, vae=taef1
154
  ).to(device)
155
 
156
+ # Image-to-image pipeline, using the high-quality VAE
 
 
157
  pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
158
  base_model,
159
  vae=good_vae,
 
 
 
 
 
160
  scheduler=scheduler,
161
  torch_dtype=dtype
162
  ).to(device)
163
 
164
+
165
  # Lightning LoRA info (no global state)
166
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
167
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
168
 
169
+ MAX_SEED = np.iinfo(np.int32).max
170
 
171
+ class Timer:
172
+ def __init__(self, task_name=""):
173
+ self.task_name = task_name
174
 
175
  def __enter__(self):
176
  self.start_time = time.time()
 
179
  def __exit__(self, exc_type, exc_value, traceback):
180
  self.end_time = time.time()
181
  self.elapsed_time = self.end_time - self.start_time
182
+ if self.task_name:
183
+ print(f"Elapsed time for {self.task_name}: {self.elapsed_time:.6f} seconds")
184
  else:
185
  print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
186
 
 
232
  else:
233
  return gr.update(value="Base mode selected - 48 steps for best quality"), 48, 4.0
234
 
235
+ def image_to_image_generation(prompt_mash, image_input, strength, steps, cfg_scale, width, height, lora_scale, seed):
236
+ """Handles the image-to-image generation process."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  generator = torch.Generator(device="cuda").manual_seed(seed)
238
  pipe_i2i.to("cuda")
239
+
240
+ # Resize and convert input image
241
+ image_input_pil = load_image(image_input).resize((width, height), Image.Resampling.LANCZOS)
242
+
243
  final_image = pipe_i2i(
244
  prompt=prompt_mash,
245
+ image=image_input_pil,
246
+ strength=strength,
247
  num_inference_steps=steps,
248
  guidance_scale=cfg_scale,
 
 
249
  generator=generator,
250
+ # Note: image-to-image with Qwen doesn't use `true_cfg_scale`
 
251
  ).images[0]
252
  return final_image
253
 
254
  @spaces.GPU(duration=100)
255
+ def process_generation_request(
256
+ prompt, image_input, image_strength, cfg_scale, steps, selected_index,
257
+ randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)
258
+ ):
259
  if selected_index is None:
260
+ raise gr.Error("You must select a LoRA before proceeding.🧨")
261
 
262
  selected_lora = loras[selected_index]
263
  lora_path = selected_lora["repo"]
 
265
 
266
  # Prepare prompt with trigger word
267
  if trigger_word:
268
+ prompt_mash = f"{trigger_word}, {prompt}" if prompt else trigger_word
 
 
 
 
 
 
269
  else:
270
  prompt_mash = prompt
271
 
272
+ # Set random seed if requested
273
+ if randomize_seed:
274
+ seed = random.randint(0, MAX_SEED)
275
+
276
+ # Determine which pipeline to use
277
  pipe_to_use = pipe_i2i if image_input is not None else pipe
278
+
279
+ # Always unload any existing LoRAs first to avoid conflicts
280
+ with Timer("Unloading existing LoRAs"):
281
+ pipe_to_use.unload_lora_weights()
282
 
283
+ # Load LoRAs based on speed mode
284
  if speed_mode == "Fast (8 steps)":
285
+ with Timer("Loading Lightning LoRA and style LoRA"):
 
286
  pipe_to_use.load_lora_weights(
287
  LIGHTNING_LORA_REPO,
288
  weight_name=LIGHTNING_LORA_WEIGHT,
289
  adapter_name="lightning"
290
  )
291
+ weight_name = selected_lora.get("weights")
 
 
292
  pipe_to_use.load_lora_weights(
293
  lora_path,
294
  weight_name=weight_name,
 
295
  adapter_name="style"
296
  )
 
 
297
  pipe_to_use.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
298
+ else: # Quality mode
299
+ with Timer(f"Loading LoRA weights for {selected_lora['title']}"):
300
+ weight_name = selected_lora.get("weights")
301
+ pipe_to_use.load_lora_weights(lora_path, weight_name=weight_name)
 
 
 
 
 
 
 
 
 
 
302
 
 
303
  width, height = compute_image_dimensions(aspect_ratio)
304
+
305
+ # --- Generation ---
306
  if image_input is not None:
307
+ # Image-to-Image Generation
308
+ final_image = image_to_image_generation(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed)
309
  yield final_image, seed, gr.update(visible=False)
310
  else:
311
+ # Text-to-Image Generation with Previews
312
+ pipe.to("cuda")
313
+ generator = torch.Generator(device="cuda").manual_seed(seed)
314
+
315
+ # Callback for generating previews
316
+ def callback_on_step_end(pipe, step_index, timestep, callback_kwargs):
317
+ latents = callback_kwargs["latents"]
318
+ # Use the fast taef1 decoder for previews
319
+ with torch.no_grad():
320
+ image = pipe.decode_latents(latents.to(dtype))[0]
321
+ progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_index + 1}; --total: {steps};"></div></div>'
322
+ yield {"image": image, "seed": seed, "progress": gr.update(value=progress_bar, visible=True)}
323
+ return callback_kwargs
324
+
325
+ # Generate image with step-by-step previews
326
+ with Timer("Generating image with previews"):
327
+ generation_output = pipe(
328
+ prompt=prompt_mash,
329
+ num_inference_steps=steps,
330
+ true_cfg_scale=cfg_scale,
331
+ width=width,
332
+ height=height,
333
+ generator=generator,
334
+ output_type="latent", # Get latents to decode with the good VAE later
335
+ callback_on_step_end=callback_on_step_end
336
+ )
337
+
338
+ # Decode the final image with the high-quality VAE
339
+ with Timer("Final decoding with good VAE"):
340
+ final_latents = generation_output.images
341
+ pipe.vae = good_vae # Temporarily swap to the good VAE
342
+ final_image = pipe.decode_latents(final_latents.to(dtype))[0]
343
+ pipe.vae = taef1 # Swap back to taef1 for the next run
344
+
345
+ yield final_image, seed, gr.update(visible=False)
346
+
347
 
348
  def fetch_hf_adapter_files(link):
349
  split_link = link.split("/")
 
352
 
353
  print(f"Repository attempted: {split_link}")
354
 
 
355
  model_card = ModelCard.load(link)
356
  base_model = model_card.data.get("base_model")
357
  print(f"Base model: {base_model}")
358
 
 
359
  acceptable_models = {"Qwen/Qwen-Image"}
 
360
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
361
 
362
  if not any(model in acceptable_models for model in models_to_check):
363
  raise Exception("Not a Qwen-Image LoRA!")
364
 
365
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url")
 
366
  trigger_word = model_card.data.get("instance_prompt", "")
367
  image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
368
 
 
369
  fs = HfFileSystem()
370
  try:
371
  list_of_files = fs.ls(link, detail=False)
372
+ safetensors_name = next((f.split('/')[-1] for f in list_of_files if f.endswith(".safetensors")), None)
 
 
 
 
 
 
 
 
373
  if not safetensors_name:
374
  raise Exception("No valid *.safetensors file found in the repository.")
 
375
  except Exception as e:
376
  print(e)
377
+ raise Exception("Could not find a valid *.safetensors file in the Hugging Face repository.")
378
 
379
  return split_link[1], link, safetensors_name, trigger_word, image_url
380
 
381
  def validate_custom_adapter(link):
382
  print(f"Checking a custom model on: {link}")
383
+ if link.startswith("https://huggingface.co"):
384
+ link = urlparse(link).path.strip("/")
385
+ return fetch_hf_adapter_files(link)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  def incorporate_custom_adapter(custom_lora):
388
  global loras
 
402
  </div>
403
  </div>
404
  '''
405
+ existing_item_index = next((i for i, item in enumerate(loras) if item['repo'] == repo), None)
406
  if existing_item_index is None:
407
+ new_item = {"image": image, "title": title, "repo": repo, "weights": path, "trigger_word": trigger_word}
 
 
 
 
 
 
 
408
  loras.append(new_item)
409
+ existing_item_index = len(loras) - 1
410
 
411
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
412
  except Exception as e:
413
+ gr.Warning(f"Invalid LoRA: {e}")
414
+ return gr.update(visible=True, value=f"Invalid LoRA: {e}"), gr.update(visible=True), gr.update(), "", None, ""
415
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
416
 
417
  def discard_custom_adapter():
418
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
419
 
420
+
421
  css = '''
422
  #gen_btn{height: 100%}
423
  #gen_column{align-self: stretch}
 
433
  #progress{height:30px}
434
  #progress .generating{display:none}
435
  .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
436
+ .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.1s ease-in-out}
437
  '''
438
 
439
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
 
457
  elem_id="gallery",
458
  show_share_button=False
459
  )
460
+ with gr.Accordion("Image-to-Image (Optional)", open=False):
461
+ image_input = gr.Image(type="filepath", label="Input Image")
462
+ image_strength = gr.Slider(label="Image Strength", minimum=0.1, maximum=1.0, step=0.05, value=0.6)
463
+
464
  with gr.Group():
465
  custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/lora-model-name")
466
  gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
 
469
 
470
  with gr.Column():
471
  result = gr.Image(label="Generated Image")
472
+ progress_bar = gr.HTML(visible=False, elem_id="progress")
473
 
474
  with gr.Row():
475
  aspect_ratio = gr.Dropdown(
476
  label="Aspect Ratio",
477
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
478
  value="1:1"
479
+ )
480
  with gr.Row():
481
  speed_mode = gr.Dropdown(
482
  label="Output Mode",
 
491
  with gr.Column():
492
  with gr.Row():
493
  cfg_scale = gr.Slider(
494
+ label="Guidance Scale",
495
  minimum=1.0,
496
  maximum=5.0,
497
  step=0.1,
498
  value=4.0,
499
+ info="Lower for speed mode, higher for quality. Also called 'True CFG'."
500
  )
501
  steps = gr.Slider(
502
  label="Steps",
 
511
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
512
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
513
  lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
 
 
 
 
514
 
515
  # Event handlers
516
  gallery.select(
 
536
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
537
  )
538
 
539
+ gen_inputs = [prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode]
540
+ gen_outputs = [result, seed, progress_bar]
541
+
542
+ generate_button.click(
543
+ fn=process_generation_request,
544
+ inputs=gen_inputs,
545
+ outputs=gen_outputs
546
+ )
547
+ prompt.submit(
548
+ fn=process_generation_request,
549
+ inputs=gen_inputs,
550
+ outputs=gen_outputs
551
  )
552
 
553
  app.queue()