prithivMLmods commited on
Commit
cb494fe
·
verified ·
1 Parent(s): fe44296

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +293 -280
app.py CHANGED
@@ -27,52 +27,64 @@ import shutil
27
  import uuid
28
  import zipfile
29
 
30
- loras = [
31
- # Sample Qwen-compatible LoRAs
 
 
 
 
 
 
 
 
 
 
 
 
32
  {
33
- "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Studio-Realism/resolve/main/images/2.png",
34
- "title": "Studio Realism",
35
- "repo": "prithivMLmods/Qwen-Image-Studio-Realism",
36
- "weights": "qwen-studio-realism.safetensors",
37
- "trigger_word": "Studio Realism"
38
  },
39
  {
40
- "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Sketch-Smudge/resolve/main/images/1.png",
41
- "title": "Sketch Smudge",
42
- "repo": "prithivMLmods/Qwen-Image-Sketch-Smudge",
43
- "weights": "qwen-sketch-smudge.safetensors",
44
- "trigger_word": "Sketch Smudge"
45
  },
46
  {
47
- "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Anime-LoRA/resolve/main/images/1.png",
48
- "title": "Qwen Anime",
49
- "repo": "prithivMLmods/Qwen-Image-Anime-LoRA",
50
- "weights": "qwen-anime.safetensors",
51
- "trigger_word": "Qwen Anime"
52
  },
53
  {
54
- "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Synthetic-Face/resolve/main/images/2.png",
55
- "title": "Synthetic Face",
56
- "repo": "prithivMLmods/Qwen-Image-Synthetic-Face",
57
- "weights": "qwen-synthetic-face.safetensors",
58
- "trigger_word": "Synthetic Face"
59
  },
60
  {
61
- "image": "https://huggingface.co/prithivMLmods/Qwen-Image-Fragmented-Portraiture/resolve/main/images/3.png",
62
- "title": "Fragmented Portraiture",
63
- "repo": "prithivMLmods/Qwen-Image-Fragmented-Portraiture",
64
- "weights": "qwen-fragmented-portraiture.safetensors",
65
- "trigger_word": "Fragmented Portraiture"
66
  },
67
  ]
68
 
69
- # Initialize the base model
70
- dtype = torch.bfloat16
71
- device = "cuda" if torch.cuda.is_available() else "cpu"
72
- base_model = "Qwen/Qwen-Image"
73
 
74
- # Scheduler configuration from the Qwen-Image-Lightning repository
75
- scheduler_config = {
76
  "base_image_seq_len": 256,
77
  "base_shift": math.log(3),
78
  "invert_sigmas": False,
@@ -89,18 +101,19 @@ scheduler_config = {
89
  "use_karras_sigmas": False,
90
  }
91
 
92
- scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
93
- pipe = DiffusionPipeline.from_pretrained(
94
- base_model, scheduler=scheduler, torch_dtype=dtype
95
- ).to(device)
96
 
97
- # Lightning LoRA info (no global state)
98
- LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
99
- LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
100
 
101
- MAX_SEED = np.iinfo(np.int32).max
102
 
103
- class calculateDuration:
 
104
  def __init__(self, activity_name=""):
105
  self.activity_name = activity_name
106
 
@@ -111,272 +124,269 @@ class calculateDuration:
111
  def __exit__(self, exc_type, exc_value, traceback):
112
  self.end_time = time.time()
113
  self.elapsed_time = self.end_time - self.start_time
114
- if self.activity_name:
115
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
116
- else:
117
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
118
-
119
- def get_image_size(aspect_ratio):
120
- """Converts aspect ratio string to width, height tuple."""
121
- if aspect_ratio == "1:1":
122
- return 1024, 1024
123
- elif aspect_ratio == "16:9":
124
- return 1152, 640
125
- elif aspect_ratio == "9:16":
126
- return 640, 1152
127
- elif aspect_ratio == "4:3":
128
- return 1024, 768
129
- elif aspect_ratio == "3:4":
130
- return 768, 1024
131
- elif aspect_ratio == "3:2":
132
- return 1024, 688
133
- elif aspect_ratio == "2:3":
134
- return 688, 1024
135
- else:
136
- return 1024, 1024
137
 
138
- def update_selection(evt: gr.SelectData, aspect_ratio):
139
- selected_lora = loras[evt.index]
140
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
141
- lora_repo = selected_lora["repo"]
142
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Update aspect ratio if specified in LoRA config
145
- if "aspect" in selected_lora:
146
- if selected_lora["aspect"] == "portrait":
147
- aspect_ratio = "9:16"
148
- elif selected_lora["aspect"] == "landscape":
149
- aspect_ratio = "16:9"
150
  else:
151
- aspect_ratio = "1:1"
152
 
153
  return (
154
  gr.update(placeholder=new_placeholder),
155
- updated_text,
156
- evt.index,
157
- aspect_ratio,
158
  )
159
 
160
- def handle_speed_mode(speed_mode):
161
- """Update UI based on speed/quality toggle."""
162
- if speed_mode == "Speed (8 steps)":
163
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
164
  else:
165
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
166
 
167
  @spaces.GPU(duration=70)
168
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
169
- pipe.to("cuda")
170
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
171
 
172
- with calculateDuration("Generating image"):
173
- # Generate image
174
- image = pipe(
175
- prompt=prompt_mash,
176
  negative_prompt=negative_prompt,
177
  num_inference_steps=steps,
178
- true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
179
  width=width,
180
  height=height,
181
  generator=generator,
182
  ).images[0]
183
 
184
- return image
185
 
186
  @spaces.GPU(duration=70)
187
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
188
- if selected_index is None:
189
- raise gr.Error("You must select a LoRA before proceeding.")
 
190
 
191
- selected_lora = loras[selected_index]
192
- lora_path = selected_lora["repo"]
193
- trigger_word = selected_lora["trigger_word"]
194
 
195
- # Prepare prompt with trigger word
196
- if trigger_word:
197
- if "trigger_position" in selected_lora:
198
- if selected_lora["trigger_position"] == "prepend":
199
- prompt_mash = f"{trigger_word} {prompt}"
200
- else:
201
- prompt_mash = f"{prompt} {trigger_word}"
202
  else:
203
- prompt_mash = f"{trigger_word} {prompt}"
204
  else:
205
- prompt_mash = prompt
206
-
207
- # Always unload any existing LoRAs first to avoid conflicts
208
- with calculateDuration("Unloading existing LoRAs"):
209
- pipe.unload_lora_weights()
210
-
211
- # Load LoRAs based on speed mode
212
- if speed_mode == "Speed (8 steps)":
213
- with calculateDuration("Loading Lightning LoRA and style LoRA"):
214
- # Load Lightning LoRA first
215
- pipe.load_lora_weights(
216
- LIGHTNING_LORA_REPO,
217
- weight_name=LIGHTNING_LORA_WEIGHT,
218
  adapter_name="lightning"
219
  )
220
 
221
- # Load the selected style LoRA
222
- weight_name = selected_lora.get("weights", None)
223
- pipe.load_lora_weights(
224
- lora_path,
225
- weight_name=weight_name,
226
  low_cpu_mem_usage=True,
227
  adapter_name="style"
228
  )
229
 
230
- # Set both adapters active with their weights
231
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
232
- else:
233
- # Quality mode - only load the style LoRA
234
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
235
- weight_name = selected_lora.get("weights", None)
236
- pipe.load_lora_weights(
237
- lora_path,
238
- weight_name=weight_name,
239
  low_cpu_mem_usage=True
240
  )
241
 
242
- # Set random seed for reproducibility
243
- with calculateDuration("Randomizing seed"):
244
- if randomize_seed:
245
- seed = random.randint(0, MAX_SEED)
246
 
247
- # Get image dimensions from aspect ratio
248
- width, height = get_image_size(aspect_ratio)
249
 
250
- # Generate the image
251
- final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
252
 
253
- return final_image, seed
254
 
255
- def get_huggingface_safetensors(link):
256
- split_link = link.split("/")
 
257
  if len(split_link) != 2:
258
- raise Exception("Invalid Hugging Face repository link format.")
259
 
260
- print(f"Repository attempted: {split_link}")
261
 
262
- # Load model card
263
- model_card = ModelCard.load(link)
264
  base_model = model_card.data.get("base_model")
265
- print(f"Base model: {base_model}")
266
 
267
- # Validate model type (for Qwen-Image)
268
  acceptable_models = {"Qwen/Qwen-Image"}
269
-
270
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
271
 
272
  if not any(model in acceptable_models for model in models_to_check):
273
- raise Exception("Not a Qwen-Image LoRA!")
274
 
275
- # Extract image and trigger word
276
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
277
- trigger_word = model_card.data.get("instance_prompt", "")
278
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
279
 
280
- # Initialize Hugging Face file system
281
  fs = HfFileSystem()
282
  try:
283
- list_of_files = fs.ls(link, detail=False)
284
-
285
- # Find safetensors file
286
- safetensors_name = None
287
- for file in list_of_files:
288
- filename = file.split("/")[-1]
289
  if filename.endswith(".safetensors"):
290
- safetensors_name = filename
291
  break
292
-
293
- if not safetensors_name:
294
- raise Exception("No valid *.safetensors file found in the repository.")
295
-
296
  except Exception as e:
297
  print(e)
298
- raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
299
 
300
- return split_link[1], link, safetensors_name, trigger_word, image_url
301
 
302
- def check_custom_model(link):
303
- print(f"Checking a custom model on: {link}")
 
304
 
305
- if link.endswith('.safetensors'):
306
- if 'huggingface.co' in link:
307
- parts = link.split('/')
 
 
 
 
 
 
308
  try:
309
- hf_index = parts.index('huggingface.co')
310
- username = parts[hf_index + 1]
311
- repo_name = parts[hf_index + 2]
312
- repo = f"{username}/{repo_name}"
313
-
314
- safetensors_name = parts[-1]
315
-
316
- try:
317
- model_card = ModelCard.load(repo)
318
- trigger_word = model_card.data.get("instance_prompt", "")
319
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
320
- image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
321
- except:
322
- trigger_word = ""
323
- image_url = None
324
-
325
- return repo_name, repo, safetensors_name, trigger_word, image_url
326
- except:
327
- raise Exception("Invalid safetensors URL format")
328
 
329
- if link.startswith("https://"):
330
- if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
331
- link_split = link.split("huggingface.co/")
332
- return get_huggingface_safetensors(link_split[1])
333
- else:
334
- return get_huggingface_safetensors(link)
 
 
335
 
336
- def add_custom_lora(custom_lora):
337
- global loras
338
- if custom_lora:
 
 
339
  try:
340
- title, repo, path, trigger_word, image = check_custom_model(custom_lora)
341
- print(f"Loaded custom LoRA: {repo}")
342
- card = f'''
 
343
  <div class="custom_lora_card">
344
- <span>Loaded custom LoRA:</span>
345
  <div class="card_internal">
346
- <img src="{image}" />
347
  <div>
348
- <h3>{title}</h3>
349
- <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
350
  </div>
351
  </div>
352
  </div>
353
  '''
354
- existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
 
 
 
355
  if existing_item_index is None:
356
- new_item = {
357
- "image": image,
358
- "title": title,
359
- "repo": repo,
360
- "weights": path,
361
- "trigger_word": trigger_word
362
  }
363
- print(new_item)
364
- loras.append(new_item)
365
- existing_item_index = len(loras) - 1 # Get the actual index after adding
 
366
 
367
- return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
368
  except Exception as e:
369
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
370
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, ""
371
- else:
372
- return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
 
 
373
 
374
- def remove_custom_lora():
 
375
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
376
 
377
- run_lora.zerogpu = True
378
 
379
- css = '''
 
 
380
  #gen_btn{height: 100%}
381
  #gen_column{align-self: stretch}
382
  #title{text-align: center}
@@ -385,112 +395,115 @@ css = '''
385
  #gallery .grid-wrap{height: 10vh}
386
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
387
  .card_internal{display: flex;height: 100px;margin-top: .5em}
388
- .card_internal img{margin-right: 1em}
389
  .styler{--form-gap-width: 0px !important}
390
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
 
391
  '''
392
 
393
- with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
394
- title = gr.HTML("""<h1>Qwen Image LoRA DLC ❤️‍🔥</h1>""", elem_id="title")
395
- selected_index = gr.State(None)
396
 
397
  with gr.Row():
398
  with gr.Column(scale=3):
399
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
400
  with gr.Column(scale=1, elem_id="gen_column"):
401
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
402
 
403
  with gr.Row():
404
  with gr.Column():
405
- selected_info = gr.Markdown("")
406
- gallery = gr.Gallery(
407
- [(item["image"], item["title"]) for item in loras],
408
- label="LoRA Gallery",
409
  allow_preview=False,
410
  columns=3,
411
  elem_id="gallery",
412
  show_share_button=False
413
  )
414
  with gr.Group():
415
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
416
- gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
417
- custom_lora_info = gr.HTML(visible=False)
418
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
419
 
420
  with gr.Column():
421
- result = gr.Image(label="Generated Image")
422
 
423
  with gr.Row():
424
- aspect_ratio = gr.Dropdown(
425
  label="Aspect Ratio",
426
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
427
  value="1:1"
428
- )
429
  with gr.Row():
430
- speed_mode = gr.Dropdown(
431
  label="Generation Mode",
432
  choices=["Speed (8 steps)", "Quality (45 steps)"],
433
- value="Quality (48 steps)",
434
  )
435
 
436
- speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
437
 
438
  with gr.Row():
439
  with gr.Accordion("Advanced Settings", open=False):
440
  with gr.Column():
441
  with gr.Row():
442
- cfg_scale = gr.Slider(
443
- label="Guidance Scale (True CFG)",
444
  minimum=1.0,
445
  maximum=5.0,
446
  step=0.1,
447
  value=3.5,
448
- info="Lower for speed mode, higher for quality"
449
  )
450
- steps = gr.Slider(
451
- label="Steps",
452
  minimum=4,
453
  maximum=50,
454
  step=1,
455
  value=45,
456
- info="Automatically set by speed mode"
457
  )
458
 
459
  with gr.Row():
460
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
461
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
462
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
463
-
464
- # Event handlers
465
- gallery.select(
466
- update_selection,
467
- inputs=[aspect_ratio],
468
- outputs=[prompt, selected_info, selected_index, aspect_ratio]
469
  )
470
 
471
- speed_mode.change(
472
- handle_speed_mode,
473
- inputs=[speed_mode],
474
- outputs=[speed_status, steps, cfg_scale]
475
  )
476
 
477
- custom_lora.input(
478
- add_custom_lora,
479
- inputs=[custom_lora],
480
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
481
  )
482
 
483
- custom_lora_button.click(
484
- remove_custom_lora,
485
- outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
486
  )
487
 
 
 
488
  gr.on(
489
- triggers=[generate_button.click, prompt.submit],
490
- fn=run_lora,
491
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode],
492
- outputs=[result, seed]
493
  )
494
 
495
- app.queue()
496
- app.launch(share=False, ssr_mode=False, show_error=True)
 
27
  import uuid
28
  import zipfile
29
 
30
+ # META: CUDA_CHECK / GPU_INFO
31
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
32
+ print("torch.__version__ =", torch.__version__)
33
+ print("torch.version.cuda =", torch.version.cuda)
34
+ print("cuda available:", torch.cuda.is_available())
35
+ print("cuda device count:", torch.cuda.device_count())
36
+ if torch.cuda.is_available():
37
+ print("current device:", torch.cuda.current_device())
38
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
39
+
40
+ print("Using device:", device)
41
+
42
+ # List of predefined style models (formerly LoRAs)
43
+ style_definitions = [
44
  {
45
+ "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Studio-Realism/resolve/main/images/2.png",
46
+ "style_name": "Studio Realism",
47
+ "repo_id": "prithivMLmods/Qwen-Image-Studio-Realism",
48
+ "weight_file": "qwen-studio-realism.safetensors",
49
+ "activation_phrase": "Studio Realism"
50
  },
51
  {
52
+ "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Sketch-Smudge/resolve/main/images/1.png",
53
+ "style_name": "Sketch Smudge",
54
+ "repo_id": "prithivMLmods/Qwen-Image-Sketch-Smudge",
55
+ "weight_file": "qwen-sketch-smudge.safetensors",
56
+ "activation_phrase": "Sketch Smudge"
57
  },
58
  {
59
+ "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Anime-LoRA/resolve/main/images/1.png",
60
+ "style_name": "Qwen Anime",
61
+ "repo_id": "prithivMLmods/Qwen-Image-Anime-LoRA",
62
+ "weight_file": "qwen-anime.safetensors",
63
+ "activation_phrase": "Qwen Anime"
64
  },
65
  {
66
+ "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Synthetic-Face/resolve/main/images/2.png",
67
+ "style_name": "Synthetic Face",
68
+ "repo_id": "prithivMLmods/Qwen-Image-Synthetic-Face",
69
+ "weight_file": "qwen-synthetic-face.safetensors",
70
+ "activation_phrase": "Synthetic Face"
71
  },
72
  {
73
+ "thumbnail_url": "https://huggingface.co/prithivMLmods/Qwen-Image-Fragmented-Portraiture/resolve/main/images/3.png",
74
+ "style_name": "Fragmented Portraiture",
75
+ "repo_id": "prithivMLmods/Qwen-Image-Fragmented-Portraiture",
76
+ "weight_file": "qwen-fragmented-portraiture.safetensors",
77
+ "activation_phrase": "Fragmented Portraiture"
78
  },
79
  ]
80
 
81
+ # --- Model Initialization ---
82
+ model_precision = torch.bfloat16
83
+ processing_device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ foundation_model_id = "Qwen/Qwen-Image"
85
 
86
+ # Sampler configuration from the Qwen-Image-Lightning repository
87
+ sampler_settings = {
88
  "base_image_seq_len": 256,
89
  "base_shift": math.log(3),
90
  "invert_sigmas": False,
 
101
  "use_karras_sigmas": False,
102
  }
103
 
104
+ sampler = FlowMatchEulerDiscreteScheduler.from_config(sampler_settings)
105
+ diffusion_pipeline = DiffusionPipeline.from_pretrained(
106
+ foundation_model_id, scheduler=sampler, torch_dtype=model_precision
107
+ ).to(processing_device)
108
 
109
+ # Information for the fast generation LoRA
110
+ FAST_GENERATION_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
111
+ FAST_GENERATION_LORA_WEIGHTS = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
112
 
113
+ MAX_SEED_VALUE = np.iinfo(np.int32).max
114
 
115
+ class ExecutionTimer:
116
+ """A context manager to time a block of code."""
117
  def __init__(self, activity_name=""):
118
  self.activity_name = activity_name
119
 
 
124
  def __exit__(self, exc_type, exc_value, traceback):
125
  self.end_time = time.time()
126
  self.elapsed_time = self.end_time - self.start_time
127
+ activity_log = f" for {self.activity_name}" if self.activity_name else ""
128
+ print(f"Elapsed time{activity_log}: {self.elapsed_time:.6f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ def get_dimensions_from_ratio(aspect_ratio_str):
131
+ """Converts an aspect ratio string to a (width, height) tuple."""
132
+ ratios = {
133
+ "1:1": (1024, 1024),
134
+ "16:9": (1152, 640),
135
+ "9:16": (640, 1152),
136
+ "4:3": (1024, 768),
137
+ "3:4": (768, 1024),
138
+ "3:2": (1024, 688),
139
+ "2:3": (688, 1024),
140
+ }
141
+ return ratios.get(aspect_ratio_str, (1024, 1024))
142
+
143
+ def on_style_select(event_data: gr.SelectData, current_aspect_ratio):
144
+ """Handles the user selecting a style from the gallery."""
145
+ selected_style = style_definitions[event_data.index]
146
+ new_placeholder = f"Type a prompt for {selected_style['style_name']}"
147
+ repo_id = selected_style["repo_id"]
148
+ updated_info_text = f"### Selected: [{repo_id}](https://huggingface.co/{repo_id}) ✨"
149
 
150
+ # Update aspect ratio if specified in the style's configuration
151
+ if "aspect" in selected_style:
152
+ if selected_style["aspect"] == "portrait":
153
+ current_aspect_ratio = "9:16"
154
+ elif selected_style["aspect"] == "landscape":
155
+ current_aspect_ratio = "16:9"
156
  else:
157
+ current_aspect_ratio = "1:1"
158
 
159
  return (
160
  gr.update(placeholder=new_placeholder),
161
+ updated_info_text,
162
+ event_data.index,
163
+ current_aspect_ratio,
164
  )
165
 
166
+ def on_mode_change(generation_mode):
167
+ """Updates UI elements based on the selected generation mode (Speed/Quality)."""
168
+ if generation_mode == "Speed (8 steps)":
169
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
170
  else:
171
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
172
 
173
  @spaces.GPU(duration=70)
174
+ def execute_image_generation(full_prompt, steps, seed_val, cfg, width, height, negative_prompt=""):
175
+ """Generates an image using the diffusion pipeline."""
176
+ diffusion_pipeline.to("cuda")
177
+ generator = torch.Generator(device="cuda").manual_seed(seed_val)
178
 
179
+ with ExecutionTimer("Image Generation"):
180
+ generated_image = diffusion_pipeline(
181
+ prompt=full_prompt,
 
182
  negative_prompt=negative_prompt,
183
  num_inference_steps=steps,
184
+ true_cfg_scale=cfg,
185
  width=width,
186
  height=height,
187
  generator=generator,
188
  ).images[0]
189
 
190
+ return generated_image
191
 
192
  @spaces.GPU(duration=70)
193
+ def handle_generate_request(prompt_text, cfg, steps, style_idx, use_random_seed, seed_val, aspect_ratio_str, style_scale, generation_mode, progress=gr.Progress(track_tqdm=True)):
194
+ """Main function to handle a user's image generation request."""
195
+ if style_idx is None:
196
+ raise gr.Error("You must select a style before generating an image.")
197
 
198
+ selected_style = style_definitions[style_idx]
199
+ style_repo_path = selected_style["repo_id"]
200
+ activation_phrase = selected_style["activation_phrase"]
201
 
202
+ # Combine the user prompt with the style's activation phrase
203
+ if activation_phrase:
204
+ position = selected_style.get("trigger_position", "prepend")
205
+ if position == "prepend":
206
+ full_prompt = f"{activation_phrase} {prompt_text}"
 
 
207
  else:
208
+ full_prompt = f"{prompt_text} {activation_phrase}"
209
  else:
210
+ full_prompt = prompt_text
211
+
212
+ # Always unload existing adapters to start fresh
213
+ with ExecutionTimer("Unloading existing adapters"):
214
+ diffusion_pipeline.unload_lora_weights()
215
+
216
+ # Load adapters based on the selected generation mode
217
+ if generation_mode == "Speed (8 steps)":
218
+ with ExecutionTimer("Loading Lightning and Style adapters"):
219
+ # Load the fast generation adapter first
220
+ diffusion_pipeline.load_lora_weights(
221
+ FAST_GENERATION_LORA_REPO,
222
+ weight_name=FAST_GENERATION_LORA_WEIGHTS,
223
  adapter_name="lightning"
224
  )
225
 
226
+ # Load the selected style adapter
227
+ weight_file = selected_style.get("weight_file", None)
228
+ diffusion_pipeline.load_lora_weights(
229
+ style_repo_path,
230
+ weight_name=weight_file,
231
  low_cpu_mem_usage=True,
232
  adapter_name="style"
233
  )
234
 
235
+ # Set both adapters active with their respective weights
236
+ diffusion_pipeline.set_adapters(["lightning", "style"], adapter_weights=[1.0, style_scale])
237
+ else: # Quality mode
238
+ with ExecutionTimer(f"Loading adapter weights for {selected_style['style_name']}"):
239
+ weight_file = selected_style.get("weight_file", None)
240
+ diffusion_pipeline.load_lora_weights(
241
+ style_repo_path,
242
+ weight_name=weight_file,
 
243
  low_cpu_mem_usage=True
244
  )
245
 
246
+ # Set the seed for reproducibility
247
+ with ExecutionTimer("Setting seed"):
248
+ if use_random_seed:
249
+ seed_val = random.randint(0, MAX_SEED_VALUE)
250
 
251
+ # Get image dimensions
252
+ width, height = get_dimensions_from_ratio(aspect_ratio_str)
253
 
254
+ # Generate the final image
255
+ final_image = execute_image_generation(full_prompt, steps, seed_val, cfg, width, height)
256
 
257
+ return final_image, seed_val
258
 
259
+ def fetch_hf_safetensors_details(repo_link):
260
+ """Fetches details of a LoRA from a Hugging Face repository."""
261
+ split_link = repo_link.split("/")
262
  if len(split_link) != 2:
263
+ raise ValueError("Invalid Hugging Face repository link format.")
264
 
265
+ print(f"Attempting to load repository: {repo_link}")
266
 
267
+ model_card = ModelCard.load(repo_link)
 
268
  base_model = model_card.data.get("base_model")
269
+ print(f"Base model identified: {base_model}")
270
 
271
+ # Validate that the LoRA is compatible with Qwen-Image
272
  acceptable_models = {"Qwen/Qwen-Image"}
 
273
  models_to_check = base_model if isinstance(base_model, list) else [base_model]
274
 
275
  if not any(model in acceptable_models for model in models_to_check):
276
+ raise TypeError("The provided model is not a Qwen-Image compatible LoRA.")
277
 
278
+ # Extract metadata from the model card
279
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url")
280
+ activation_phrase = model_card.data.get("instance_prompt", "")
281
+ image_url = f"https://huggingface.co/{repo_link}/resolve/main/{image_path}" if image_path else None
282
 
283
+ # Find the .safetensors file in the repository
284
  fs = HfFileSystem()
285
  try:
286
+ repo_files = fs.ls(repo_link, detail=False)
287
+ safetensors_filename = None
288
+ for file_path in repo_files:
289
+ filename = file_path.split("/")[-1]
 
 
290
  if filename.endswith(".safetensors"):
291
+ safetensors_filename = filename
292
  break
293
+ if not safetensors_filename:
294
+ raise FileNotFoundError("No .safetensors file was found in the repository.")
 
 
295
  except Exception as e:
296
  print(e)
297
+ raise IOError("Could not access the Hugging Face repository or find a valid .safetensors file.")
298
 
299
+ return split_link[1], repo_link, safetensors_filename, activation_phrase, image_url
300
 
301
+ def parse_custom_model_source(source_text):
302
+ """Parses a user-provided link to a custom LoRA."""
303
+ print(f"Parsing custom model source: {source_text}")
304
 
305
+ if source_text.endswith('.safetensors') and 'huggingface.co' in source_text:
306
+ parts = source_text.split('/')
307
+ try:
308
+ hf_index = parts.index('huggingface.co')
309
+ username = parts[hf_index + 1]
310
+ repo_name = parts[hf_index + 2]
311
+ repo_id = f"{username}/{repo_name}"
312
+ safetensors_filename = parts[-1]
313
+
314
  try:
315
+ model_card = ModelCard.load(repo_id)
316
+ activation_phrase = model_card.data.get("instance_prompt", "")
317
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url")
318
+ image_url = f"https://huggingface.co/{repo_id}/resolve/main/{image_path}" if image_path else None
319
+ except Exception:
320
+ activation_phrase = ""
321
+ image_url = None
322
+
323
+ return repo_name, repo_id, safetensors_filename, activation_phrase, image_url
324
+ except ValueError:
325
+ raise ValueError("Invalid .safetensors URL format.")
 
 
 
 
 
 
 
 
326
 
327
+ if source_text.startswith("https://"):
328
+ parsed_url = urlparse(source_text)
329
+ if "huggingface.co" in parsed_url.netloc:
330
+ repo_link = parsed_url.path.strip("/")
331
+ return fetch_hf_safetensors_details(repo_link)
332
+
333
+ # Assume it's a direct repo path like "username/repo-name"
334
+ return fetch_hf_safetensors_details(source_text)
335
 
336
+
337
+ def add_custom_style_model(custom_model_path):
338
+ """Adds a custom LoRA provided by the user to the session."""
339
+ global style_definitions
340
+ if custom_model_path:
341
  try:
342
+ style_name, repo_id, weight_file, activation_phrase, thumbnail_url = parse_custom_model_source(custom_model_path)
343
+ print(f"Successfully loaded custom style: {repo_id}")
344
+
345
+ card_html = f'''
346
  <div class="custom_lora_card">
347
+ <span>Loaded custom style:</span>
348
  <div class="card_internal">
349
+ <img src="{thumbnail_url}" alt="{style_name}" />
350
  <div>
351
+ <h3>{style_name}</h3>
352
+ <small>{"Activation phrase: <code><b>"+activation_phrase+"</b></code>" if activation_phrase else "No activation phrase found. If required, include it in your prompt."}<br></small>
353
  </div>
354
  </div>
355
  </div>
356
  '''
357
+
358
+ # Check if this style already exists
359
+ existing_item_index = next((index for (index, item) in enumerate(style_definitions) if item['repo_id'] == repo_id), None)
360
+
361
  if existing_item_index is None:
362
+ new_style_item = {
363
+ "thumbnail_url": thumbnail_url,
364
+ "style_name": style_name,
365
+ "repo_id": repo_id,
366
+ "weight_file": weight_file,
367
+ "activation_phrase": activation_phrase
368
  }
369
+ style_definitions.append(new_style_item)
370
+ existing_item_index = len(style_definitions) - 1
371
+
372
+ return gr.update(visible=True, value=card_html), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {weight_file}", existing_item_index, activation_phrase
373
 
 
374
  except Exception as e:
375
+ gr.Warning(f"Failed to load custom style. Error: {e}")
376
+ error_message = f"Invalid input. Could not load the specified style. Please check the link or repository path."
377
+ return gr.update(visible=True, value=error_message), gr.update(visible=True), gr.update(), "", None, ""
378
+
379
+ # If input is empty, hide the custom section
380
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
381
 
382
+ def remove_custom_style_model():
383
+ """Resets the UI when a custom LoRA is removed."""
384
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
385
 
 
386
 
387
+ # --- Gradio UI Definition ---
388
+
389
+ app_css = '''
390
  #gen_btn{height: 100%}
391
  #gen_column{align-self: stretch}
392
  #title{text-align: center}
 
395
  #gallery .grid-wrap{height: 10vh}
396
  #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
397
  .card_internal{display: flex;height: 100px;margin-top: .5em}
398
+ .card_internal img{margin-right: 1em; object-fit: cover;}
399
  .styler{--form-gap-width: 0px !important}
400
  #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
401
+ .custom_lora_card{padding: 1em; border: 1px solid var(--border-color-primary); border-radius: var(--radius-lg)}
402
  '''
403
 
404
+ with gr.Blocks(theme="bethecloud/storj_theme", css=app_css, delete_cache=(120, 120)) as web_interface:
405
+ main_title = gr.HTML("""<h1>Qwen Image Style Showcase ❤️‍🔥</h1>""", elem_id="title")
406
+ selected_style_index = gr.State(None)
407
 
408
  with gr.Row():
409
  with gr.Column(scale=3):
410
+ prompt_textbox = gr.Textbox(label="Prompt", lines=1, placeholder="Select a style to begin...")
411
  with gr.Column(scale=1, elem_id="gen_column"):
412
+ generate_btn = gr.Button("Generate", variant="primary", elem_id="gen_btn")
413
 
414
  with gr.Row():
415
  with gr.Column():
416
+ selected_style_info = gr.Markdown("")
417
+ style_gallery = gr.Gallery(
418
+ [(item["thumbnail_url"], item["style_name"]) for item in style_definitions],
419
+ label="Style Gallery",
420
  allow_preview=False,
421
  columns=3,
422
  elem_id="gallery",
423
  show_share_button=False
424
  )
425
  with gr.Group():
426
+ custom_style_textbox = gr.Textbox(label="Load Custom Style", info="Enter a Hugging Face repository path (e.g., username/repo-name)", placeholder="username/qwen-image-custom-style")
427
+ gr.Markdown("[Find More Qwen-Image Styles Here](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
428
+ custom_style_info_html = gr.HTML(visible=False)
429
+ remove_custom_style_btn = gr.Button("Remove Custom Style", visible=False)
430
 
431
  with gr.Column():
432
+ output_image_display = gr.Image(label="Generated Image")
433
 
434
  with gr.Row():
435
+ aspect_ratio_dropdown = gr.Dropdown(
436
  label="Aspect Ratio",
437
  choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
438
  value="1:1"
439
+ )
440
  with gr.Row():
441
+ generation_mode_dropdown = gr.Dropdown(
442
  label="Generation Mode",
443
  choices=["Speed (8 steps)", "Quality (45 steps)"],
444
+ value="Quality (45 steps)",
445
  )
446
 
447
+ generation_mode_status_display = gr.Markdown("Quality mode active", elem_id="speed_status")
448
 
449
  with gr.Row():
450
  with gr.Accordion("Advanced Settings", open=False):
451
  with gr.Column():
452
  with gr.Row():
453
+ cfg_scale_slider = gr.Slider(
454
+ label="Guidance Scale (CFG)",
455
  minimum=1.0,
456
  maximum=5.0,
457
  step=0.1,
458
  value=3.5,
459
+ info="Adjusts how strictly the model follows the prompt. Lower for speed, higher for quality."
460
  )
461
+ steps_slider = gr.Slider(
462
+ label="Inference Steps",
463
  minimum=4,
464
  maximum=50,
465
  step=1,
466
  value=45,
467
+ info="Number of steps for the generation process. Automatically set by Generation Mode."
468
  )
469
 
470
  with gr.Row():
471
+ randomize_seed_checkbox = gr.Checkbox(True, label="Use Random Seed")
472
+ seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED_VALUE, step=1, value=0, randomize=True)
473
+ style_scale_slider = gr.Slider(label="Style Strength", minimum=0, maximum=2, step=0.01, value=1.0)
474
+
475
+ # --- Event Handlers ---
476
+ style_gallery.select(
477
+ on_style_select,
478
+ inputs=[aspect_ratio_dropdown],
479
+ outputs=[prompt_textbox, selected_style_info, selected_style_index, aspect_ratio_dropdown]
480
  )
481
 
482
+ generation_mode_dropdown.change(
483
+ on_mode_change,
484
+ inputs=[generation_mode_dropdown],
485
+ outputs=[generation_mode_status_display, steps_slider, cfg_scale_slider]
486
  )
487
 
488
+ custom_style_textbox.submit(
489
+ add_custom_style_model,
490
+ inputs=[custom_style_textbox],
491
+ outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, prompt_textbox]
492
  )
493
 
494
+ remove_custom_style_btn.click(
495
+ remove_custom_style_model,
496
+ outputs=[custom_style_info_html, remove_custom_style_btn, style_gallery, selected_style_info, selected_style_index, custom_style_textbox]
497
  )
498
 
499
+ # Combined trigger for generation
500
+ generate_triggers = [generate_btn.click, prompt_textbox.submit]
501
  gr.on(
502
+ triggers=generate_triggers,
503
+ fn=handle_generate_request,
504
+ inputs=[prompt_textbox, cfg_scale_slider, steps_slider, selected_style_index, randomize_seed_checkbox, seed_slider, aspect_ratio_dropdown, style_scale_slider, generation_mode_dropdown],
505
+ outputs=[output_image_display, seed_slider]
506
  )
507
 
508
+ web_interface.queue()
509
+ web_interface.launch(share=False, ssr_mode=False, show_error=True)