prithivMLmods commited on
Commit
90a0e37
·
verified ·
1 Parent(s): 151bae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -284
app.py CHANGED
@@ -37,54 +37,54 @@ if torch.cuda.is_available():
37
  print("current device:", torch.cuda.current_device())
38
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
39
 
40
- print("Using device:", processing_device)
41
 
42
- # List of predefined style models (formerly LoRAs)
43
- style_definitions = [
44
  {
45
- "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Studio-Realism/resolve/main/images/2.png",
46
- "style_name": "Studio Realism",
47
- "repo_id": "prithivMLmods/Qwen-Image-Studio-Realism",
48
- "weight_file": "qwen-studio-realism.safetensors",
49
- "activation_phrase": "Studio Realism"
50
  },
51
  {
52
- "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Sketch-Smudge/resolve/main/images/1.png",
53
- "style_name": "Sketch Smudge",
54
- "repo_id": "prithivMLmods/Qwen-Image-Sketch-Smudge",
55
- "weight_file": "qwen-sketch-smudge.safetensors",
56
- "activation_phrase": "Sketch Smudge"
57
  },
58
  {
59
- "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Anime-LoRA/resolve/main/images/1.png",
60
- "style_name": "Qwen Anime",
61
- "repo_id": "prithivMLmods/Qwen-Image-Anime-LoRA",
62
- "weight_file": "qwen-anime.safetensors",
63
- "activation_phrase": "Qwen Anime"
64
  },
65
  {
66
- "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Synthetic-Face/resolve/main/images/2.png",
67
- "style_name": "Synthetic Face",
68
- "repo_id": "prithivMLmods/Qwen-Image-Synthetic-Face",
69
- "weight_file": "qwen-synthetic-face.safetensors",
70
- "activation_phrase": "Synthetic Face"
71
  },
72
  {
73
- "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Fragmented-Portraiture/resolve/main/images/3.png",
74
- "style_name": "Fragmented Portraiture",
75
- "repo_id": "prithivMLmods/Qwen-Image-Fragmented-Portraiture",
76
- "weight_file": "qwen-fragmented-portraiture.safetensors",
77
- "activation_phrase": "Fragmented Portraiture"
78
  },
79
  ]
80
 
81
- # --- Model Initialization ---
82
- model_precision = torch.bfloat16
83
- processing_device = "cuda" if torch.cuda.is_available() else "cpu"
84
- foundation_model_id = "Qwen/Qwen-Image"
85
 
86
- # Sampler configuration from the Qwen-Image-Lightning repository
87
- sampler_settings = {
88
  "base_image_seq_len": 256,
89
  "base_shift": math.log(3),
90
  "invert_sigmas": False,
@@ -101,21 +101,20 @@ sampler_settings = {
101
  "use_karras_sigmas": False,
102
  }
103
 
104
- sampler = FlowMatchEulerDiscreteScheduler.from_config(sampler_settings)
105
- diffusion_pipeline = DiffusionPipeline.from_pretrained(
106
- foundation_model_id, scheduler=sampler, torch_dtype=model_precision
107
- ).to(processing_device)
108
 
109
- # Information for the fast generation LoRA
110
- FAST_GENERATION_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
111
- FAST_GENERATION_LORA_WEIGHTS = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
112
 
113
- MAX_SEED_VALUE = np.iinfo(np.int32).max
114
 
115
- class ExecutionTimer:
116
- """A context manager to time a block of code."""
117
- def __init__(self, activity_name=""):
118
- self.activity_name = activity_name
119
 
120
  def __enter__(self):
121
  self.start_time = time.time()
@@ -124,269 +123,272 @@ class ExecutionTimer:
124
  def __exit__(self, exc_type, exc_value, traceback):
125
  self.end_time = time.time()
126
  self.elapsed_time = self.end_time - self.start_time
127
- activity_log = f" for {self.activity_name}" if self.activity_name else ""
128
- print(f"Elapsed time{activity_log}: {self.elapsed_time:.6f} seconds")
129
-
130
- def get_dimensions_from_ratio(aspect_ratio_str):
131
- """Converts an aspect ratio string to a (width, height) tuple."""
132
- ratios = {
133
- "1:1": (1024, 1024),
134
- "16:9": (1152, 640),
135
- "9:16": (640, 1152),
136
- "4:3": (1024, 768),
137
- "3:4": (768, 1024),
138
- "3:2": (1024, 688),
139
- "2:3": (688, 1024),
140
- }
141
- return ratios.get(aspect_ratio_str, (1024, 1024))
142
-
143
- def on_style_select(event_data: gr.SelectData, current_aspect_ratio):
144
- """Handles the user selecting a style from the gallery."""
145
- selected_style = style_definitions[event_data.index]
146
- new_placeholder = f"Type a prompt for {selected_style['style_name']}"
147
- repo_id = selected_style["repo_id"]
148
- updated_info_text = f"### Selected: [{repo_id}](https://huggingface.co/{repo_id}) ✨"
 
 
 
 
 
 
 
149
 
150
- # Update aspect ratio if specified in the style's configuration
151
- if "aspect" in selected_style:
152
- if selected_style["aspect"] == "portrait":
153
- current_aspect_ratio = "9:16"
154
- elif selected_style["aspect"] == "landscape":
155
- current_aspect_ratio = "16:9"
156
  else:
157
- current_aspect_ratio = "1:1"
158
 
159
  return (
160
  gr.update(placeholder=new_placeholder),
161
- updated_info_text,
162
- event_data.index,
163
- current_aspect_ratio,
164
  )
165
 
166
- def on_mode_change(generation_mode):
167
- """Updates UI elements based on the selected generation mode (Speed/Quality)."""
168
- if generation_mode == "Speed (8 steps)":
169
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
170
  else:
171
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
172
 
173
  @spaces.GPU(duration=70)
174
- def execute_image_generation(full_prompt, steps, seed_val, cfg, width, height, negative_prompt=""):
175
- """Generates an image using the diffusion pipeline."""
176
- diffusion_pipeline.to("cuda")
177
- generator = torch.Generator(device="cuda").manual_seed(seed_val)
178
 
179
- with ExecutionTimer("Image Generation"):
180
- generated_image = diffusion_pipeline(
181
- prompt=full_prompt,
 
182
  negative_prompt=negative_prompt,
183
  num_inference_steps=steps,
184
- true_cfg_scale=cfg,
185
  width=width,
186
  height=height,
187
  generator=generator,
188
  ).images[0]
189
 
190
- return generated_image
191
 
192
  @spaces.GPU(duration=70)
193
- def handle_generate_request(prompt_text, cfg, steps, style_idx, use_random_seed, seed_val, aspect_ratio_str, style_scale, generation_mode, progress=gr.Progress(track_tqdm=True)):
194
- """Main function to handle a user's image generation request."""
195
- if style_idx is None:
196
- raise gr.Error("You must select a style before generating an image.")
197
 
198
- selected_style = style_definitions[style_idx]
199
- style_repo_path = selected_style["repo_id"]
200
- activation_phrase = selected_style["activation_phrase"]
201
 
202
- # Combine the user prompt with the style's activation phrase
203
- if activation_phrase:
204
- position = selected_style.get("trigger_position", "prepend")
205
- if position == "prepend":
206
- full_prompt = f"{activation_phrase} {prompt_text}"
 
 
207
  else:
208
- full_prompt = f"{prompt_text} {activation_phrase}"
209
  else:
210
- full_prompt = prompt_text
211
-
212
- # Always unload existing adapters to start fresh
213
- with ExecutionTimer("Unloading existing adapters"):
214
- diffusion_pipeline.unload_lora_weights()
215
-
216
- # Load adapters based on the selected generation mode
217
- if generation_mode == "Speed (8 steps)":
218
- with ExecutionTimer("Loading Lightning and Style adapters"):
219
- # Load the fast generation adapter first
220
- diffusion_pipeline.load_lora_weights(
221
- FAST_GENERATION_LORA_REPO,
222
- weight_name=FAST_GENERATION_LORA_WEIGHTS,
223
  adapter_name="lightning"
224
  )
225
 
226
- # Load the selected style adapter
227
- weight_file = selected_style.get("weight_file", None)
228
- diffusion_pipeline.load_lora_weights(
229
- style_repo_path,
230
- weight_name=weight_file,
231
  low_cpu_mem_usage=True,
232
  adapter_name="style"
233
  )
234
 
235
- # Set both adapters active with their respective weights
236
- diffusion_pipeline.set_adapters(["lightning", "style"], adapter_weights=[1.0, style_scale])
237
- else: # Quality mode
238
- with ExecutionTimer(f"Loading adapter weights for {selected_style['style_name']}"):
239
- weight_file = selected_style.get("weight_file", None)
240
- diffusion_pipeline.load_lora_weights(
241
- style_repo_path,
242
- weight_name=weight_file,
 
243
  low_cpu_mem_usage=True
244
  )
245
 
246
- # Set the seed for reproducibility
247
- with ExecutionTimer("Setting seed"):
248
- if use_random_seed:
249
- seed_val = random.randint(0, MAX_SEED_VALUE)
250
 
251
- # Get image dimensions
252
- width, height = get_dimensions_from_ratio(aspect_ratio_str)
253
 
254
- # Generate the final image
255
- final_image = execute_image_generation(full_prompt, steps, seed_val, cfg, width, height)
256
 
257
- return final_image, seed_val
258
 
259
- def fetch_hf_safetensors_details(repo_link):
260
- """Fetches details of a LoRA from a Hugging Face repository."""
261
- split_link = repo_link.split("/")
262
  if len(split_link) != 2:
263
- raise ValueError("Invalid Hugging Face repository link format.")
264
 
265
- print(f"Attempting to load repository: {repo_link}")
266
 
267
- model_card = ModelCard.load(repo_link)
 
268
  base_model = model_card.data.get("base_model")
269
- print(f"Base model identified: {base_model}")
270
 
271
- # Validate that the LoRA is compatible with Qwen-Image
272
  acceptable_models = {"Qwen/Qwen-Image"}
 
273
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
274
 
275
  if not any(model in acceptable_models for model in models_to_check):
276
- raise TypeError("The provided model is not a Qwen-Image compatible LoRA.")
277
 
278
- # Extract metadata from the model card
279
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url")
280
- activation_phrase = model_card.data.get("instance_prompt", "")
281
- image_url = f"https://huggingface.co/{repo_link}/resolve/main/{image_path}" if image_path else None
282
 
283
- # Find the .safetensors file in the repository
284
  fs = HfFileSystem()
285
  try:
286
- repo_files = fs.ls(repo_link, detail=False)
287
- safetensors_filename = None
288
- for file_path in repo_files:
289
- filename = file_path.split("/")[-1]
 
 
290
  if filename.endswith(".safetensors"):
291
- safetensors_filename = filename
292
  break
293
- if not safetensors_filename:
294
- raise FileNotFoundError("No .safetensors file was found in the repository.")
 
 
295
  except Exception as e:
296
  print(e)
297
- raise IOError("Could not access the Hugging Face repository or find a valid .safetensors file.")
298
 
299
- return split_link[1], repo_link, safetensors_filename, activation_phrase, image_url
300
 
301
- def parse_custom_model_source(source_text):
302
- """Parses a user-provided link to a custom LoRA."""
303
- print(f"Parsing custom model source: {source_text}")
304
 
305
- if source_text.endswith('.safetensors') and 'huggingface.co' in source_text:
306
- parts = source_text.split('/')
307
- try:
308
- hf_index = parts.index('huggingface.co')
309
- username = parts[hf_index + 1]
310
- repo_name = parts[hf_index + 2]
311
- repo_id = f"{username}/{repo_name}"
312
- safetensors_filename = parts[-1]
313
-
314
  try:
315
- model_card = ModelCard.load(repo_id)
316
- activation_phrase = model_card.data.get("instance_prompt", "")
317
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url")
318
- image_url = f"https://huggingface.co/{repo_id}/resolve/main/{image_path}" if image_path else None
319
- except Exception:
320
- activation_phrase = ""
321
- image_url = None
322
-
323
- return repo_name, repo_id, safetensors_filename, activation_phrase, image_url
324
- except ValueError:
325
- raise ValueError("Invalid .safetensors URL format.")
326
-
327
- if source_text.startswith("https://"):
328
- parsed_url = urlparse(source_text)
329
- if "huggingface.co" in parsed_url.netloc:
330
- repo_link = parsed_url.path.strip("/")
331
- return fetch_hf_safetensors_details(repo_link)
 
 
332
 
333
- # Assume it's a direct repo path like "username/repo-name"
334
- return fetch_hf_safetensors_details(source_text)
335
-
 
 
 
336
 
337
- def add_custom_style_model(custom_model_path):
338
- """Adds a custom LoRA provided by the user to the session."""
339
- global style_definitions
340
- if custom_model_path:
341
  try:
342
- style_name, repo_id, weight_file, activation_phrase, thumbnail_url = parse_custom_model_source(custom_model_path)
343
- print(f"Successfully loaded custom style: {repo_id}")
344
-
345
- card_html = f'''
346
  <div class="custom_lora_card">
347
- <span>Loaded custom style:</span>
348
  <div class="card_internal">
349
- <img src="{thumbnail_url}" alt="{style_name}" />
350
  <div>
351
- <h3>{style_name}</h3>
352
- <small>{"Activation phrase: <code><b>"+activation_phrase+"</b></code>" if activation_phrase else "No activation phrase found. If required, include it in your prompt."}<br></small>
353
  </div>
354
  </div>
355
  </div>
356
  '''
357
-
358
- # Check if this style already exists
359
- existing_item_index = next((index for (index, item) in enumerate(style_definitions) if item['repo_id'] == repo_id), None)
360
-
361
  if existing_item_index is None:
362
- new_style_item = {
363
- "thumbnail_url": thumbnail_url,
364
- "style_name": style_name,
365
- "repo_id": repo_id,
366
- "weight_file": weight_file,
367
- "activation_phrase": activation_phrase
368
  }
369
- style_definitions.append(new_style_item)
370
- existing_item_index = len(style_definitions) - 1
371
-
372
- return gr.update(visible=True, value=card_html), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {weight_file}", existing_item_index, activation_phrase
373
 
 
374
  except Exception as e:
375
- gr.Warning(f"Failed to load custom style. Error: {e}")
376
- error_message = f"Invalid input. Could not load the specified style. Please check the link or repository path."
377
- return gr.update(visible=True, value=error_message), gr.update(visible=True), gr.update(), "", None, ""
378
-
379
- # If input is empty, hide the custom section
380
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
381
 
382
- def remove_custom_style_model():
383
- """Resets the UI when a custom LoRA is removed."""
384
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
385
 
 
386
 
387
- # --- Gradio UI Definition ---
388
-
389
- app_css = '''
390
  #gen_btn{height: 100%}
391
  #gen_column{align-self: stretch}
392
  #title{text-align: center}
@@ -395,115 +397,112 @@ app_css = '''
395
  #gallery .grid-wrap{height: 10vh}
396
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
397
  .card_internal{display: flex;height: 100px;margin-top: .5em}
398
- .card_internal img{margin-right: 1em; object-fit: cover;}
399
  .styler{--form-gap-width: 0px !important}
400
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
401
- .custom_lora_card{padding: 1em; border: 1px solid var(--border-color-primary); border-radius: var(--radius-lg)}
402
  '''
403
 
404
- with gr.Blocks(theme="bethecloud/storj_theme", css=app_css, delete_cache=(120, 120)) as web_interface:
405
- main_title = gr.HTML("""<h1>Qwen Image LoRA DLC❤️‍🔥</h1>""", elem_id="title")
406
- selected_style_index = gr.State(None)
407
 
408
  with gr.Row():
409
  with gr.Column(scale=3):
410
- prompt_textbox = gr.Textbox(label="Prompt", lines=1, placeholder="Select a style to begin...")
411
  with gr.Column(scale=1, elem_id="gen_column"):
412
- generate_btn = gr.Button("Generate", variant="primary", elem_id="gen_btn")
413
 
414
  with gr.Row():
415
  with gr.Column():
416
- selected_style_info = gr.Markdown("")
417
- style_gallery = gr.Gallery(
418
- [(item["thumbnail_url"], item["style_name"]) for item in style_definitions],
419
- label="Style Gallery",
420
  allow_preview=False,
421
  columns=3,
422
  elem_id="gallery",
423
  show_share_button=False
424
  )
425
  with gr.Group():
426
- custom_style_textbox = gr.Textbox(label="Load Custom Style", info="Enter a Hugging Face repository path (e.g., username/repo-name)", placeholder="username/qwen-image-custom-style")
427
- gr.Markdown("[Find More Qwen-Image Styles Here](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
428
- custom_style_info_html = gr.HTML(visible=False)
429
- remove_custom_style_btn = gr.Button("Remove Custom Style", visible=False)
430
 
431
  with gr.Column():
432
- output_image_display = gr.Image(label="Generated Image")
433
 
434
  with gr.Row():
435
- aspect_ratio_dropdown = gr.Dropdown(
436
  label="Aspect Ratio",
437
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
438
  value="1:1"
439
- )
440
  with gr.Row():
441
- generation_mode_dropdown = gr.Dropdown(
442
  label="Generation Mode",
443
  choices=["Speed (8 steps)", "Quality (45 steps)"],
444
  value="Quality (45 steps)",
445
  )
446
 
447
- generation_mode_status_display = gr.Markdown("Quality mode active", elem_id="speed_status")
448
 
449
  with gr.Row():
450
  with gr.Accordion("Advanced Settings", open=False):
451
  with gr.Column():
452
  with gr.Row():
453
- cfg_scale_slider = gr.Slider(
454
- label="Guidance Scale (CFG)",
455
  minimum=1.0,
456
  maximum=5.0,
457
  step=0.1,
458
  value=3.5,
459
- info="Adjusts how strictly the model follows the prompt. Lower for speed, higher for quality."
460
  )
461
- steps_slider = gr.Slider(
462
- label="Inference Steps",
463
  minimum=4,
464
  maximum=50,
465
  step=1,
466
  value=45,
467
- info="Number of steps for the generation process. Automatically set by Generation Mode."
468
  )
469
 
470
  with gr.Row():
471
- randomize_seed_checkbox = gr.Checkbox(True, label="Use Random Seed")
472
- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED_VALUE, step=1, value=0, randomize=True)
473
- style_scale_slider = gr.Slider(label="Style Strength", minimum=0, maximum=2, step=0.01, value=1.0)
474
-
475
- # --- Event Handlers ---
476
- style_gallery.select(
477
- on_style_select,
478
- inputs=[aspect_ratio_dropdown],
479
- outputs=[prompt_textbox, selected_style_info, selected_style_index, aspect_ratio_dropdown]
480
  )
481
 
482
- generation_mode_dropdown.change(
483
- on_mode_change,
484
- inputs=[generation_mode_dropdown],
485
- outputs=[generation_mode_status_display, steps_slider, cfg_scale_slider]
486
  )
487
 
488
- custom_style_textbox.submit(
489
- add_custom_style_model,
490
- inputs=[custom_style_textbox],
491
- outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, prompt_textbox]
492
  )
493
 
494
- remove_custom_style_btn.click(
495
- remove_custom_style_model,
496
- outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, custom_style_textbox]
497
  )
498
 
499
- # Combined trigger for generation
500
- generate_triggers = [generate_btn.click, prompt_textbox.submit]
501
  gr.on(
502
- triggers=generate_triggers,
503
- fn=handle_generate_request,
504
- inputs=[prompt_textbox, cfg_scale_slider, steps_slider, selected_style_index, randomize_seed_checkbox, seed_slider, aspect_ratio_dropdown, style_scale_slider, generation_mode_dropdown],
505
- outputs=[output_image_display, seed_slider]
506
  )
507
 
508
- web_interface.queue()
509
- web_interface.launch(share=False, ssr_mode=False, show_error=True)
 
37
  print("current device:", torch.cuda.current_device())
38
  print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
39
 
40
+ print("Using device:", device)
41
 
42
+ loras = [
43
+ # Sample Qwen-compatible LoRAs
44
  {
45
+ "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Studio-Realism/resolve/main/images/2.png",
46
+ "title": "Studio Realism",
47
+ "repo": "prithivMLmods/Qwen-Image-Studio-Realism",
48
+ "weights": "qwen-studio-realism.safetensors",
49
+ "trigger_word": "Studio Realism"
50
  },
51
  {
52
+ "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Sketch-Smudge/resolve/main/images/1.png",
53
+ "title": "Sketch Smudge",
54
+ "repo": "prithivMLmods/Qwen-Image-Sketch-Smudge",
55
+ "weights": "qwen-sketch-smudge.safetensors",
56
+ "trigger_word": "Sketch Smudge"
57
  },
58
  {
59
+ "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Anime-LoRA/resolve/main/images/1.png",
60
+ "title": "Qwen Anime",
61
+ "repo": "prithivMLmods/Qwen-Image-Anime-LoRA",
62
+ "weights": "qwen-anime.safetensors",
63
+ "trigger_word": "Qwen Anime"
64
  },
65
  {
66
+ "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Synthetic-Face/resolve/main/images/2.png",
67
+ "title": "Synthetic Face",
68
+ "repo": "prithivMLmods/Qwen-Image-Synthetic-Face",
69
+ "weights": "qwen-synthetic-face.safetensors",
70
+ "trigger_word": "Synthetic Face"
71
  },
72
  {
73
+ "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Fragmented-Portraiture/resolve/main/images/3.png",
74
+ "title": "Fragmented Portraiture",
75
+ "repo": "prithivMLmods/Qwen-Image-Fragmented-Portraiture",
76
+ "weights": "qwen-fragmented-portraiture.safetensors",
77
+ "trigger_word": "Fragmented Portraiture"
78
  },
79
  ]
80
 
81
+ # Initialize the base model
82
+ dtype = torch.bfloat16
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ base_model = "Qwen/Qwen-Image"
85
 
86
+ # Scheduler configuration from the Qwen-Image-Lightning repository
87
+ scheduler_config = {
88
  "base_image_seq_len": 256,
89
  "base_shift": math.log(3),
90
  "invert_sigmas": False,
 
101
  "use_karras_sigmas": False,
102
  }
103
 
104
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
105
+ pipe = DiffusionPipeline.from_pretrained(
106
+ base_model, scheduler=scheduler, torch_dtype=dtype
107
+ ).to(device)
108
 
109
+ # Lightning LoRA info (no global state)
110
+ LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
111
+ LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
112
 
113
+ MAX_SEED = np.iinfo(np.int32).max
114
 
115
+ class Timer:
116
+ def __init__(self, task_name=""):
117
+ self.task_name = task_name
 
118
 
119
  def __enter__(self):
120
  self.start_time = time.time()
 
123
  def __exit__(self, exc_type, exc_value, traceback):
124
  self.end_time = time.time()
125
  self.elapsed_time = self.end_time - self.start_time
126
+ if self.task_name:
127
+ print(f"Elapsed time for {self.task_name}: {self.elapsed_time:.6f} seconds")
128
+ else:
129
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
130
+
131
+ def compute_image_dimensions(aspect_ratio):
132
+ """Converts aspect ratio string to width, height tuple."""
133
+ if aspect_ratio == "1:1":
134
+ return 1024, 1024
135
+ elif aspect_ratio == "16:9":
136
+ return 1152, 640
137
+ elif aspect_ratio == "9:16":
138
+ return 640, 1152
139
+ elif aspect_ratio == "4:3":
140
+ return 1024, 768
141
+ elif aspect_ratio == "3:4":
142
+ return 768, 1024
143
+ elif aspect_ratio == "3:2":
144
+ return 1024, 688
145
+ elif aspect_ratio == "2:3":
146
+ return 688, 1024
147
+ else:
148
+ return 1024, 1024
149
+
150
+ def handle_lora_selection(evt: gr.SelectData, aspect_ratio):
151
+ selected_lora = loras[evt.index]
152
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
153
+ lora_repo = selected_lora["repo"]
154
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
155
 
156
+ # Update aspect ratio if specified in LoRA config
157
+ if "aspect" in selected_lora:
158
+ if selected_lora["aspect"] == "portrait":
159
+ aspect_ratio = "9:16"
160
+ elif selected_lora["aspect"] == "landscape":
161
+ aspect_ratio = "16:9"
162
  else:
163
+ aspect_ratio = "1:1"
164
 
165
  return (
166
  gr.update(placeholder=new_placeholder),
167
+ updated_text,
168
+ evt.index,
169
+ aspect_ratio,
170
  )
171
 
172
+ def adjust_generation_mode(speed_mode):
173
+ """Update UI based on speed/quality toggle."""
174
+ if speed_mode == "Speed (8 steps)":
175
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
176
  else:
177
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
178
 
179
  @spaces.GPU(duration=70)
180
+ def create_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
181
+ pipe.to("cuda")
182
+ generator = torch.Generator(device="cuda").manual_seed(seed)
 
183
 
184
+ with Timer("Generating image"):
185
+ # Generate image
186
+ image = pipe(
187
+ prompt=prompt_mash,
188
  negative_prompt=negative_prompt,
189
  num_inference_steps=steps,
190
+ true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
191
  width=width,
192
  height=height,
193
  generator=generator,
194
  ).images[0]
195
 
196
+ return image
197
 
198
  @spaces.GPU(duration=70)
199
+ def process_adapter_generation(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
200
+ if selected_index is None:
201
+ raise gr.Error("You must select a LoRA before proceeding.")
 
202
 
203
+ selected_lora = loras[selected_index]
204
+ lora_path = selected_lora["repo"]
205
+ trigger_word = selected_lora["trigger_word"]
206
 
207
+ # Prepare prompt with trigger word
208
+ if trigger_word:
209
+ if "trigger_position" in selected_lora:
210
+ if selected_lora["trigger_position"] == "prepend":
211
+ prompt_mash = f"{trigger_word} {prompt}"
212
+ else:
213
+ prompt_mash = f"{prompt} {trigger_word}"
214
  else:
215
+ prompt_mash = f"{trigger_word} {prompt}"
216
  else:
217
+ prompt_mash = prompt
218
+
219
+ # Always unload any existing LoRAs first to avoid conflicts
220
+ with Timer("Unloading existing LoRAs"):
221
+ pipe.unload_lora_weights()
222
+
223
+ # Load LoRAs based on speed mode
224
+ if speed_mode == "Speed (8 steps)":
225
+ with Timer("Loading Lightning LoRA and style LoRA"):
226
+ # Load Lightning LoRA first
227
+ pipe.load_lora_weights(
228
+ LIGHTNING_LORA_REPO,
229
+ weight_name=LIGHTNING_LORA_WEIGHT,
230
  adapter_name="lightning"
231
  )
232
 
233
+ # Load the selected style LoRA
234
+ weight_name = selected_lora.get("weights", None)
235
+ pipe.load_lora_weights(
236
+ lora_path,
237
+ weight_name=weight_name,
238
  low_cpu_mem_usage=True,
239
  adapter_name="style"
240
  )
241
 
242
+ # Set both adapters active with their weights
243
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
244
+ else:
245
+ # Quality mode - only load the style LoRA
246
+ with Timer(f"Loading LoRA weights for {selected_lora['title']}"):
247
+ weight_name = selected_lora.get("weights", None)
248
+ pipe.load_lora_weights(
249
+ lora_path,
250
+ weight_name=weight_name,
251
  low_cpu_mem_usage=True
252
  )
253
 
254
+ # Set random seed for reproducibility
255
+ with Timer("Randomizing seed"):
256
+ if randomize_seed:
257
+ seed = random.randint(0, MAX_SEED)
258
 
259
+ # Get image dimensions from aspect ratio
260
+ width, height = compute_image_dimensions(aspect_ratio)
261
 
262
+ # Generate the image
263
+ final_image = create_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
264
 
265
+ return final_image, seed
266
 
267
+ def fetch_hf_adapter_files(link):
268
+ split_link = link.split("/")
 
269
  if len(split_link) != 2:
270
+ raise Exception("Invalid Hugging Face repository link format.")
271
 
272
+ print(f"Repository attempted: {split_link}")
273
 
274
+ # Load model card
275
+ model_card = ModelCard.load(link)
276
  base_model = model_card.data.get("base_model")
277
+ print(f"Base model: {base_model}")
278
 
279
+ # Validate model type (for Qwen-Image)
280
  acceptable_models = {"Qwen/Qwen-Image"}
281
+
282
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
283
 
284
  if not any(model in acceptable_models for model in models_to_check):
285
+ raise Exception("Not a Qwen-Image LoRA!")
286
 
287
+ # Extract image and trigger word
288
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
289
+ trigger_word = model_card.data.get("instance_prompt", "")
290
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
291
 
292
+ # Initialize Hugging Face file system
293
  fs = HfFileSystem()
294
  try:
295
+ list_of_files = fs.ls(link, detail=False)
296
+
297
+ # Find safetensors file
298
+ safetensors_name = None
299
+ for file in list_of_files:
300
+ filename = file.split("/")[-1]
301
  if filename.endswith(".safetensors"):
302
+ safetensors_name = filename
303
  break
304
+
305
+ if not safetensors_name:
306
+ raise Exception("No valid *.safetensors file found in the repository.")
307
+
308
  except Exception as e:
309
  print(e)
310
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
311
 
312
+ return split_link[1], link, safetensors_name, trigger_word, image_url
313
 
314
+ def validate_custom_adapter(link):
315
+ print(f"Checking a custom model on: {link}")
 
316
 
317
+ if link.endswith('.safetensors'):
318
+ if 'huggingface.co' in link:
319
+ parts = link.split('/')
 
 
 
 
 
 
320
  try:
321
+ hf_index = parts.index('huggingface.co')
322
+ username = parts[hf_index + 1]
323
+ repo_name = parts[hf_index + 2]
324
+ repo = f"{username}/{repo_name}"
325
+
326
+ safetensors_name = parts[-1]
327
+
328
+ try:
329
+ model_card = ModelCard.load(repo)
330
+ trigger_word = model_card.data.get("instance_prompt", "")
331
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
332
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
333
+ except:
334
+ trigger_word = ""
335
+ image_url = None
336
+
337
+ return repo_name, repo, safetensors_name, trigger_word, image_url
338
+ except:
339
+ raise Exception("Invalid safetensors URL format")
340
 
341
+ if link.startswith("https://"):
342
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
343
+ link_split = link.split("huggingface.co/")
344
+ return fetch_hf_adapter_files(link_split[1])
345
+ else:
346
+ return fetch_hf_adapter_files(link)
347
 
348
+ def incorporate_custom_adapter(custom_lora):
349
+ global loras
350
+ if custom_lora:
 
351
  try:
352
+ title, repo, path, trigger_word, image = validate_custom_adapter(custom_lora)
353
+ print(f"Loaded custom LoRA: {repo}")
354
+ card = f'''
 
355
  <div class="custom_lora_card">
356
+ <span>Loaded custom LoRA:</span>
357
  <div class="card_internal">
358
+ <img src="{image}" />
359
  <div>
360
+ <h3>{title}</h3>
361
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
362
  </div>
363
  </div>
364
  </div>
365
  '''
366
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
 
 
 
367
  if existing_item_index is None:
368
+ new_item = {
369
+ "image": image,
370
+ "title": title,
371
+ "repo": repo,
372
+ "weights": path,
373
+ "trigger_word": trigger_word
374
  }
375
+ print(new_item)
376
+ loras.append(new_item)
377
+ existing_item_index = len(loras) - 1 # Get the actual index after adding
 
378
 
379
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
380
  except Exception as e:
381
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
382
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, ""
383
+ else:
384
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
385
 
386
+ def discard_custom_adapter():
 
387
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
388
 
389
+ process_adapter_generation.zerogpu = True
390
 
391
+ css = '''
 
 
392
  #gen_btn{height: 100%}
393
  #gen_column{align-self: stretch}
394
  #title{text-align: center}
 
397
  #gallery .grid-wrap{height: 10vh}
398
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
399
  .card_internal{display: flex;height: 100px;margin-top: .5em}
400
+ .card_internal img{margin-right: 1em}
401
  .styler{--form-gap-width: 0px !important}
402
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
 
403
  '''
404
 
405
+ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
406
+ title = gr.HTML("""<h1>Qwen Image LoRA DLC⛵</h1>""", elem_id="title")
407
+ selected_index = gr.State(None)
408
 
409
  with gr.Row():
410
  with gr.Column(scale=3):
411
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
412
  with gr.Column(scale=1, elem_id="gen_column"):
413
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
414
 
415
  with gr.Row():
416
  with gr.Column():
417
+ selected_info = gr.Markdown("")
418
+ gallery = gr.Gallery(
419
+ [(item["image"], item["title"]) for item in loras],
420
+ label="LoRA Gallery",
421
  allow_preview=False,
422
  columns=3,
423
  elem_id="gallery",
424
  show_share_button=False
425
  )
426
  with gr.Group():
427
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
428
+ gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
429
+ custom_lora_info = gr.HTML(visible=False)
430
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
431
 
432
  with gr.Column():
433
+ result = gr.Image(label="Generated Image")
434
 
435
  with gr.Row():
436
+ aspect_ratio = gr.Dropdown(
437
  label="Aspect Ratio",
438
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
439
  value="1:1"
440
+ )
441
  with gr.Row():
442
+ speed_mode = gr.Dropdown(
443
  label="Generation Mode",
444
  choices=["Speed (8 steps)", "Quality (45 steps)"],
445
  value="Quality (45 steps)",
446
  )
447
 
448
+ speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
449
 
450
  with gr.Row():
451
  with gr.Accordion("Advanced Settings", open=False):
452
  with gr.Column():
453
  with gr.Row():
454
+ cfg_scale = gr.Slider(
455
+ label="Guidance Scale (True CFG)",
456
  minimum=1.0,
457
  maximum=5.0,
458
  step=0.1,
459
  value=3.5,
460
+ info="Lower for speed mode, higher for quality"
461
  )
462
+ steps = gr.Slider(
463
+ label="Steps",
464
  minimum=4,
465
  maximum=50,
466
  step=1,
467
  value=45,
468
+ info="Automatically set by speed mode"
469
  )
470
 
471
  with gr.Row():
472
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
473
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
474
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
475
+
476
+ # Event handlers
477
+ gallery.select(
478
+ handle_lora_selection,
479
+ inputs=[aspect_ratio],
480
+ outputs=[prompt, selected_info, selected_index, aspect_ratio]
481
  )
482
 
483
+ speed_mode.change(
484
+ adjust_generation_mode,
485
+ inputs=[speed_mode],
486
+ outputs=[speed_status, steps, cfg_scale]
487
  )
488
 
489
+ custom_lora.input(
490
+ incorporate_custom_adapter,
491
+ inputs=[custom_lora],
492
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
493
  )
494
 
495
+ custom_lora_button.click(
496
+ discard_custom_adapter,
497
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
498
  )
499
 
 
 
500
  gr.on(
501
+ triggers=[generate_button.click, prompt.submit],
502
+ fn=process_adapter_generation,
503
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode],
504
+ outputs=[result, seed]
505
  )
506
 
507
+ app.queue()
508
+ app.launch(share=False, ssr_mode=False, show_error=True)