prithivMLmods commited on
Commit
2a31e71
·
verified ·
1 Parent(s): 94d22ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -325
app.py CHANGED
@@ -24,31 +24,6 @@ import shutil
24
  import uuid
25
  import zipfile
26
 
27
- # Helper functions
28
- def save_image(img):
29
- unique_name = str(uuid.uuid4()) + ".png"
30
- img.save(unique_name)
31
- return unique_name
32
-
33
- MAX_SEED = np.iinfo(np.int32).max
34
- MAX_IMAGE_SIZE = 2048
35
-
36
- # Load Qwen/Qwen-Image pipeline
37
- dtype = torch.bfloat16
38
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
-
40
- # Load Qwen model
41
- pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
42
-
43
- # Aspect ratios
44
- aspect_ratios = {
45
- "1:1": (1328, 1328),
46
- "16:9": (1664, 928),
47
- "9:16": (928, 1664),
48
- "4:3": (1472, 1140),
49
- "3:4": (1140, 1472)
50
- }
51
-
52
  loras = [
53
  # Sample Qwen-compatible LoRAs
54
  {
@@ -88,84 +63,271 @@ loras = [
88
  },
89
  ]
90
 
91
- def load_lora_opt(pipe, lora_input):
92
- lora_input = lora_input.strip()
93
- if not lora_input:
94
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # If it's just an ID like "author/model"
97
- if "/" in lora_input and not lora_input.startswith("http"):
98
- pipe.load_lora_weights(lora_input, adapter_name="default")
99
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- if lora_input.startswith("http"):
102
- url = lora_input
103
- # Repo page (no blob/resolve)
104
- if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
105
- repo_id = urlparse(url).path.strip("/")
106
- pipe.load_lora_weights(repo_id, adapter_name="default")
107
- return
108
-
109
- # Blob link → convert to resolve link
110
- if "/blob/" in url:
111
- url = url.replace("/blob/", "/resolve/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- # Download direct file
114
- tmp_dir = tempfile.mkdtemp()
115
- local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
116
- try:
117
- print(f"Downloading LoRA from {url}...")
118
- resp = requests.get(url, stream=True)
119
- resp.raise_for_status()
120
- with open(local_path, "wb") as f:
121
- for chunk in resp.iter_content(chunk_size=8192):
122
- f.write(chunk)
123
- print(f"Saved LoRA to {local_path}")
124
- pipe.load_lora_weights(local_path, adapter_name="default")
125
- finally:
126
- shutil.rmtree(tmp_dir, ignore_errors=True)
127
 
128
- def get_huggingface_safetensors(link):
129
- split_link = link.split("/")
130
- if len(split_link) == 2:
131
- try:
132
- response = requests.get(f"https://huggingface.co/api/models/{link}")
133
- response.raise_for_status()
134
- model_info = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- # Check if it's a Qwen model
137
- if "qwen" not in model_info.get("tags", []):
138
- raise Exception("Not a Qwen LoRA model!")
139
-
140
- # Get image if available
141
- image_url = None
142
- if "cardData" in model_info and "widget" in model_info["cardData"]:
143
- if len(model_info["cardData"]["widget"]) > 0:
144
- image_url = model_info["cardData"]["widget"][0].get("output", {}).get("url", None)
145
 
146
- # Try to find safetensors file
147
- safetensors_name = None
148
- try:
149
- model_files = requests.get(f"https://huggingface.co/api/models/{link}/tree/main").json()
150
- for file in model_files:
151
- if file.get("path", "").endswith(".safetensors"):
152
- safetensors_name = file["path"]
153
- break
154
- except:
155
- pass
 
156
 
157
- return split_link[1], link, safetensors_name, "trigger_word", image_url
158
- except Exception as e:
159
- print(f"Error getting model info: {e}")
160
- raise Exception(f"Failed to get model info: {e}")
161
- return None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  def check_custom_model(link):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if link.startswith("https://"):
165
  if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
166
  link_split = link.split("huggingface.co/")
167
  return get_huggingface_safetensors(link_split[1])
168
- else:
169
  return get_huggingface_safetensors(link)
170
 
171
  def add_custom_lora(custom_lora):
@@ -173,9 +335,6 @@ def add_custom_lora(custom_lora):
173
  if custom_lora:
174
  try:
175
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
176
- if not title:
177
- raise Exception("Invalid LoRA model")
178
-
179
  print(f"Loaded custom LoRA: {repo}")
180
  card = f'''
181
  <div class="custom_lora_card">
@@ -189,9 +348,8 @@ def add_custom_lora(custom_lora):
189
  </div>
190
  </div>
191
  '''
192
-
193
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
194
- if not existing_item_index:
195
  new_item = {
196
  "image": image,
197
  "title": title,
@@ -199,161 +357,21 @@ def add_custom_lora(custom_lora):
199
  "weights": path,
200
  "trigger_word": trigger_word
201
  }
202
- existing_item_index = len(loras)
203
  loras.append(new_item)
 
204
 
205
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
206
  except Exception as e:
207
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen LoRA")
208
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen LoRA"), gr.update(visible=False), gr.update(), "", None, ""
209
  else:
210
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
211
 
212
  def remove_custom_lora():
213
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
214
 
215
- def update_selection(evt: gr.SelectData, width, height):
216
- selected_lora = loras[evt.index]
217
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
218
- lora_repo = selected_lora["repo"]
219
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
220
-
221
- # Update aspect ratio based on LoRA if it has aspect info
222
- if "aspect" in selected_lora:
223
- if selected_lora["aspect"] == "portrait":
224
- width = 928
225
- height = 1664
226
- elif selected_lora["aspect"] == "landscape":
227
- width = 1664
228
- height = 928
229
- else:
230
- width = 1328
231
- height = 1328
232
-
233
- return (
234
- gr.update(placeholder=new_placeholder),
235
- updated_text,
236
- evt.index,
237
- width,
238
- height,
239
- )
240
-
241
- @spaces.GPU(duration=120)
242
- def generate_qwen(
243
- prompt: str,
244
- negative_prompt: str = "",
245
- seed: int = 0,
246
- width: int = 1024,
247
- height: int = 1024,
248
- guidance_scale: float = 4.0,
249
- randomize_seed: bool = False,
250
- num_inference_steps: int = 50,
251
- num_images: int = 1,
252
- zip_images: bool = False,
253
- lora_input: str = "",
254
- lora_scale: float = 1.0,
255
- progress=gr.Progress(track_tqdm=True),
256
- ):
257
- if randomize_seed:
258
- seed = random.randint(0, MAX_SEED)
259
-
260
- generator = torch.Generator(device).manual_seed(seed)
261
-
262
- start_time = time.time()
263
-
264
- # Clear any existing LoRA adapters
265
- current_adapters = pipe.get_list_adapters()
266
- for adapter in current_adapters:
267
- pipe.delete_adapters(adapter)
268
- pipe.disable_lora()
269
-
270
- use_lora = False
271
- if lora_input and lora_input.strip() != "":
272
- load_lora_opt(pipe, lora_input)
273
- pipe.set_adapters(["default"], adapter_weights=[lora_scale])
274
- use_lora = True
275
-
276
- images = pipe(
277
- prompt=prompt,
278
- negative_prompt=negative_prompt if negative_prompt else "",
279
- height=height,
280
- width=width,
281
- guidance_scale=guidance_scale,
282
- num_inference_steps=num_inference_steps,
283
- num_images_per_prompt=num_images,
284
- generator=generator,
285
- output_type="pil",
286
- ).images
287
-
288
- end_time = time.time()
289
- duration = end_time - start_time
290
-
291
- image_paths = [save_image(img) for img in images]
292
- zip_path = None
293
- if zip_images:
294
- zip_name = str(uuid.uuid4()) + ".zip"
295
- with zipfile.ZipFile(zip_name, 'w') as zipf:
296
- for i, img_path in enumerate(image_paths):
297
- zipf.write(img_path, arcname=f"Img_{i}.png")
298
- zip_path = zip_name
299
-
300
- # Clean up adapters
301
- current_adapters = pipe.get_list_adapters()
302
- for adapter in current_adapters:
303
- pipe.delete_adapters(adapter)
304
- pipe.disable_lora()
305
-
306
- return image_paths, seed, f"{duration:.2f}", zip_path
307
-
308
- @spaces.GPU(duration=120)
309
- def run_lora(
310
- prompt: str,
311
- negative_prompt: str,
312
- use_negative_prompt: bool,
313
- seed: int,
314
- width: int,
315
- height: int,
316
- guidance_scale: float,
317
- randomize_seed: bool,
318
- num_inference_steps: int,
319
- num_images: int,
320
- zip_images: bool,
321
- selected_index: int,
322
- lora_scale: float,
323
- progress=gr.Progress(track_tqdm=True),
324
- ):
325
- if selected_index is None:
326
- raise gr.Error("You must select a LoRA before proceeding.🧨")
327
-
328
- selected_lora = loras[selected_index]
329
- lora_repo = selected_lora["repo"]
330
- trigger_word = selected_lora["trigger_word"]
331
-
332
- if trigger_word:
333
- prompt_mash = f"{trigger_word} {prompt}"
334
- else:
335
- prompt_mash = prompt
336
-
337
- final_negative_prompt = negative_prompt if use_negative_prompt else ""
338
-
339
- if randomize_seed:
340
- seed = random.randint(0, MAX_SEED)
341
-
342
- return generate_qwen(
343
- prompt=prompt_mash,
344
- negative_prompt=final_negative_prompt,
345
- seed=seed,
346
- width=width,
347
- height=height,
348
- guidance_scale=guidance_scale,
349
- randomize_seed=False, # Already handled
350
- num_inference_steps=num_inference_steps,
351
- num_images=num_images,
352
- zip_images=zip_images,
353
- lora_input=lora_repo,
354
- lora_scale=lora_scale,
355
- progress=progress,
356
- )
357
 
358
  css = '''
359
  #gen_btn{height: 100%}
@@ -366,10 +384,7 @@ css = '''
366
  .card_internal{display: flex;height: 100px;margin-top: .5em}
367
  .card_internal img{margin-right: 1em}
368
  .styler{--form-gap-width: 0px !important}
369
- #progress{height:30px}
370
- #progress .generating{display:none}
371
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
372
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
373
  '''
374
 
375
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
@@ -378,7 +393,7 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
378
 
379
  with gr.Row():
380
  with gr.Column(scale=3):
381
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt")
382
  with gr.Column(scale=1, elem_id="gen_column"):
383
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
384
 
@@ -387,89 +402,73 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
387
  selected_info = gr.Markdown("")
388
  gallery = gr.Gallery(
389
  [(item["image"], item["title"]) for item in loras],
390
- label="Qwen LoRA DLC's",
391
  allow_preview=False,
392
  columns=3,
393
  elem_id="gallery",
394
  show_share_button=False
395
  )
396
  with gr.Group():
397
- custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Qwen-Image-Sketch-Smudge")
398
- gr.Markdown("[Check the list of Qwen LoRA's](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
399
  custom_lora_info = gr.HTML(visible=False)
400
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
401
 
402
  with gr.Column():
403
- result = gr.Gallery(label="Generated Images", columns=1, show_label=False, preview=True)
 
404
  with gr.Row():
405
  aspect_ratio = gr.Dropdown(
406
  label="Aspect Ratio",
407
- choices=list(aspect_ratios.keys()),
408
- value="1:1",
409
- )
410
  with gr.Row():
411
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=48)
412
-
413
- with gr.Row():
414
- with gr.Accordion("Advanced Settings", open=False):
415
-
416
- with gr.Row():
417
- use_negative_prompt = gr.Checkbox(
418
- label="Use negative prompt",
419
- value=True,
420
  )
421
- negative_prompt = gr.Text(
422
- label="Negative prompt",
423
- max_lines=1,
424
- placeholder="Enter a negative prompt",
425
- value="text, watermark, copyright, blurry, low resolution",
426
- )
427
-
428
- with gr.Row():
429
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=4.0)
430
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=50)
431
-
432
- with gr.Row():
433
- width = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1328)
434
- height = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1328)
435
-
436
- with gr.Row():
437
- num_images = gr.Slider(label="Number of Images", minimum=1, maximum=5, step=1, value=1)
438
- zip_images = gr.Checkbox(label="Zip generated images", value=False)
439
 
440
- with gr.Row():
441
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
442
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
443
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
444
-
445
- # Output information
446
  with gr.Row():
447
- seed_display = gr.Textbox(label="Seed used", interactive=False)
448
- generation_time = gr.Textbox(label="Generation time (seconds)", interactive=False)
449
- zip_file = gr.File(label="Download ZIP")
450
-
451
- # Update aspect ratio
452
- def set_dimensions(ar):
453
- w, h = aspect_ratios[ar]
454
- return gr.update(value=w), gr.update(value=h)
455
-
456
- aspect_ratio.change(
457
- fn=set_dimensions,
458
- inputs=aspect_ratio,
459
- outputs=[width, height]
460
- )
461
-
462
- # Negative prompt visibility
463
- use_negative_prompt.change(
464
- fn=lambda x: gr.update(visible=x),
465
- inputs=use_negative_prompt,
466
- outputs=negative_prompt
467
- )
468
-
 
 
 
 
469
  gallery.select(
470
  update_selection,
471
- inputs=[width, height],
472
- outputs=[prompt, selected_info, selected_index, width, height]
 
 
 
 
 
 
473
  )
474
 
475
  custom_lora.input(
@@ -486,22 +485,8 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
486
  gr.on(
487
  triggers=[generate_button.click, prompt.submit],
488
  fn=run_lora,
489
- inputs=[
490
- prompt,
491
- negative_prompt,
492
- use_negative_prompt,
493
- seed,
494
- width,
495
- height,
496
- #guidance_scale,
497
- randomize_seed,
498
- steps,
499
- num_images,
500
- zip_images,
501
- selected_index,
502
- lora_scale,
503
- ],
504
- outputs=[result, seed_display, generation_time, zip_file]
505
  )
506
 
507
  app.queue()
 
24
  import uuid
25
  import zipfile
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  loras = [
28
  # Sample Qwen-compatible LoRAs
29
  {
 
63
  },
64
  ]
65
 
66
+ # Initialize the base model
67
+ dtype = torch.bfloat16
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ base_model = "Qwen/Qwen-Image"
70
+
71
+ # Scheduler configuration from the Qwen-Image-Lightning repository
72
+ scheduler_config = {
73
+ "base_image_seq_len": 256,
74
+ "base_shift": math.log(3),
75
+ "invert_sigmas": False,
76
+ "max_image_seq_len": 8192,
77
+ "max_shift": math.log(3),
78
+ "num_train_timesteps": 1000,
79
+ "shift": 1.0,
80
+ "shift_terminal": None,
81
+ "stochastic_sampling": False,
82
+ "time_shift_type": "exponential",
83
+ "use_beta_sigmas": False,
84
+ "use_dynamic_shifting": True,
85
+ "use_exponential_sigmas": False,
86
+ "use_karras_sigmas": False,
87
+ }
88
+
89
+ scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
90
+ pipe = DiffusionPipeline.from_pretrained(
91
+ base_model, scheduler=scheduler, torch_dtype=dtype
92
+ ).to(device)
93
+
94
+ # Lightning LoRA info (no global state)
95
+ LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
96
+ LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
97
+
98
+ MAX_SEED = np.iinfo(np.int32).max
99
+
100
+ class calculateDuration:
101
+ def __init__(self, activity_name=""):
102
+ self.activity_name = activity_name
103
+
104
+ def __enter__(self):
105
+ self.start_time = time.time()
106
+ return self
107
 
108
+ def __exit__(self, exc_type, exc_value, traceback):
109
+ self.end_time = time.time()
110
+ self.elapsed_time = self.end_time - self.start_time
111
+ if self.activity_name:
112
+ print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
113
+ else:
114
+ print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
115
+
116
+ def get_image_size(aspect_ratio):
117
+ """Converts aspect ratio string to width, height tuple."""
118
+ if aspect_ratio == "1:1":
119
+ return 1024, 1024
120
+ elif aspect_ratio == "16:9":
121
+ return 1152, 640
122
+ elif aspect_ratio == "9:16":
123
+ return 640, 1152
124
+ elif aspect_ratio == "4:3":
125
+ return 1024, 768
126
+ elif aspect_ratio == "3:4":
127
+ return 768, 1024
128
+ elif aspect_ratio == "3:2":
129
+ return 1024, 688
130
+ elif aspect_ratio == "2:3":
131
+ return 688, 1024
132
+ else:
133
+ return 1024, 1024
134
+
135
+ def update_selection(evt: gr.SelectData, aspect_ratio):
136
+ selected_lora = loras[evt.index]
137
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
138
+ lora_repo = selected_lora["repo"]
139
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
140
 
141
+ # Update aspect ratio if specified in LoRA config
142
+ if "aspect" in selected_lora:
143
+ if selected_lora["aspect"] == "portrait":
144
+ aspect_ratio = "9:16"
145
+ elif selected_lora["aspect"] == "landscape":
146
+ aspect_ratio = "16:9"
147
+ else:
148
+ aspect_ratio = "1:1"
149
+
150
+ return (
151
+ gr.update(placeholder=new_placeholder),
152
+ updated_text,
153
+ evt.index,
154
+ aspect_ratio,
155
+ )
156
+
157
+ def handle_speed_mode(speed_mode):
158
+ """Update UI based on speed/quality toggle."""
159
+ if speed_mode == "Speed (8 steps)":
160
+ return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
161
+ else:
162
+ return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
163
+
164
+ @spaces.GPU(duration=70)
165
+ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
166
+ pipe.to("cuda")
167
+ generator = torch.Generator(device="cuda").manual_seed(seed)
168
+
169
+ with calculateDuration("Generating image"):
170
+ # Generate image
171
+ image = pipe(
172
+ prompt=prompt_mash,
173
+ negative_prompt=negative_prompt,
174
+ num_inference_steps=steps,
175
+ true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
176
+ width=width,
177
+ height=height,
178
+ generator=generator,
179
+ ).images[0]
180
 
181
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ @spaces.GPU(duration=70)
184
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
185
+ if selected_index is None:
186
+ raise gr.Error("You must select a LoRA before proceeding.")
187
+
188
+ selected_lora = loras[selected_index]
189
+ lora_path = selected_lora["repo"]
190
+ trigger_word = selected_lora["trigger_word"]
191
+
192
+ # Prepare prompt with trigger word
193
+ if trigger_word:
194
+ if "trigger_position" in selected_lora:
195
+ if selected_lora["trigger_position"] == "prepend":
196
+ prompt_mash = f"{trigger_word} {prompt}"
197
+ else:
198
+ prompt_mash = f"{prompt} {trigger_word}"
199
+ else:
200
+ prompt_mash = f"{trigger_word} {prompt}"
201
+ else:
202
+ prompt_mash = prompt
203
+
204
+ # Always unload any existing LoRAs first to avoid conflicts
205
+ with calculateDuration("Unloading existing LoRAs"):
206
+ pipe.unload_lora_weights()
207
+
208
+ # Load LoRAs based on speed mode
209
+ if speed_mode == "Speed (8 steps)":
210
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
211
+ # Load Lightning LoRA first
212
+ pipe.load_lora_weights(
213
+ LIGHTNING_LORA_REPO,
214
+ weight_name=LIGHTNING_LORA_WEIGHT,
215
+ adapter_name="lightning"
216
+ )
217
 
218
+ # Load the selected style LoRA
219
+ weight_name = selected_lora.get("weights", None)
220
+ pipe.load_lora_weights(
221
+ lora_path,
222
+ weight_name=weight_name,
223
+ low_cpu_mem_usage=True,
224
+ adapter_name="style"
225
+ )
 
226
 
227
+ # Set both adapters active with their weights
228
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
229
+ else:
230
+ # Quality mode - only load the style LoRA
231
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
232
+ weight_name = selected_lora.get("weights", None)
233
+ pipe.load_lora_weights(
234
+ lora_path,
235
+ weight_name=weight_name,
236
+ low_cpu_mem_usage=True
237
+ )
238
 
239
+ # Set random seed for reproducibility
240
+ with calculateDuration("Randomizing seed"):
241
+ if randomize_seed:
242
+ seed = random.randint(0, MAX_SEED)
243
+
244
+ # Get image dimensions from aspect ratio
245
+ width, height = get_image_size(aspect_ratio)
246
+
247
+ # Generate the image
248
+ final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
249
+
250
+ return final_image, seed
251
+
252
+ def get_huggingface_safetensors(link):
253
+ split_link = link.split("/")
254
+ if len(split_link) != 2:
255
+ raise Exception("Invalid Hugging Face repository link format.")
256
+
257
+ print(f"Repository attempted: {split_link}")
258
+
259
+ # Load model card
260
+ model_card = ModelCard.load(link)
261
+ base_model = model_card.data.get("base_model")
262
+ print(f"Base model: {base_model}")
263
+
264
+ # Validate model type (for Qwen-Image)
265
+ acceptable_models = {"Qwen/Qwen-Image"}
266
+
267
+ models_to_check = base_model if isinstance(base_model, list) else [base_model]
268
+
269
+ if not any(model in acceptable_models for model in models_to_check):
270
+ raise Exception("Not a Qwen-Image LoRA!")
271
+
272
+ # Extract image and trigger word
273
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
274
+ trigger_word = model_card.data.get("instance_prompt", "")
275
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
276
+
277
+ # Initialize Hugging Face file system
278
+ fs = HfFileSystem()
279
+ try:
280
+ list_of_files = fs.ls(link, detail=False)
281
+
282
+ # Find safetensors file
283
+ safetensors_name = None
284
+ for file in list_of_files:
285
+ filename = file.split("/")[-1]
286
+ if filename.endswith(".safetensors"):
287
+ safetensors_name = filename
288
+ break
289
+
290
+ if not safetensors_name:
291
+ raise Exception("No valid *.safetensors file found in the repository.")
292
+
293
+ except Exception as e:
294
+ print(e)
295
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
296
+
297
+ return split_link[1], link, safetensors_name, trigger_word, image_url
298
 
299
  def check_custom_model(link):
300
+ print(f"Checking a custom model on: {link}")
301
+
302
+ if link.endswith('.safetensors'):
303
+ if 'huggingface.co' in link:
304
+ parts = link.split('/')
305
+ try:
306
+ hf_index = parts.index('huggingface.co')
307
+ username = parts[hf_index + 1]
308
+ repo_name = parts[hf_index + 2]
309
+ repo = f"{username}/{repo_name}"
310
+
311
+ safetensors_name = parts[-1]
312
+
313
+ try:
314
+ model_card = ModelCard.load(repo)
315
+ trigger_word = model_card.data.get("instance_prompt", "")
316
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
317
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
318
+ except:
319
+ trigger_word = ""
320
+ image_url = None
321
+
322
+ return repo_name, repo, safetensors_name, trigger_word, image_url
323
+ except:
324
+ raise Exception("Invalid safetensors URL format")
325
+
326
  if link.startswith("https://"):
327
  if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
328
  link_split = link.split("huggingface.co/")
329
  return get_huggingface_safetensors(link_split[1])
330
+ else:
331
  return get_huggingface_safetensors(link)
332
 
333
  def add_custom_lora(custom_lora):
 
335
  if custom_lora:
336
  try:
337
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
 
 
 
338
  print(f"Loaded custom LoRA: {repo}")
339
  card = f'''
340
  <div class="custom_lora_card">
 
348
  </div>
349
  </div>
350
  '''
 
351
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
352
+ if existing_item_index is None:
353
  new_item = {
354
  "image": image,
355
  "title": title,
 
357
  "weights": path,
358
  "trigger_word": trigger_word
359
  }
360
+ print(new_item)
361
  loras.append(new_item)
362
+ existing_item_index = len(loras) - 1 # Get the actual index after adding
363
 
364
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
365
  except Exception as e:
366
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
367
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, ""
368
  else:
369
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
370
 
371
  def remove_custom_lora():
372
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
373
 
374
+ run_lora.zerogpu = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
  css = '''
377
  #gen_btn{height: 100%}
 
384
  .card_internal{display: flex;height: 100px;margin-top: .5em}
385
  .card_internal img{margin-right: 1em}
386
  .styler{--form-gap-width: 0px !important}
387
+ #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
 
 
 
388
  '''
389
 
390
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
 
393
 
394
  with gr.Row():
395
  with gr.Column(scale=3):
396
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
397
  with gr.Column(scale=1, elem_id="gen_column"):
398
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
399
 
 
402
  selected_info = gr.Markdown("")
403
  gallery = gr.Gallery(
404
  [(item["image"], item["title"]) for item in loras],
405
+ label="LoRA Gallery",
406
  allow_preview=False,
407
  columns=3,
408
  elem_id="gallery",
409
  show_share_button=False
410
  )
411
  with gr.Group():
412
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
413
+ gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
414
  custom_lora_info = gr.HTML(visible=False)
415
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
416
 
417
  with gr.Column():
418
+ result = gr.Image(label="Generated Image")
419
+
420
  with gr.Row():
421
  aspect_ratio = gr.Dropdown(
422
  label="Aspect Ratio",
423
+ choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3"],
424
+ value="1:1"
425
+ )
426
  with gr.Row():
427
+ speed_mode = gr.Dropdown(
428
+ label="Generation Mode",
429
+ choices=["Speed (8 steps)", "Quality (45 steps)"],
430
+ value="Quality (48 steps)",
 
 
 
 
 
431
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
+ speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
434
+
 
 
 
 
435
  with gr.Row():
436
+ with gr.Accordion("Advanced Settings", open=False):
437
+ with gr.Column():
438
+ with gr.Row():
439
+ cfg_scale = gr.Slider(
440
+ label="Guidance Scale (True CFG)",
441
+ minimum=1.0,
442
+ maximum=5.0,
443
+ step=0.1,
444
+ value=3.5,
445
+ info="Lower for speed mode, higher for quality"
446
+ )
447
+ steps = gr.Slider(
448
+ label="Steps",
449
+ minimum=4,
450
+ maximum=50,
451
+ step=1,
452
+ value=45,
453
+ info="Automatically set by speed mode"
454
+ )
455
+
456
+ with gr.Row():
457
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
458
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
459
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
460
+
461
+ # Event handlers
462
  gallery.select(
463
  update_selection,
464
+ inputs=[aspect_ratio],
465
+ outputs=[prompt, selected_info, selected_index, aspect_ratio]
466
+ )
467
+
468
+ speed_mode.change(
469
+ handle_speed_mode,
470
+ inputs=[speed_mode],
471
+ outputs=[speed_status, steps, cfg_scale]
472
  )
473
 
474
  custom_lora.input(
 
485
  gr.on(
486
  triggers=[generate_button.click, prompt.submit],
487
  fn=run_lora,
488
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode],
489
+ outputs=[result, seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  )
491
 
492
  app.queue()