prithivMLmods commited on
Commit
c314fac
·
verified ·
1 Parent(s): 0c8e12c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -273
app.py CHANGED
@@ -7,422 +7,363 @@ import json
7
  import os
8
  from PIL import Image
9
  from diffusers import FluxKontextPipeline
10
- from diffusers.utils import load_image
11
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, list_repo_files
12
  from safetensors.torch import load_file
13
  import requests
14
  import re
15
 
16
- # Load Kontext model
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
 
21
- # Load LoRA data
22
- flux_loras_raw = [
23
- {
24
- "image": "https://huggingface.co/prithivMLmods/FLUX.1-Kontext-Cinematic-Relighting/resolve/main/images/1.png",
25
- "title": "Kontext Cinematic Relighting",
26
- "repo": "prithivMLmods/FLUX.1-Kontext-Cinematic-Relighting",
27
- "trigger_word": "Cinematic Relighting, Relight this portrait with warm, cinematic indoor lighting. Add soft amber highlights and gentle shadows to the face mimicking golden-hour light through a cozy room. Maintain natural skin texture and soft facial shadows, while enhancing eye catchlights for a vivid, lifelike look. Adjust white balance to a warmer tone, and slightly boost exposure to soften the darker midtones. Preserve the subject's pose and expression, and enhance the depth with gentle background bokeh and subtle filmic glow.",
28
- "weights": "FLUX.1-Kontext-Cinematic-Relighting.safetensors"
29
- },
30
- ]
31
- print(f"Loaded {len(flux_loras_raw)} LoRAs")
32
- # Global variables for LoRA management
33
- current_lora = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  lora_cache = {}
35
 
36
  def load_lora_weights(repo_id, weights_filename):
37
- """Load LoRA weights from HuggingFace"""
38
  try:
39
- # First try with the specified filename
40
- try:
41
  lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
42
- if repo_id not in lora_cache:
43
- lora_cache[repo_id] = lora_path
44
- return lora_path
45
- except Exception as e:
46
- print(f"Failed to load {weights_filename}, trying to find alternative LoRA files...")
47
-
48
- # If the specified file doesn't exist, try to find any .safetensors file
49
- from huggingface_hub import list_repo_files
50
- try:
51
- files = list_repo_files(repo_id)
52
- safetensors_files = [f for f in files if f.endswith(('.safetensors', '.bin')) and 'lora' in f.lower()]
53
-
54
- if not safetensors_files:
55
- # Try without 'lora' in filename
56
- safetensors_files = [f for f in files if f.endswith('.safetensors')]
57
-
58
- if safetensors_files:
59
- # Try the first available file
60
- for file in safetensors_files:
61
- try:
62
- print(f"Trying alternative file: {file}")
63
- lora_path = hf_hub_download(repo_id=repo_id, filename=file)
64
- if repo_id not in lora_cache:
65
- lora_cache[repo_id] = lora_path
66
- print(f"Successfully loaded alternative LoRA file: {file}")
67
- return lora_path
68
- except:
69
- continue
70
-
71
- print(f"No suitable LoRA files found in {repo_id}")
72
- return None
73
-
74
- except Exception as list_error:
75
- print(f"Error listing files in repo {repo_id}: {list_error}")
76
- return None
77
-
78
  except Exception as e:
79
- print(f"Error loading LoRA from {repo_id}: {e}")
80
  return None
81
 
82
- def update_selection(selected_state: gr.SelectData, flux_loras):
83
- """Update UI when a LoRA is selected"""
84
- if selected_state.index >= len(flux_loras):
85
- return "### No LoRA selected", gr.update(), None
86
 
87
- lora = flux_loras[selected_state.index]
88
- lora_title = lora["title"]
89
- lora_repo = lora["repo"]
90
- trigger_word = lora["trigger_word"]
91
 
92
- # Create a more informative selected text
93
- updated_text = f"### 🎨 Selected Style: {lora_title}"
94
- new_placeholder = f"Describe additional details, e.g., 'wearing a red hat' or 'smiling'"
95
 
96
  return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
97
 
98
- def get_huggingface_lora(link):
99
- """Download LoRA from HuggingFace link"""
100
  split_link = link.split("/")
101
  if len(split_link) == 2:
102
  try:
103
  model_card = ModelCard.load(link)
104
  trigger_word = model_card.data.get("instance_prompt", "")
105
 
106
- # Try to find the correct safetensors file
107
- files = list_repo_files(link)
108
- safetensors_files = [f for f in files if f.endswith('.safetensors')]
109
 
110
- # Prioritize files with 'lora' in the name
111
- lora_files = [f for f in safetensors_files if 'lora' in f.lower()]
112
- if lora_files:
113
- safetensors_file = lora_files[0]
114
- elif safetensors_files:
115
- safetensors_file = safetensors_files[0]
116
- else:
117
- # Try .bin files as fallback
118
- bin_files = [f for f in files if f.endswith('.bin') and 'lora' in f.lower()]
119
- if bin_files:
120
- safetensors_file = bin_files[0]
121
- else:
122
- safetensors_file = "pytorch_lora_weights.safetensors" # Default fallback
123
 
124
- print(f"Found LoRA file: {safetensors_file} in {link}")
125
- return split_link[1], safetensors_file, trigger_word
126
 
 
127
  except Exception as e:
128
- print(f"Error in get_huggingface_lora: {e}")
129
- # Try basic detection
130
- try:
131
- files = list_repo_files(link)
132
- safetensors_file = next((f for f in files if f.endswith('.safetensors')), "pytorch_lora_weights.safetensors")
133
- return split_link[1], safetensors_file, ""
134
- except:
135
- raise Exception(f"Error loading LoRA: {e}")
136
  else:
137
  raise Exception("Invalid HuggingFace repository format")
138
 
139
- def load_custom_lora(link):
140
- """Load custom LoRA from user input"""
141
  if not link:
142
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### 🎨 Select an art style from the gallery", None
143
 
144
  try:
145
- repo_name, weights_file, trigger_word = get_huggingface_lora(link)
146
 
147
  card = f'''
148
- <div class="custom_lora_card">
149
- <div style="display: flex; align-items: center; margin-bottom: 12px;">
150
- <span style="font-size: 18px; margin-right: 8px;">✅</span>
151
- <strong style="font-size: 16px;">Custom LoRA Loaded!</strong>
152
- </div>
153
- <div style="background: rgba(255, 255, 255, 0.8); padding: 12px; border-radius: 8px;">
154
- <h4 style="margin: 0 0 8px 0; color: #333;">{repo_name}</h4>
155
- <small style="color: #666;">{"Trigger: <code style='background: #f0f0f0; padding: 2px 6px; border-radius: 4px;'><b>"+trigger_word+"</b></code>" if trigger_word else "No trigger word found"}</small>
156
  </div>
157
  </div>
158
  '''
159
 
160
- custom_lora_data = {
161
  "repo": link,
162
  "weights": weights_file,
163
  "trigger_word": trigger_word
164
  }
165
 
166
- return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"🎨 Custom Style: {repo_name}", None
167
 
168
  except Exception as e:
169
- return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### 🎨 Select an art style from the gallery", None
170
 
171
- def remove_custom_lora():
172
- """Remove custom LoRA"""
173
  return "", gr.update(visible=False), gr.update(visible=False), None, None
174
 
175
- def classify_gallery(flux_loras):
176
- """Sort gallery by likes"""
177
- try:
178
- sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
179
- gallery_items = []
180
-
181
- for item in sorted_gallery:
182
- if "image" in item and "title" in item:
183
- image_path = item["image"]
184
- title = item["title"]
185
-
186
- # Simply use the path as-is for Gradio to handle
187
- gallery_items.append((image_path, title))
188
- print(f"Added to gallery: {image_path} - {title}")
189
-
190
- print(f"Total gallery items: {len(gallery_items)}")
191
- return gallery_items, sorted_gallery
192
- except Exception as e:
193
- print(f"Error in classify_gallery: {e}")
194
- import traceback
195
- traceback.print_exc()
196
- return [], []
197
 
198
- def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
199
- """Wrapper function to handle state serialization"""
200
- return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
201
 
202
  @spaces.GPU
203
- def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
204
- """Generate image with selected LoRA"""
205
- global current_lora, pipe
206
-
207
- # Check if input image is provided
208
- if input_image is None:
209
- gr.Warning("Please upload your portrait photo first! 📸")
210
- return None, seed, gr.update(visible=False)
211
 
212
  if randomize_seed:
213
  seed = random.randint(0, MAX_SEED)
214
 
215
- # Determine which LoRA to use
216
  lora_to_use = None
217
- if custom_lora:
218
- lora_to_use = custom_lora
219
- elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
220
- lora_to_use = flux_loras[selected_index]
221
- # Load LoRA if needed
222
- if lora_to_use and lora_to_use != current_lora:
 
 
223
  try:
224
- # Unload current LoRA
225
- if current_lora:
226
- pipe.unload_lora_weights()
227
- print(f"Unloaded previous LoRA")
228
 
229
- # Load new LoRA
230
- repo_id = lora_to_use.get("repo", "unknown")
231
- weights_file = lora_to_use.get("weights", "pytorch_lora_weights.safetensors")
232
- print(f"Loading LoRA: {repo_id} with weights: {weights_file}")
233
-
234
- lora_path = load_lora_weights(repo_id, weights_file)
235
  if lora_path:
236
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
237
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
238
- print(f"Successfully loaded: {lora_path} with scale {lora_scale}")
239
- current_lora = lora_to_use
240
- else:
241
- print(f"Failed to load LoRA from {repo_id}")
242
- gr.Warning(f"Failed to load {lora_to_use.get('title', 'style')}. Please try a different art style.")
243
- return None, seed, gr.update(visible=False)
244
 
245
  except Exception as e:
246
- print(f"Error loading LoRA: {e}")
247
- # Continue without LoRA
248
  else:
249
- if lora_to_use:
250
- print(f"Using already loaded LoRA: {lora_to_use.get('repo', 'unknown')}")
251
-
252
- try:
253
- # Convert image to RGB
254
- input_image = input_image.convert("RGB")
255
- except Exception as e:
256
- print(f"Error processing image: {e}")
257
- gr.Warning("Error processing the uploaded image. Please try a different photo. 📸")
258
- return None, seed, gr.update(visible=False)
259
-
260
- # Check if LoRA is selected
261
- if lora_to_use is None:
262
- gr.Warning("Please select an art style from the gallery first! 🎨")
263
- return None, seed, gr.update(visible=False)
264
 
265
- # Add trigger word to prompt
266
- trigger_word = lora_to_use.get("trigger_word", "")
267
-
268
- # Special handling for different art styles
269
- if trigger_word == "ghibli":
270
- prompt = f"Create a Studio Ghibli anime style portrait of the person in the photo, {prompt}. Maintain the facial identity while transforming into whimsical anime art style."
271
- elif trigger_word == "homer":
272
- prompt = f"Paint the person in Winslow Homer's American realist style, {prompt}. Keep facial features while applying watercolor and marine art techniques."
273
- elif trigger_word == "gogh":
274
- prompt = f"Transform the portrait into Van Gogh's post-impressionist style with swirling brushstrokes, {prompt}. Maintain facial identity with expressive colors."
275
- elif trigger_word == "Cezanne":
276
- prompt = f"Render the person in Paul Cézanne's geometric post-impressionist style, {prompt}. Keep facial structure while applying structured brushwork."
277
- elif trigger_word == "Renoir":
278
- prompt = f"Paint the portrait in Pierre-Auguste Renoir's impressionist style with soft light, {prompt}. Maintain identity with luminous skin tones."
279
- elif trigger_word == "claude monet":
280
- prompt = f"Create an impressionist portrait in Claude Monet's style with visible brushstrokes, {prompt}. Keep facial features while using light and color."
281
- elif trigger_word == "fantasy":
282
- prompt = f"Transform into an epic fantasy character portrait, {prompt}. Maintain facial identity while adding magical and fantastical elements."
283
- elif trigger_word == ", How2Draw":
284
  prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
285
- elif trigger_word == ", video game screenshot in the style of THSMS":
286
- prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
287
  else:
288
- prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
289
 
290
  try:
291
  image = pipe(
292
  image=input_image,
293
  prompt=prompt,
294
  guidance_scale=guidance_scale,
 
295
  generator=torch.Generator().manual_seed(seed),
 
 
 
296
  ).images[0]
297
 
298
  return image, seed, gr.update(visible=True)
299
 
300
  except Exception as e:
301
- print(f"Error during inference: {e}")
302
  return None, seed, gr.update(visible=False)
303
 
304
- # CSS styling with beautiful gradient pastel design
305
- css = '''
306
- #gen_btn{height: 100%}
307
- #gen_column{align-self: stretch}
308
- #title{text-align: center}
309
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
310
- #title img{width: 100px; margin-right: 0.5em}
311
- #gallery .grid-wrap{height: 10vh}
312
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
313
- .card_internal{display: flex;height: 100px;margin-top: .5em}
314
- .card_internal img{margin-right: 1em}
315
- .styler{--form-gap-width: 0px !important}
316
- #progress{height:30px}
317
- #progress .generating{display:none}
318
- .progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
319
- .progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
320
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- # Create Gradio interface
323
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
324
- gr_flux_loras = gr.State(value=flux_loras_raw)
325
 
326
  title = gr.HTML(
327
- """<h1>Flux Kontext DLC 🎈</h1>""",
 
328
  )
329
 
330
  selected_state = gr.State(value=None)
331
- custom_loaded_lora = gr.State(value=None)
332
 
333
- with gr.Row(elem_id="main_app"):
334
- with gr.Column(scale=4, elem_id="box_column"):
335
- with gr.Group(elem_id="gallery_box"):
336
- input_image = gr.Image(label="Upload an image for editing", type="pil", height=260)
337
 
338
  gallery = gr.Gallery(
339
- label="Choose the Flux Kontext LoRA",
340
  allow_preview=False,
341
  columns=3,
342
- elem_id="gallery",
343
  show_share_button=False,
344
  height=400
345
  )
346
 
347
- custom_model = gr.Textbox(
348
- label="🔗 Or use a custom LoRA from HuggingFace",
349
- placeholder="e.g., username/lora-name",
350
  visible=True
351
  )
352
- custom_model_card = gr.HTML(visible=False)
353
- custom_model_button = gr.Button("Remove custom LoRA", visible=False)
354
 
355
  with gr.Column(scale=5):
356
  with gr.Row():
357
  prompt = gr.Textbox(
358
- label="Additional Details (optional)",
359
  show_label=False,
360
  lines=1,
361
  max_lines=1,
362
- placeholder="Describe additional details, e.g., 'wearing a red hat' or 'smiling'",
363
- elem_id="prompt"
364
  )
365
- run_button = gr.Button("Edit Image", elem_id="run_button")
366
 
367
- result = gr.Image(label="Your Kontext Edited Image", interactive=False)
368
  reuse_button = gr.Button("Reuse this image", visible=False)
369
 
370
- with gr.Accordion("Advanced Settings", open=False):
371
  lora_scale = gr.Slider(
372
- label="Style Strength",
373
  minimum=0,
374
  maximum=2,
375
  step=0.1,
376
- value=1.0,
377
- info="How strongly to apply the art style (1.0 = balanced)"
378
  )
379
  seed = gr.Slider(
380
- label="Random Seed",
381
  minimum=0,
382
  maximum=MAX_SEED,
383
  step=1,
384
  value=0,
385
- info="Set to 0 for random results"
386
  )
387
- randomize_seed = gr.Checkbox(label="Randomize seed for each generation", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  guidance_scale = gr.Slider(
389
- label="Image Guidance",
390
  minimum=1,
391
  maximum=10,
392
  step=0.1,
393
- value=2.5,
394
- info="How closely to follow the input image (lower = more creative)"
395
  )
396
 
397
  prompt_title = gr.Markdown(
398
- value="### Select an art style from the gallery",
399
  visible=True,
400
- elem_id="selected_lora",
401
  )
402
 
403
  # Event handlers
404
- custom_model.input(
405
- fn=load_custom_lora,
406
- inputs=[custom_model],
407
- outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
408
  )
409
 
410
- custom_model_button.click(
411
- fn=remove_custom_lora,
412
- outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
413
  )
414
 
415
  gallery.select(
416
- fn=update_selection,
417
- inputs=[gr_flux_loras],
418
  outputs=[prompt_title, prompt, selected_state],
419
  show_progress=False
420
  )
421
 
422
  gr.on(
423
  triggers=[run_button.click, prompt.submit],
424
- fn=infer_with_lora_wrapper,
425
- inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
426
  outputs=[result, seed, reuse_button]
427
  )
428
 
@@ -431,11 +372,12 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
431
  inputs=[result],
432
  outputs=[input_image]
433
  )
434
-
 
435
  demo.load(
436
- fn=classify_gallery,
437
- inputs=[gr_flux_loras],
438
- outputs=[gallery, gr_flux_loras]
439
  )
440
 
441
  demo.queue(default_concurrency_limit=None)
 
7
  import os
8
  from PIL import Image
9
  from diffusers import FluxKontextPipeline
10
+ from diffusers.utils import load_image, peft_utils
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
  from safetensors.torch import load_file
13
  import requests
14
  import re
15
 
16
+ # Load the base model
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
  pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
20
 
21
+ try: # Temporary workaround for diffusers LoRA loading issue
22
+ from diffusers.utils.peft_utils import _derive_exclude_modules
23
+
24
+ def new_derive_exclude_modules(*args, **kwargs):
25
+ exclude_modules = _derive_exclude_modules(*args, **kwargs)
26
+ if exclude_modules is not None:
27
+ exclude_modules = [n for n in exclude_modules if "proj_out" not in n]
28
+ return exclude_modules
29
+ peft_utils._derive_exclude_modules = new_derive_exclude_modules
30
+ except:
31
+ pass
32
+
33
+ # Load LoRA configurations from JSON
34
+ with open("lora_configs.json", "r") as file:
35
+ data = json.load(file)
36
+ lora_configs = [
37
+ {
38
+ "image": item["image"],
39
+ "title": item["title"],
40
+ "repo": item["repo"],
41
+ "trigger_word": item.get("trigger_word", ""),
42
+ "trigger_position": item.get("trigger_position", "prepend"),
43
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
44
+ }
45
+ for item in data
46
+ ]
47
+ print(f"Loaded {len(lora_configs)} LoRAs from JSON")
48
+
49
+ # Global variables for adapter management
50
+ active_lora_adapter = None
51
  lora_cache = {}
52
 
53
  def load_lora_weights(repo_id, weights_filename):
54
+ """Load adapter weights from HuggingFace"""
55
  try:
56
+ if repo_id not in lora_cache:
 
57
  lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
58
+ lora_cache[repo_id] = lora_path
59
+ return lora_cache[repo_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  except Exception as e:
61
+ print(f"Error loading adapter from {repo_id}: {e}")
62
  return None
63
 
64
+ def on_lora_select(selected_state: gr.SelectData, lora_configs):
65
+ """Update UI when an adapter is selected"""
66
+ if selected_state.index >= len(lora_configs):
67
+ return "### No adapter selected", gr.update(), None
68
 
69
+ lora_repo = lora_configs[selected_state.index]["repo"]
70
+ trigger_word = lora_configs[selected_state.index]["trigger_word"]
 
 
71
 
72
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
73
+ new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
 
74
 
75
  return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
76
 
77
+ def fetch_lora_from_hf(link):
78
+ """Retrieve adapter from HuggingFace link"""
79
  split_link = link.split("/")
80
  if len(split_link) == 2:
81
  try:
82
  model_card = ModelCard.load(link)
83
  trigger_word = model_card.data.get("instance_prompt", "")
84
 
85
+ fs = HfFileSystem()
86
+ list_of_files = fs.ls(link, detail=False)
87
+ safetensors_file = None
88
 
89
+ for file in list_of_files:
90
+ if file.endswith(".safetensors") and "lora" in file.lower():
91
+ safetensors_file = file.split("/")[-1]
92
+ break
 
 
 
 
 
 
 
 
 
93
 
94
+ if not safetensors_file:
95
+ safetensors_file = "pytorch_lora_weights.safetensors"
96
 
97
+ return split_link[1], safetensors_file, trigger_word
98
  except Exception as e:
99
+ raise Exception(f"Error loading adapter: {e}")
 
 
 
 
 
 
 
100
  else:
101
  raise Exception("Invalid HuggingFace repository format")
102
 
103
+ def load_user_lora(link):
104
+ """Load a user-provided adapter"""
105
  if not link:
106
+ return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on an adapter in the gallery to select it", None
107
 
108
  try:
109
+ repo_name, weights_file, trigger_word = fetch_lora_from_hf(link)
110
 
111
  card = f'''
112
+ <div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
113
+ <span><strong>Loaded custom adapter:</strong></span>
114
+ <div style="margin-top: 8px;">
115
+ <h4>{repo_name}</h4>
116
+ <small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
 
 
 
117
  </div>
118
  </div>
119
  '''
120
 
121
+ user_lora_data = {
122
  "repo": link,
123
  "weights": weights_file,
124
  "trigger_word": trigger_word
125
  }
126
 
127
+ return gr.update(visible=True), card, gr.update(visible=True), user_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
128
 
129
  except Exception as e:
130
+ return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on an adapter in the gallery to select it", None
131
 
132
+ def unload_user_lora():
133
+ """Remove the user-provided adapter"""
134
  return "", gr.update(visible=False), gr.update(visible=False), None, None
135
 
136
+ def sort_lora_gallery(lora_configs):
137
+ """Sort the adapter gallery by likes"""
138
+ sorted_gallery = sorted(lora_configs, key=lambda x: x.get("likes", 0), reverse=True)
139
+ return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ def generate_image_wrapper(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.75, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
142
+ """Wrapper for image generation to handle state"""
143
+ return generate_image(input_image, prompt, selected_index, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, lora_configs, progress)
144
 
145
  @spaces.GPU
146
+ def generate_image(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.0, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
147
+ """Generate an image using the selected adapter"""
148
+ global active_lora_adapter, pipe
 
 
 
 
 
149
 
150
  if randomize_seed:
151
  seed = random.randint(0, MAX_SEED)
152
 
153
+ # Select the adapter to use
154
  lora_to_use = None
155
+ if user_lora:
156
+ lora_to_use = user_lora
157
+ elif selected_index is not None and lora_configs and selected_index < len(lora_configs):
158
+ lora_to_use = lora_configs[selected_index]
159
+ print(f"Loaded {len(lora_configs)} adapters from JSON")
160
+
161
+ # Load the adapter if necessary
162
+ if lora_to_use and lora_to_use != active_lora_adapter:
163
  try:
164
+ if active_lora_adapter:
165
+ pipe разгрузить_веса_lora()
 
 
166
 
167
+ lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
 
 
 
 
 
168
  if lora_path:
169
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
170
  pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
171
+ print(f"loaded: {lora_path} with scale {lora_scale}")
172
+ active_lora_adapter = lora_to_use
 
 
 
 
173
 
174
  except Exception as e:
175
+ print(f"Error loading adapter: {e}")
 
176
  else:
177
+ print(f"using already loaded adapter: {lora_to_use}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ input_image = input_image.convert("RGB")
180
+ # Modify prompt based on trigger word
181
+ trigger_word = lora_to_use["trigger_word"]
182
+ if trigger_word == ", How2Draw":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
184
+ elif trigger_word == "__ ":
185
+ prompt = f" {prompt}. Accurately render the toolimpact logo and any tool impact iconography. The toolimpact logo begins with a two-line-tall drop-cap capital letter T with a dot in the center of its top bar."
186
  else:
187
+ prompt = f" {prompt}. convert the style of this photo or image to {trigger_word}. Maintain the facial identity of any persons and the general features of the image!"
188
 
189
  try:
190
  image = pipe(
191
  image=input_image,
192
  prompt=prompt,
193
  guidance_scale=guidance_scale,
194
+ num_inference_steps=steps,
195
  generator=torch.Generator().manual_seed(seed),
196
+ width=width,
197
+ height=height,
198
+ max_area=width * height
199
  ).images[0]
200
 
201
  return image, seed, gr.update(visible=True)
202
 
203
  except Exception as e:
204
+ print(f"Error during generation: {e}")
205
  return None, seed, gr.update(visible=False)
206
 
207
+ # CSS styling
208
+ css = """
209
+ #app_container {
210
+ display: flex;
211
+ gap: 20px;
212
+ }
213
+ #left_panel {
214
+ min-width: 400px;
215
+ }
216
+ #lora_info {
217
+ color: #2563eb;
218
+ font-weight: bold;
219
+ }
220
+ #edit_prompt {
221
+ flex-grow: 1;
222
+ }
223
+ #generate_button {
224
+ background: linear-gradient(45deg, #2563eb, #3b82f6);
225
+ color: white;
226
+ border: none;
227
+ padding: 8px 16px;
228
+ border-radius: 6px;
229
+ font-weight: bold;
230
+ }
231
+ .user_lora_card {
232
+ background: #f8fafc;
233
+ border: 1px solid #e2e8f0;
234
+ border-radius: 8px;
235
+ padding: 12px;
236
+ margin: 8px 0;
237
+ }
238
+ #lora_gallery{
239
+ overflow: scroll !important
240
+ }
241
+ """
242
 
243
+ # Build the Gradio interface
244
+ with gr.Blocks(css=css) as demo:
245
+ gr_lora_configs = gr.State(value=lora_configs)
246
 
247
  title = gr.HTML(
248
+ """<h1>Image Style Transfer using FLUX.1 with Adapters</h1>
249
+ <p>Edit images using custom style adapters. Fast generation with minimal steps.</p>""",
250
  )
251
 
252
  selected_state = gr.State(value=None)
253
+ user_lora = gr.State(value=None)
254
 
255
+ with gr.Row(elem_id="app_container"):
256
+ with gr.Column(scale=4, elem_id="left_panel"):
257
+ with gr.Group(elem_id="lora_selection"):
258
+ input_image = gr.Image(label="Upload a picture", type="pil", height=300)
259
 
260
  gallery = gr.Gallery(
261
+ label="Pick an Adapter",
262
  allow_preview=False,
263
  columns=3,
264
+ elem_id="lora_gallery",
265
  show_share_button=False,
266
  height=400
267
  )
268
 
269
+ user_lora_input = gr.Textbox(
270
+ label="Or enter a custom HuggingFace adapter",
271
+ placeholder="e.g., username/adapter-name",
272
  visible=True
273
  )
274
+ user_lora_card = gr.HTML(visible=False)
275
+ unload_user_lora_button = gr.Button("Remove custom adapter", visible=True)
276
 
277
  with gr.Column(scale=5):
278
  with gr.Row():
279
  prompt = gr.Textbox(
280
+ label="Editing Prompt",
281
  show_label=False,
282
  lines=1,
283
  max_lines=1,
284
+ placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'",
285
+ elem_id="edit_prompt"
286
  )
287
+ run_button = gr.Button("Generate", elem_id="generate_button")
288
 
289
+ result = gr.Image(label="Generated Image", interactive=False)
290
  reuse_button = gr.Button("Reuse this image", visible=False)
291
 
292
+ with gr.Accordion("Advanced Settings", open=True):
293
  lora_scale = gr.Slider(
294
+ label="Adapter Scale",
295
  minimum=0,
296
  maximum=2,
297
  step=0.1,
298
+ value=1.5,
299
+ info="Controls the strength of the adapter effect"
300
  )
301
  seed = gr.Slider(
302
+ label="Seed",
303
  minimum=0,
304
  maximum=MAX_SEED,
305
  step=1,
306
  value=0,
 
307
  )
308
+ steps = gr.Slider(
309
+ label="Steps",
310
+ minimum=1,
311
+ maximum=40,
312
+ value=10,
313
+ step=1
314
+ )
315
+ width = gr.Slider(
316
+ label="Width",
317
+ minimum=128,
318
+ maximum=2560,
319
+ step=1,
320
+ value=960,
321
+ )
322
+ height = gr.Slider(
323
+ label="Height",
324
+ minimum=128,
325
+ maximum=2560,
326
+ step=1,
327
+ value=1280,
328
+ )
329
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
330
  guidance_scale = gr.Slider(
331
+ label="Guidance Scale",
332
  minimum=1,
333
  maximum=10,
334
  step=0.1,
335
+ value=2.8,
 
336
  )
337
 
338
  prompt_title = gr.Markdown(
339
+ value="### Click on an adapter in the gallery to select it",
340
  visible=True,
341
+ elem_id="lora_info",
342
  )
343
 
344
  # Event handlers
345
+ user_lora_input.input(
346
+ fn=load_user_lora,
347
+ inputs=[user_lora_input],
348
+ outputs=[user_lora_card, user_lora_card, unload_user_lora_button, user_lora, gallery, prompt_title, selected_state],
349
  )
350
 
351
+ unload_user_lora_button.click(
352
+ fn=unload_user_lora,
353
+ outputs=[user_lora_input, unload_user_lora_button, user_lora_card, user_lora, selected_state]
354
  )
355
 
356
  gallery.select(
357
+ fn=on_lora_select,
358
+ inputs=[gr_lora_configs],
359
  outputs=[prompt_title, prompt, selected_state],
360
  show_progress=False
361
  )
362
 
363
  gr.on(
364
  triggers=[run_button.click, prompt.submit],
365
+ fn=generate_image_wrapper,
366
+ inputs=[input_image, prompt, selected_state, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_lora_configs],
367
  outputs=[result, seed, reuse_button]
368
  )
369
 
 
372
  inputs=[result],
373
  outputs=[input_image]
374
  )
375
+
376
+ # Initialize the gallery
377
  demo.load(
378
+ fn=sort_lora_gallery,
379
+ inputs=[gr_lora_configs],
380
+ outputs=[gallery, gr_lora_configs]
381
  )
382
 
383
  demo.queue(default_concurrency_limit=None)