prithivMLmods commited on
Commit
94d22ce
·
verified ·
1 Parent(s): bca1c98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -289
app.py CHANGED
@@ -24,80 +24,30 @@ import shutil
24
  import uuid
25
  import zipfile
26
 
27
- def calculate_shift(
28
- image_seq_len,
29
- base_seq_len: int = 256,
30
- max_seq_len: int = 4096,
31
- base_shift: float = 0.5,
32
- max_shift: float = 1.16,
33
- ):
34
- m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
35
- b = base_shift - m * base_seq_len
36
- mu = image_seq_len * m + b
37
- return mu
38
-
39
  def save_image(img):
40
  unique_name = str(uuid.uuid4()) + ".png"
41
  img.save(unique_name)
42
  return unique_name
43
 
44
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
45
- if randomize_seed:
46
- seed = random.randint(0, MAX_SEED)
47
- return seed
48
 
49
- # Qwen Image pipeline with live preview capability
50
- @torch.inference_mode()
51
- def qwen_pipe_call_that_returns_an_iterable_of_images(
52
- self,
53
- prompt: Union[str, List[str]] = None,
54
- negative_prompt: Optional[Union[str, List[str]]] = None,
55
- height: Optional[int] = None,
56
- width: Optional[int] = None,
57
- num_inference_steps: int = 50,
58
- guidance_scale: float = 4.0,
59
- num_images_per_prompt: Optional[int] = 1,
60
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61
- output_type: Optional[str] = "pil",
62
- ):
63
- height = height or 1024
64
- width = width or 1024
65
-
66
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
67
- device = self._execution_device
68
-
69
- # Generate intermediate images during the process
70
- for i in range(num_inference_steps):
71
- if i % 5 == 0: # Show progress every 5 steps
72
- # Generate partial result
73
- temp_result = self(
74
- prompt=prompt,
75
- negative_prompt=negative_prompt,
76
- height=height,
77
- width=width,
78
- guidance_scale=guidance_scale,
79
- num_inference_steps=max(1, i + 1),
80
- num_images_per_prompt=num_images_per_prompt,
81
- generator=generator,
82
- output_type=output_type,
83
- ).images[0]
84
- yield temp_result
85
- torch.cuda.empty_cache()
86
-
87
- # Final high-quality result
88
- final_result = self(
89
- prompt=prompt,
90
- negative_prompt=negative_prompt,
91
- height=height,
92
- width=width,
93
- guidance_scale=guidance_scale,
94
- num_inference_steps=num_inference_steps,
95
- num_images_per_prompt=num_images_per_prompt,
96
- generator=generator,
97
- output_type=output_type,
98
- ).images[0]
99
-
100
- yield final_result
101
 
102
  loras = [
103
  # Sample Qwen-compatible LoRAs
@@ -138,56 +88,16 @@ loras = [
138
  },
139
  ]
140
 
141
- #--------------------------------------------------Model Initialization-----------------------------------------------------------------------------------------#
142
- dtype = torch.bfloat16
143
- device = "cuda" if torch.cuda.is_available() else "cpu"
144
- base_model = "Qwen/Qwen-Image"
145
-
146
- # Load Qwen Image pipeline
147
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device)
148
-
149
- # Add aspect ratios for Qwen
150
- aspect_ratios = {
151
- "1:1": (1024, 1024),
152
- "16:9": (1344, 768),
153
- "9:16": (768, 1344),
154
- "4:3": (1152, 896),
155
- "3:4": (896, 1152),
156
- "3:2": (1216, 832),
157
- "2:3": (832, 1216)
158
- }
159
-
160
- MAX_SEED = 2**32-1
161
-
162
- # Add the custom method to the pipeline
163
- pipe.qwen_pipe_call_that_returns_an_iterable_of_images = qwen_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
164
-
165
- class calculateDuration:
166
- def __init__(self, activity_name=""):
167
- self.activity_name = activity_name
168
-
169
- def __enter__(self):
170
- self.start_time = time.time()
171
- return self
172
-
173
- def __exit__(self, exc_type, exc_value, traceback):
174
- self.end_time = time.time()
175
- self.elapsed_time = self.end_time - self.start_time
176
- if self.activity_name:
177
- print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
178
- else:
179
- print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
180
-
181
  def load_lora_opt(pipe, lora_input):
182
  lora_input = lora_input.strip()
183
  if not lora_input:
184
  return
185
-
186
  # If it's just an ID like "author/model"
187
  if "/" in lora_input and not lora_input.startswith("http"):
188
  pipe.load_lora_weights(lora_input, adapter_name="default")
189
  return
190
-
191
  if lora_input.startswith("http"):
192
  url = lora_input
193
  # Repo page (no blob/resolve)
@@ -195,11 +105,11 @@ def load_lora_opt(pipe, lora_input):
195
  repo_id = urlparse(url).path.strip("/")
196
  pipe.load_lora_weights(repo_id, adapter_name="default")
197
  return
198
-
199
  # Blob link → convert to resolve link
200
  if "/blob/" in url:
201
  url = url.replace("/blob/", "/resolve/")
202
-
203
  # Download direct file
204
  tmp_dir = tempfile.mkdtemp()
205
  local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
@@ -215,148 +125,47 @@ def load_lora_opt(pipe, lora_input):
215
  finally:
216
  shutil.rmtree(tmp_dir, ignore_errors=True)
217
 
218
- def update_selection(evt: gr.SelectData, width, height):
219
- selected_lora = loras[evt.index]
220
- new_placeholder = f"Type a prompt for {selected_lora['title']}"
221
- lora_repo = selected_lora["repo"]
222
- updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
223
-
224
- if "aspect" in selected_lora:
225
- if selected_lora["aspect"] == "portrait":
226
- width = 768
227
- height = 1024
228
- elif selected_lora["aspect"] == "landscape":
229
- width = 1024
230
- height = 768
231
- else:
232
- width = 1024
233
- height = 1024
234
-
235
- return (
236
- gr.update(placeholder=new_placeholder),
237
- updated_text,
238
- evt.index,
239
- width,
240
- height,
241
- )
242
-
243
- @spaces.GPU(duration=120)
244
- def generate_image(prompt_mash, negative_prompt, steps, seed, cfg_scale, width, height, lora_scale, progress):
245
- pipe.to("cuda")
246
- generator = torch.Generator(device="cuda").manual_seed(seed)
247
-
248
- with calculateDuration("Generating image"):
249
- # Generate image with live preview
250
- for img in pipe.qwen_pipe_call_that_returns_an_iterable_of_images(
251
- prompt=prompt_mash,
252
- negative_prompt=negative_prompt,
253
- num_inference_steps=steps,
254
- guidance_scale=cfg_scale,
255
- width=width,
256
- height=height,
257
- generator=generator,
258
- ):
259
- yield img
260
-
261
- def set_dimensions(ar):
262
- w, h = aspect_ratios[ar]
263
- return gr.update(value=w), gr.update(value=h)
264
-
265
- @spaces.GPU(duration=120)
266
- def run_lora(prompt, negative_prompt, use_negative_prompt, aspect_ratio, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
267
- if selected_index is None:
268
- raise gr.Error("You must select a LoRA before proceeding.🧨")
269
-
270
- selected_lora = loras[selected_index]
271
- lora_path = selected_lora["repo"]
272
- trigger_word = selected_lora["trigger_word"]
273
-
274
- # Set dimensions based on aspect ratio
275
- width, height = aspect_ratios[aspect_ratio]
276
-
277
- if trigger_word:
278
- if "trigger_position" in selected_lora:
279
- if selected_lora["trigger_position"] == "prepend":
280
- prompt_mash = f"{trigger_word} {prompt}"
281
- else:
282
- prompt_mash = f"{prompt} {trigger_word}"
283
- else:
284
- prompt_mash = f"{trigger_word} {prompt}"
285
- else:
286
- prompt_mash = prompt
287
-
288
- # Handle negative prompt
289
- final_negative_prompt = negative_prompt if use_negative_prompt else ""
290
-
291
- with calculateDuration("Unloading LoRA"):
292
- # Clear existing adapters
293
- current_adapters = pipe.get_list_adapters() if hasattr(pipe, 'get_list_adapters') else []
294
- for adapter in current_adapters:
295
- if hasattr(pipe, 'delete_adapters'):
296
- pipe.delete_adapters(adapter)
297
- if hasattr(pipe, 'disable_lora'):
298
- pipe.disable_lora()
299
-
300
- # Load new LoRA weights
301
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
302
- weight_name = selected_lora.get("weights", None)
303
- load_lora_opt(pipe, lora_path)
304
- if hasattr(pipe, 'set_adapters'):
305
- pipe.set_adapters(["default"], adapter_weights=[lora_scale])
306
-
307
- with calculateDuration("Randomizing seed"):
308
- if randomize_seed:
309
- seed = random.randint(0, MAX_SEED)
310
-
311
- image_generator = generate_image(prompt_mash, final_negative_prompt, steps, seed, cfg_scale, width, height, lora_scale, progress)
312
-
313
- final_image = None
314
- step_counter = 0
315
- for image in image_generator:
316
- step_counter += 1
317
- final_image = image
318
- progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
319
- yield image, seed, gr.update(value=progress_bar, visible=True)
320
-
321
- yield final_image, seed, gr.update(value=progress_bar, visible=False)
322
-
323
  def get_huggingface_safetensors(link):
324
  split_link = link.split("/")
325
  if len(split_link) == 2:
326
- model_card = ModelCard.load(link)
327
- base_model = model_card.data.get("base_model")
328
- print(base_model)
329
-
330
- # Allow Qwen models
331
- if base_model and "qwen" not in base_model.lower():
332
- raise Exception("Qwen-compatible LoRA Not Found!")
333
-
334
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
335
- trigger_word = model_card.data.get("instance_prompt", "")
336
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
337
-
338
- fs = HfFileSystem()
339
  try:
340
- list_of_files = fs.ls(link, detail=False)
341
- for file in list_of_files:
342
- if file.endswith(".safetensors"):
343
- safetensors_name = file.split("/")[-1]
344
- if not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
345
- image_elements = file.split("/")
346
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  except Exception as e:
348
- print(e)
349
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
350
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
351
-
352
- return split_link[1], link, safetensors_name, trigger_word, image_url
353
 
354
  def check_custom_model(link):
355
  if link.startswith("https://"):
356
  if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
357
  link_split = link.split("huggingface.co/")
358
  return get_huggingface_safetensors(link_split[1])
359
- else:
360
  return get_huggingface_safetensors(link)
361
 
362
  def add_custom_lora(custom_lora):
@@ -364,6 +173,9 @@ def add_custom_lora(custom_lora):
364
  if custom_lora:
365
  try:
366
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
 
 
 
367
  print(f"Loaded custom LoRA: {repo}")
368
  card = f'''
369
  <div class="custom_lora_card">
@@ -377,6 +189,7 @@ def add_custom_lora(custom_lora):
377
  </div>
378
  </div>
379
  '''
 
380
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
381
  if not existing_item_index:
382
  new_item = {
@@ -386,21 +199,161 @@ def add_custom_lora(custom_lora):
386
  "weights": path,
387
  "trigger_word": trigger_word
388
  }
389
- print(new_item)
390
  existing_item_index = len(loras)
391
  loras.append(new_item)
392
-
393
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
394
  except Exception as e:
395
- gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen compatible LoRA")
396
- return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen compatible LoRA"), gr.update(visible=False), gr.update(), "", None, ""
397
  else:
398
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
399
 
400
  def remove_custom_lora():
401
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
402
 
403
- run_lora.zerogpu = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  css = '''
406
  #gen_btn{height: 100%}
@@ -420,7 +373,7 @@ css = '''
420
  '''
421
 
422
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
423
- title = gr.HTML("""<h1>Qwen Image LoRA DLC🥳</h1>""", elem_id="title",)
424
  selected_index = gr.State(None)
425
 
426
  with gr.Row():
@@ -428,101 +381,127 @@ with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120))
428
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt")
429
  with gr.Column(scale=1, elem_id="gen_column"):
430
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
431
-
432
  with gr.Row():
433
  with gr.Column():
434
  selected_info = gr.Markdown("")
435
  gallery = gr.Gallery(
436
  [(item["image"], item["title"]) for item in loras],
437
- label="Qwen LoRA Collection",
438
  allow_preview=False,
439
  columns=3,
440
  elem_id="gallery",
441
  show_share_button=False
442
  )
443
-
444
  with gr.Group():
445
- custom_lora = gr.Textbox(label="Enter Custom Qwen LoRA", placeholder="prithivMLmods/Qwen-Image-Sketch-Smudge")
446
- gr.Markdown("[Check the list of Qwen-compatible LoRAs](https://huggingface.co/models?search=qwen+lora)", elem_id="lora_list")
447
-
448
  custom_lora_info = gr.HTML(visible=False)
449
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
450
-
451
  with gr.Column():
452
- progress_bar = gr.Markdown(elem_id="progress", visible=False)
453
- result = gr.Image(label="Generated Image", format="png")
454
-
455
  with gr.Row():
456
  aspect_ratio = gr.Dropdown(
457
  label="Aspect Ratio",
458
  choices=list(aspect_ratios.keys()),
459
  value="1:1",
460
  )
 
 
461
 
462
  with gr.Row():
463
  with gr.Accordion("Advanced Settings", open=False):
464
-
465
  with gr.Row():
466
  use_negative_prompt = gr.Checkbox(
467
- label="Use negative prompt", value=True, visible=True
 
468
  )
469
  negative_prompt = gr.Text(
470
  label="Negative prompt",
471
  max_lines=1,
472
  placeholder="Enter a negative prompt",
473
  value="text, watermark, copyright, blurry, low resolution",
474
- visible=True,
475
  )
476
 
477
- with gr.Column():
478
- with gr.Row():
479
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=4.0)
480
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=50)
481
-
482
- with gr.Row():
483
- width = gr.Slider(label="Width", minimum=256, maximum=2048, step=64, value=1024)
484
- height = gr.Slider(label="Height", minimum=256, maximum=2048, step=64, value=1024)
485
-
486
- with gr.Row():
487
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
488
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
489
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
490
-
491
- # Event handlers
492
- gallery.select(
493
- update_selection,
494
- inputs=[width, height],
495
- outputs=[prompt, selected_info, selected_index, width, height]
496
- )
497
-
 
 
 
 
 
 
 
498
  aspect_ratio.change(
499
  fn=set_dimensions,
500
  inputs=aspect_ratio,
501
  outputs=[width, height]
502
  )
503
-
 
504
  use_negative_prompt.change(
505
  fn=lambda x: gr.update(visible=x),
506
  inputs=use_negative_prompt,
507
  outputs=negative_prompt
508
  )
509
-
 
 
 
 
 
 
510
  custom_lora.input(
511
  add_custom_lora,
512
  inputs=[custom_lora],
513
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
514
  )
515
-
516
  custom_lora_button.click(
517
  remove_custom_lora,
518
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
519
  )
520
-
521
  gr.on(
522
  triggers=[generate_button.click, prompt.submit],
523
  fn=run_lora,
524
- inputs=[prompt, negative_prompt, use_negative_prompt, aspect_ratio, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
525
- outputs=[result, seed, progress_bar]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  )
527
 
528
  app.queue()
 
24
  import uuid
25
  import zipfile
26
 
27
+ # Helper functions
 
 
 
 
 
 
 
 
 
 
 
28
  def save_image(img):
29
  unique_name = str(uuid.uuid4()) + ".png"
30
  img.save(unique_name)
31
  return unique_name
32
 
33
+ MAX_SEED = np.iinfo(np.int32).max
34
+ MAX_IMAGE_SIZE = 2048
 
 
35
 
36
+ # Load Qwen/Qwen-Image pipeline
37
+ dtype = torch.bfloat16
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # Load Qwen model
41
+ pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device)
42
+
43
+ # Aspect ratios
44
+ aspect_ratios = {
45
+ "1:1": (1328, 1328),
46
+ "16:9": (1664, 928),
47
+ "9:16": (928, 1664),
48
+ "4:3": (1472, 1140),
49
+ "3:4": (1140, 1472)
50
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  loras = [
53
  # Sample Qwen-compatible LoRAs
 
88
  },
89
  ]
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def load_lora_opt(pipe, lora_input):
92
  lora_input = lora_input.strip()
93
  if not lora_input:
94
  return
95
+
96
  # If it's just an ID like "author/model"
97
  if "/" in lora_input and not lora_input.startswith("http"):
98
  pipe.load_lora_weights(lora_input, adapter_name="default")
99
  return
100
+
101
  if lora_input.startswith("http"):
102
  url = lora_input
103
  # Repo page (no blob/resolve)
 
105
  repo_id = urlparse(url).path.strip("/")
106
  pipe.load_lora_weights(repo_id, adapter_name="default")
107
  return
108
+
109
  # Blob link → convert to resolve link
110
  if "/blob/" in url:
111
  url = url.replace("/blob/", "/resolve/")
112
+
113
  # Download direct file
114
  tmp_dir = tempfile.mkdtemp()
115
  local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
 
125
  finally:
126
  shutil.rmtree(tmp_dir, ignore_errors=True)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_huggingface_safetensors(link):
129
  split_link = link.split("/")
130
  if len(split_link) == 2:
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  try:
132
+ response = requests.get(f"https://huggingface.co/api/models/{link}")
133
+ response.raise_for_status()
134
+ model_info = response.json()
135
+
136
+ # Check if it's a Qwen model
137
+ if "qwen" not in model_info.get("tags", []):
138
+ raise Exception("Not a Qwen LoRA model!")
139
+
140
+ # Get image if available
141
+ image_url = None
142
+ if "cardData" in model_info and "widget" in model_info["cardData"]:
143
+ if len(model_info["cardData"]["widget"]) > 0:
144
+ image_url = model_info["cardData"]["widget"][0].get("output", {}).get("url", None)
145
+
146
+ # Try to find safetensors file
147
+ safetensors_name = None
148
+ try:
149
+ model_files = requests.get(f"https://huggingface.co/api/models/{link}/tree/main").json()
150
+ for file in model_files:
151
+ if file.get("path", "").endswith(".safetensors"):
152
+ safetensors_name = file["path"]
153
+ break
154
+ except:
155
+ pass
156
+
157
+ return split_link[1], link, safetensors_name, "trigger_word", image_url
158
  except Exception as e:
159
+ print(f"Error getting model info: {e}")
160
+ raise Exception(f"Failed to get model info: {e}")
161
+ return None, None, None, None, None
 
 
162
 
163
  def check_custom_model(link):
164
  if link.startswith("https://"):
165
  if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
166
  link_split = link.split("huggingface.co/")
167
  return get_huggingface_safetensors(link_split[1])
168
+ else:
169
  return get_huggingface_safetensors(link)
170
 
171
  def add_custom_lora(custom_lora):
 
173
  if custom_lora:
174
  try:
175
  title, repo, path, trigger_word, image = check_custom_model(custom_lora)
176
+ if not title:
177
+ raise Exception("Invalid LoRA model")
178
+
179
  print(f"Loaded custom LoRA: {repo}")
180
  card = f'''
181
  <div class="custom_lora_card">
 
189
  </div>
190
  </div>
191
  '''
192
+
193
  existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
194
  if not existing_item_index:
195
  new_item = {
 
199
  "weights": path,
200
  "trigger_word": trigger_word
201
  }
 
202
  existing_item_index = len(loras)
203
  loras.append(new_item)
204
+
205
  return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
206
  except Exception as e:
207
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen LoRA")
208
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen LoRA"), gr.update(visible=False), gr.update(), "", None, ""
209
  else:
210
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
211
 
212
  def remove_custom_lora():
213
  return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
214
 
215
+ def update_selection(evt: gr.SelectData, width, height):
216
+ selected_lora = loras[evt.index]
217
+ new_placeholder = f"Type a prompt for {selected_lora['title']}"
218
+ lora_repo = selected_lora["repo"]
219
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✅"
220
+
221
+ # Update aspect ratio based on LoRA if it has aspect info
222
+ if "aspect" in selected_lora:
223
+ if selected_lora["aspect"] == "portrait":
224
+ width = 928
225
+ height = 1664
226
+ elif selected_lora["aspect"] == "landscape":
227
+ width = 1664
228
+ height = 928
229
+ else:
230
+ width = 1328
231
+ height = 1328
232
+
233
+ return (
234
+ gr.update(placeholder=new_placeholder),
235
+ updated_text,
236
+ evt.index,
237
+ width,
238
+ height,
239
+ )
240
+
241
+ @spaces.GPU(duration=120)
242
+ def generate_qwen(
243
+ prompt: str,
244
+ negative_prompt: str = "",
245
+ seed: int = 0,
246
+ width: int = 1024,
247
+ height: int = 1024,
248
+ guidance_scale: float = 4.0,
249
+ randomize_seed: bool = False,
250
+ num_inference_steps: int = 50,
251
+ num_images: int = 1,
252
+ zip_images: bool = False,
253
+ lora_input: str = "",
254
+ lora_scale: float = 1.0,
255
+ progress=gr.Progress(track_tqdm=True),
256
+ ):
257
+ if randomize_seed:
258
+ seed = random.randint(0, MAX_SEED)
259
+
260
+ generator = torch.Generator(device).manual_seed(seed)
261
+
262
+ start_time = time.time()
263
+
264
+ # Clear any existing LoRA adapters
265
+ current_adapters = pipe.get_list_adapters()
266
+ for adapter in current_adapters:
267
+ pipe.delete_adapters(adapter)
268
+ pipe.disable_lora()
269
+
270
+ use_lora = False
271
+ if lora_input and lora_input.strip() != "":
272
+ load_lora_opt(pipe, lora_input)
273
+ pipe.set_adapters(["default"], adapter_weights=[lora_scale])
274
+ use_lora = True
275
+
276
+ images = pipe(
277
+ prompt=prompt,
278
+ negative_prompt=negative_prompt if negative_prompt else "",
279
+ height=height,
280
+ width=width,
281
+ guidance_scale=guidance_scale,
282
+ num_inference_steps=num_inference_steps,
283
+ num_images_per_prompt=num_images,
284
+ generator=generator,
285
+ output_type="pil",
286
+ ).images
287
+
288
+ end_time = time.time()
289
+ duration = end_time - start_time
290
+
291
+ image_paths = [save_image(img) for img in images]
292
+ zip_path = None
293
+ if zip_images:
294
+ zip_name = str(uuid.uuid4()) + ".zip"
295
+ with zipfile.ZipFile(zip_name, 'w') as zipf:
296
+ for i, img_path in enumerate(image_paths):
297
+ zipf.write(img_path, arcname=f"Img_{i}.png")
298
+ zip_path = zip_name
299
+
300
+ # Clean up adapters
301
+ current_adapters = pipe.get_list_adapters()
302
+ for adapter in current_adapters:
303
+ pipe.delete_adapters(adapter)
304
+ pipe.disable_lora()
305
+
306
+ return image_paths, seed, f"{duration:.2f}", zip_path
307
+
308
+ @spaces.GPU(duration=120)
309
+ def run_lora(
310
+ prompt: str,
311
+ negative_prompt: str,
312
+ use_negative_prompt: bool,
313
+ seed: int,
314
+ width: int,
315
+ height: int,
316
+ guidance_scale: float,
317
+ randomize_seed: bool,
318
+ num_inference_steps: int,
319
+ num_images: int,
320
+ zip_images: bool,
321
+ selected_index: int,
322
+ lora_scale: float,
323
+ progress=gr.Progress(track_tqdm=True),
324
+ ):
325
+ if selected_index is None:
326
+ raise gr.Error("You must select a LoRA before proceeding.🧨")
327
+
328
+ selected_lora = loras[selected_index]
329
+ lora_repo = selected_lora["repo"]
330
+ trigger_word = selected_lora["trigger_word"]
331
+
332
+ if trigger_word:
333
+ prompt_mash = f"{trigger_word} {prompt}"
334
+ else:
335
+ prompt_mash = prompt
336
+
337
+ final_negative_prompt = negative_prompt if use_negative_prompt else ""
338
+
339
+ if randomize_seed:
340
+ seed = random.randint(0, MAX_SEED)
341
+
342
+ return generate_qwen(
343
+ prompt=prompt_mash,
344
+ negative_prompt=final_negative_prompt,
345
+ seed=seed,
346
+ width=width,
347
+ height=height,
348
+ guidance_scale=guidance_scale,
349
+ randomize_seed=False, # Already handled
350
+ num_inference_steps=num_inference_steps,
351
+ num_images=num_images,
352
+ zip_images=zip_images,
353
+ lora_input=lora_repo,
354
+ lora_scale=lora_scale,
355
+ progress=progress,
356
+ )
357
 
358
  css = '''
359
  #gen_btn{height: 100%}
 
373
  '''
374
 
375
  with gr.Blocks(theme="bethecloud/storj_theme", css=css, delete_cache=(120, 120)) as app:
376
+ title = gr.HTML("""<h1>Qwen Image LoRA DLC ❤️‍🔥</h1>""", elem_id="title")
377
  selected_index = gr.State(None)
378
 
379
  with gr.Row():
 
381
  prompt = gr.Textbox(label="Prompt", lines=1, placeholder="✦︎ Choose the LoRA and type the prompt")
382
  with gr.Column(scale=1, elem_id="gen_column"):
383
  generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
384
+
385
  with gr.Row():
386
  with gr.Column():
387
  selected_info = gr.Markdown("")
388
  gallery = gr.Gallery(
389
  [(item["image"], item["title"]) for item in loras],
390
+ label="Qwen LoRA DLC's",
391
  allow_preview=False,
392
  columns=3,
393
  elem_id="gallery",
394
  show_share_button=False
395
  )
 
396
  with gr.Group():
397
+ custom_lora = gr.Textbox(label="Enter Custom LoRA", placeholder="prithivMLmods/Qwen-Image-Sketch-Smudge")
398
+ gr.Markdown("[Check the list of Qwen LoRA's](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
 
399
  custom_lora_info = gr.HTML(visible=False)
400
  custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
401
+
402
  with gr.Column():
403
+ result = gr.Gallery(label="Generated Images", columns=1, show_label=False, preview=True)
 
 
404
  with gr.Row():
405
  aspect_ratio = gr.Dropdown(
406
  label="Aspect Ratio",
407
  choices=list(aspect_ratios.keys()),
408
  value="1:1",
409
  )
410
+ with gr.Row():
411
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=48)
412
 
413
  with gr.Row():
414
  with gr.Accordion("Advanced Settings", open=False):
415
+
416
  with gr.Row():
417
  use_negative_prompt = gr.Checkbox(
418
+ label="Use negative prompt",
419
+ value=True,
420
  )
421
  negative_prompt = gr.Text(
422
  label="Negative prompt",
423
  max_lines=1,
424
  placeholder="Enter a negative prompt",
425
  value="text, watermark, copyright, blurry, low resolution",
 
426
  )
427
 
428
+ with gr.Row():
429
+ cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=4.0)
430
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=50)
431
+
432
+ with gr.Row():
433
+ width = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1328)
434
+ height = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1328)
435
+
436
+ with gr.Row():
437
+ num_images = gr.Slider(label="Number of Images", minimum=1, maximum=5, step=1, value=1)
438
+ zip_images = gr.Checkbox(label="Zip generated images", value=False)
439
+
440
+ with gr.Row():
441
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
442
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
443
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1.0)
444
+
445
+ # Output information
446
+ with gr.Row():
447
+ seed_display = gr.Textbox(label="Seed used", interactive=False)
448
+ generation_time = gr.Textbox(label="Generation time (seconds)", interactive=False)
449
+ zip_file = gr.File(label="Download ZIP")
450
+
451
+ # Update aspect ratio
452
+ def set_dimensions(ar):
453
+ w, h = aspect_ratios[ar]
454
+ return gr.update(value=w), gr.update(value=h)
455
+
456
  aspect_ratio.change(
457
  fn=set_dimensions,
458
  inputs=aspect_ratio,
459
  outputs=[width, height]
460
  )
461
+
462
+ # Negative prompt visibility
463
  use_negative_prompt.change(
464
  fn=lambda x: gr.update(visible=x),
465
  inputs=use_negative_prompt,
466
  outputs=negative_prompt
467
  )
468
+
469
+ gallery.select(
470
+ update_selection,
471
+ inputs=[width, height],
472
+ outputs=[prompt, selected_info, selected_index, width, height]
473
+ )
474
+
475
  custom_lora.input(
476
  add_custom_lora,
477
  inputs=[custom_lora],
478
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
479
  )
480
+
481
  custom_lora_button.click(
482
  remove_custom_lora,
483
  outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
484
  )
485
+
486
  gr.on(
487
  triggers=[generate_button.click, prompt.submit],
488
  fn=run_lora,
489
+ inputs=[
490
+ prompt,
491
+ negative_prompt,
492
+ use_negative_prompt,
493
+ seed,
494
+ width,
495
+ height,
496
+ #guidance_scale,
497
+ randomize_seed,
498
+ steps,
499
+ num_images,
500
+ zip_images,
501
+ selected_index,
502
+ lora_scale,
503
+ ],
504
+ outputs=[result, seed_display, generation_time, zip_file]
505
  )
506
 
507
  app.queue()