prithivMLmods commited on
Commit
2f83709
·
verified ·
1 Parent(s): a84e1b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -301
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import numpy as np
3
- import spaces
4
  import torch
5
  import random
6
  import json
@@ -8,415 +8,276 @@ import os
8
  from PIL import Image
9
  from diffusers import FluxKontextPipeline
10
  from diffusers.utils import load_image
11
- from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, list_repo_files
12
  from safetensors.torch import load_file
13
  import requests
14
  import re
15
 
16
- # Load Kontext model
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
- pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
 
 
 
 
20
 
21
- # Load LoRA data
22
- flux_loras_raw = [
23
- {
24
- "image": "https://huggingface.co/fal/Realism-Detailer-Kontext-Dev-LoRA/resolve/main/outputs/1.png",
25
- "title": "Realism Detailer Kontext",
26
- "repo": "fal/Realism-Detailer-Kontext-Dev-LoRA",
27
- "trigger_word": "Add details to this face, improve skin details",
28
- "weights": "high_detail.safetensors"
29
- },
30
- ]
31
- print(f"Loaded {len(flux_loras_raw)} LoRAs")
32
- # Global variables for LoRA management
33
- current_lora = None
34
- lora_cache = {}
35
-
36
- def load_lora_weights(repo_id, weights_filename):
37
- """Load LoRA weights from HuggingFace"""
38
- try:
39
- # First try with the specified filename
40
- try:
41
- lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
42
- if repo_id not in lora_cache:
43
- lora_cache[repo_id] = lora_path
44
- return lora_path
45
- except Exception as e:
46
- print(f"Failed to load {weights_filename}, trying to find alternative LoRA files...")
47
-
48
- # If the specified file doesn't exist, try to find any .safetensors file
49
- from huggingface_hub import list_repo_files
50
- try:
51
- files = list_repo_files(repo_id)
52
- safetensors_files = [f for f in files if f.endswith(('.safetensors', '.bin')) and 'lora' in f.lower()]
53
-
54
- if not safetensors_files:
55
- # Try without 'lora' in filename
56
- safetensors_files = [f for f in files if f.endswith('.safetensors')]
57
-
58
- if safetensors_files:
59
- # Try the first available file
60
- for file in safetensors_files:
61
- try:
62
- print(f"Trying alternative file: {file}")
63
- lora_path = hf_hub_download(repo_id=repo_id, filename=file)
64
- if repo_id not in lora_cache:
65
- lora_cache[repo_id] = lora_path
66
- print(f"Successfully loaded alternative LoRA file: {file}")
67
- return lora_path
68
- except:
69
- continue
70
-
71
- print(f"No suitable LoRA files found in {repo_id}")
72
- return None
73
-
74
- except Exception as list_error:
75
- print(f"Error listing files in repo {repo_id}: {list_error}")
76
- return None
77
-
78
- except Exception as e:
79
- print(f"Error loading LoRA from {repo_id}: {e}")
80
- return None
81
 
82
  def update_selection(selected_state: gr.SelectData, flux_loras):
83
  """Update UI when a LoRA is selected"""
84
  if selected_state.index >= len(flux_loras):
85
- return "### No LoRA selected", gr.update(), None
86
-
87
- lora = flux_loras[selected_state.index]
88
- lora_title = lora["title"]
89
- lora_repo = lora["repo"]
90
- trigger_word = lora["trigger_word"]
91
 
92
- # Create a more informative selected text
93
- updated_text = f"### 🎨 Selected Style: {lora_title}"
94
- new_placeholder = f"Describe additional details, e.g., 'wearing a red hat' or 'smiling'"
95
 
96
- return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
97
 
98
- def get_huggingface_lora(link):
99
- """Download LoRA from HuggingFace link"""
100
- split_link = link.split("/")
101
- if len(split_link) == 2:
102
- try:
103
- model_card = ModelCard.load(link)
104
- trigger_word = model_card.data.get("instance_prompt", "")
105
-
106
- # Try to find the correct safetensors file
107
- files = list_repo_files(link)
108
- safetensors_files = [f for f in files if f.endswith('.safetensors')]
109
-
110
- # Prioritize files with 'lora' in the name
111
- lora_files = [f for f in safetensors_files if 'lora' in f.lower()]
112
- if lora_files:
113
- safetensors_file = lora_files[0]
114
- elif safetensors_files:
115
- safetensors_file = safetensors_files[0]
116
- else:
117
- # Try .bin files as fallback
118
- bin_files = [f for f in files if f.endswith('.bin') and 'lora' in f.lower()]
119
- if bin_files:
120
- safetensors_file = bin_files[0]
121
- else:
122
- safetensors_file = "pytorch_lora_weights.safetensors" # Default fallback
123
-
124
- print(f"Found LoRA file: {safetensors_file} in {link}")
125
- return split_link[1], safetensors_file, trigger_word
126
-
127
- except Exception as e:
128
- print(f"Error in get_huggingface_lora: {e}")
129
- # Try basic detection
130
- try:
131
- files = list_repo_files(link)
132
- safetensors_file = next((f for f in files if f.endswith('.safetensors')), "pytorch_lora_weights.safetensors")
133
- return split_link[1], safetensors_file, ""
134
- except:
135
- raise Exception(f"Error loading LoRA: {e}")
136
- else:
137
- raise Exception("Invalid HuggingFace repository format")
138
-
139
- def load_custom_lora(link):
140
- """Load custom LoRA from user input"""
141
- if not link:
142
- return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### 🎨 Select an art style from the gallery", None
143
-
144
- try:
145
- repo_name, weights_file, trigger_word = get_huggingface_lora(link)
146
-
147
- card = f'''
148
- <div class="custom_lora_card">
149
- <div style="display: flex; align-items: center; margin-bottom: 12px;">
150
- <span style="font-size: 18px; margin-right: 8px;">✅</span>
151
- <strong style="font-size: 16px;">Custom LoRA Loaded!</strong>
152
- </div>
153
- <div style="background: rgba(255, 255, 255, 0.8); padding: 12px; border-radius: 8px;">
154
- <h4 style="margin: 0 0 8px 0; color: #333;">{repo_name}</h4>
155
- <small style="color: #666;">{"Trigger: <code style='background: #f0f0f0; padding: 2px 6px; border-radius: 4px;'><b>"+trigger_word+"</b></code>" if trigger_word else "No trigger word found"}</small>
156
- </div>
157
- </div>
158
- '''
159
-
160
- custom_lora_data = {
161
- "repo": link,
162
- "weights": weights_file,
163
- "trigger_word": trigger_word
164
- }
165
-
166
- return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"🎨 Custom Style: {repo_name}", None
167
-
168
- except Exception as e:
169
- return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### 🎨 Select an art style from the gallery", None
170
-
171
- def remove_custom_lora():
172
- """Remove custom LoRA"""
173
- return "", gr.update(visible=False), gr.update(visible=False), None, None
174
 
175
- def classify_gallery(flux_loras):
176
- """Sort gallery by likes"""
177
- try:
178
- sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
179
- gallery_items = []
180
-
181
- for item in sorted_gallery:
182
- if "image" in item and "title" in item:
183
- image_path = item["image"]
184
- title = item["title"]
185
-
186
- # Simply use the path as-is for Gradio to handle
187
- gallery_items.append((image_path, title))
188
- print(f"Added to gallery: {image_path} - {title}")
189
-
190
- print(f"Total gallery items: {len(gallery_items)}")
191
- return gallery_items, sorted_gallery
192
- except Exception as e:
193
- print(f"Error in classify_gallery: {e}")
194
- import traceback
195
- traceback.print_exc()
196
- return [], []
197
-
198
- def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
199
  """Wrapper function to handle state serialization"""
200
- return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, flux_loras, progress)
 
201
 
202
- @spaces.GPU
203
- def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
204
  """Generate image with selected LoRA"""
205
- global current_lora, pipe
206
 
207
- # Check if input image is provided
208
- if input_image is None:
209
- gr.Warning("Please upload your portrait photo first! 📸")
210
- return None, seed, gr.update(visible=False)
211
 
212
- if randomize_seed:
213
- seed = random.randint(0, MAX_SEED)
 
214
 
215
- # Determine which LoRA to use
216
  lora_to_use = None
217
- if custom_lora:
218
- lora_to_use = custom_lora
219
- elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
220
  lora_to_use = flux_loras[selected_index]
221
- # Load LoRA if needed
222
- if lora_to_use and lora_to_use != current_lora:
 
223
  try:
224
- # Unload current LoRA
225
- if current_lora:
226
- pipe.unload_lora_weights()
227
- print(f"Unloaded previous LoRA")
228
-
229
- # Load new LoRA
230
- repo_id = lora_to_use.get("repo", "unknown")
231
- weights_file = lora_to_use.get("weights", "pytorch_lora_weights.safetensors")
232
- print(f"Loading LoRA: {repo_id} with weights: {weights_file}")
233
-
234
- lora_path = load_lora_weights(repo_id, weights_file)
235
- if lora_path:
236
- pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
237
- pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
238
- print(f"Successfully loaded: {lora_path} with scale {lora_scale}")
239
- current_lora = lora_to_use
240
- else:
241
- print(f"Failed to load LoRA from {repo_id}")
242
- gr.Warning(f"Failed to load {lora_to_use.get('title', 'style')}. Please try a different art style.")
243
- return None, seed, gr.update(visible=False)
244
 
245
  except Exception as e:
246
  print(f"Error loading LoRA: {e}")
247
- # Continue without LoRA
248
- else:
249
- if lora_to_use:
250
- print(f"Using already loaded LoRA: {lora_to_use.get('repo', 'unknown')}")
251
-
252
- try:
253
- # Convert image to RGB
254
- input_image = input_image.convert("RGB")
255
- except Exception as e:
256
- print(f"Error processing image: {e}")
257
- gr.Warning("Error processing the uploaded image. Please try a different photo. 📸")
258
- return None, seed, gr.update(visible=False)
259
-
260
- # Check if LoRA is selected
261
- if lora_to_use is None:
262
- gr.Warning("Please select an art style from the gallery first! 🎨")
263
- return None, seed, gr.update(visible=False)
264
-
265
- # Add trigger word to prompt
266
- trigger_word = lora_to_use.get("trigger_word", "")
267
-
268
- # Special handling for different art styles
269
- if trigger_word == "ghibli":
270
- prompt = f"Create a Studio Ghibli anime style portrait of the person in the photo, {prompt}. Maintain the facial identity while transforming into whimsical anime art style."
271
- elif trigger_word == "homer":
272
- prompt = f"Paint the person in Winslow Homer's American realist style, {prompt}. Keep facial features while applying watercolor and marine art techniques."
273
- elif trigger_word == "gogh":
274
- prompt = f"Transform the portrait into Van Gogh's post-impressionist style with swirling brushstrokes, {prompt}. Maintain facial identity with expressive colors."
275
- elif trigger_word == "Cezanne":
276
- prompt = f"Render the person in Paul Cézanne's geometric post-impressionist style, {prompt}. Keep facial structure while applying structured brushwork."
277
- elif trigger_word == "Renoir":
278
- prompt = f"Paint the portrait in Pierre-Auguste Renoir's impressionist style with soft light, {prompt}. Maintain identity with luminous skin tones."
279
- elif trigger_word == "claude monet":
280
- prompt = f"Create an impressionist portrait in Claude Monet's style with visible brushstrokes, {prompt}. Keep facial features while using light and color."
281
- elif trigger_word == "fantasy":
282
- prompt = f"Transform into an epic fantasy character portrait, {prompt}. Maintain facial identity while adding magical and fantastical elements."
283
- elif trigger_word == ", How2Draw":
284
- prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
285
- elif trigger_word == ", video game screenshot in the style of THSMS":
286
- prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
287
- else:
288
- prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
289
 
290
  try:
291
  image = pipe(
292
- image=input_image,
293
- prompt=prompt,
 
 
294
  guidance_scale=guidance_scale,
295
- generator=torch.Generator().manual_seed(seed),
 
296
  ).images[0]
297
 
298
- return image, seed, gr.update(visible=True)
 
299
 
300
  except Exception as e:
301
  print(f"Error during inference: {e}")
302
- return None, seed, gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  # Create Gradio interface
305
- with gr.Blocks() as demo:
306
  gr_flux_loras = gr.State(value=flux_loras_raw)
307
 
308
  title = gr.HTML(
309
- """<h1>FLUX Kontex Super LoRAs🖖</h1>""",
 
310
  )
 
311
 
312
  selected_state = gr.State(value=None)
 
313
  custom_loaded_lora = gr.State(value=None)
 
314
 
315
  with gr.Row(elem_id="main_app"):
316
  with gr.Column(scale=4, elem_id="box_column"):
317
  with gr.Group(elem_id="gallery_box"):
318
- input_image = gr.Image(label="Upload your portrait photo 📸", type="pil", height=300)
319
-
 
 
 
 
320
  gallery = gr.Gallery(
321
- label="Choose Your Art Style",
322
  allow_preview=False,
323
- columns=3,
324
  elem_id="gallery",
325
  show_share_button=False,
326
- height=400
 
327
  )
328
 
329
  custom_model = gr.Textbox(
330
- label="🔗 Or use a custom LoRA from HuggingFace",
331
  placeholder="e.g., username/lora-name",
332
- visible=True
333
  )
334
  custom_model_card = gr.HTML(visible=False)
335
- custom_model_button = gr.Button("Remove custom LoRA", visible=False)
336
 
337
  with gr.Column(scale=5):
338
  with gr.Row():
339
  prompt = gr.Textbox(
340
- label="Additional Details (optional)",
341
  show_label=False,
342
  lines=1,
343
  max_lines=1,
344
- placeholder="Describe additional details, e.g., 'wearing a red hat' or 'smiling'",
345
  elem_id="prompt"
346
  )
347
- run_button = gr.Button("Generate", elem_id="run_button")
348
 
349
- result = gr.Image(label="Your Artistic Portrait", interactive=False)
350
- reuse_button = gr.Button("🔄 Reuse this image", visible=False)
351
 
352
- with gr.Accordion("⚙️ Advanced Settings", open=False):
353
  lora_scale = gr.Slider(
354
- label="Style Strength",
355
  minimum=0,
356
  maximum=2,
357
  step=0.1,
358
  value=1.0,
359
- info="How strongly to apply the art style (1.0 = balanced)"
360
  )
361
  seed = gr.Slider(
362
- label="Random Seed",
363
  minimum=0,
364
  maximum=MAX_SEED,
365
  step=1,
366
  value=0,
367
- info="Set to 0 for random results"
368
  )
369
- randomize_seed = gr.Checkbox(label="🎲 Randomize seed for each generation", value=True)
370
  guidance_scale = gr.Slider(
371
- label="Image Guidance",
372
  minimum=1,
373
  maximum=10,
374
  step=0.1,
375
  value=2.5,
376
- info="How closely to follow the input image (lower = more creative)"
 
 
 
 
 
 
 
377
  )
378
 
379
  prompt_title = gr.Markdown(
380
- value="### 🎨 Select an art style from the gallery",
381
  visible=True,
382
  elem_id="selected_lora",
383
  )
384
 
385
  # Event handlers
386
- custom_model.input(
387
- fn=load_custom_lora,
388
- inputs=[custom_model],
389
- outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
390
- )
391
-
392
- custom_model_button.click(
393
- fn=remove_custom_lora,
394
- outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
395
- )
396
 
397
  gallery.select(
398
  fn=update_selection,
399
  inputs=[gr_flux_loras],
400
- outputs=[prompt_title, prompt, selected_state],
401
  show_progress=False
402
  )
403
 
404
  gr.on(
405
  triggers=[run_button.click, prompt.submit],
406
  fn=infer_with_lora_wrapper,
407
- inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, gr_flux_loras],
408
- outputs=[result, seed, reuse_button]
409
- )
410
-
411
- reuse_button.click(
412
- fn=lambda image: image,
413
- inputs=[result],
414
- outputs=[input_image]
415
  )
416
 
417
  # Initialize gallery
418
  demo.load(
419
- fn=classify_gallery,
420
  inputs=[gr_flux_loras],
421
  outputs=[gallery, gr_flux_loras]
422
  )
 
1
  import gradio as gr
2
  import numpy as np
3
+ import spaces # This is a special module for Hugging Face Spaces, not needed for local execution
4
  import torch
5
  import random
6
  import json
 
8
  from PIL import Image
9
  from diffusers import FluxKontextPipeline
10
  from diffusers.utils import load_image
11
+ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
12
  from safetensors.torch import load_file
13
  import requests
14
  import re
15
 
16
+ # Load Kontext model from your local path
17
  MAX_SEED = np.iinfo(np.int32).max
18
 
19
+ # Use the local path for the base model as in your test.py
20
+ pipe = FluxKontextPipeline.from_pretrained(
21
+ "black-forest-labs/FLUX.1-Kontext-dev",
22
+ torch_dtype=torch.bfloat16
23
+ ).to("cuda")
24
 
25
+ # Load LoRA data from our custom JSON file
26
+ with open("kontext_loras.json", "r") as file:
27
+ data = json.load(file)
28
+ # Add default values for keys that might be missing, to prevent errors
29
+ flux_loras_raw = [
30
+ {
31
+ "image": item["image"],
32
+ "title": item["title"],
33
+ "repo": item["repo"],
34
+ "weights": item.get("weights", "pytorch_lora_weights.safetensors"),
35
+ "prompt": item.get("prompt", f"Turn this image into {item['title']} style."),
36
+ # The following keys are kept for compatibility with the original demo structure,
37
+ # but our simplified logic doesn't heavily rely on them.
38
+ "lora_type": item.get("lora_type", "flux"),
39
+ "lora_scale_config": item.get("lora_scale", 1.0), # Default scale set to 1.0
40
+ "prompt_placeholder": item.get("prompt_placeholder", "You can edit the prompt here..."),
41
+ }
42
+ for item in data
43
+ ]
44
+ print(f"Loaded {len(flux_loras_raw)} LoRAs from kontext_loras.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def update_selection(selected_state: gr.SelectData, flux_loras):
47
  """Update UI when a LoRA is selected"""
48
  if selected_state.index >= len(flux_loras):
49
+ return "### No LoRA selected", gr.update(), None, gr.update()
 
 
 
 
 
50
 
51
+ selected_lora = flux_loras[selected_state.index]
52
+ lora_repo = selected_lora["repo"]
53
+ default_prompt = selected_lora.get("prompt")
54
 
55
+ updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
56
 
57
+ optimal_scale = selected_lora.get("lora_scale_config", 1.0)
58
+ print("Selected Style: ", selected_lora['title'])
59
+ print("Optimal Scale: ", optimal_scale)
60
+ return updated_text, gr.update(value=default_prompt), selected_state.index, optimal_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # This wrapper is kept for compatibility with the Gradio event triggers
63
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=0, guidance_scale=2.5, num_inference_steps=28, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """Wrapper function to handle state serialization"""
65
+ # The 'custom_lora' and 'lora_state' arguments are no longer used but kept in the signature
66
+ return infer_with_lora(input_image, prompt, selected_index, seed, guidance_scale, num_inference_steps, lora_scale, flux_loras, progress)
67
 
68
+ @spaces.GPU # This decorator is only for Hugging Face Spaces hardware, not needed for local execution
69
+ def infer_with_lora(input_image, prompt, selected_index, seed=0, guidance_scale=2.5, num_inference_steps=28, lora_scale=1.0, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
70
  """Generate image with selected LoRA"""
71
+ global pipe
72
 
73
+ # The seed is now always taken directly from the input. Randomization has been removed.
 
 
 
74
 
75
+ # Unload any previous LoRA to ensure a clean state
76
+ if "selected_lora" in pipe.get_active_adapters():
77
+ pipe.unload_lora_weights()
78
 
79
+ # Determine which LoRA to use from our gallery
80
  lora_to_use = None
81
+ if selected_index is not None and flux_loras and selected_index < len(flux_loras):
 
 
82
  lora_to_use = flux_loras[selected_index]
83
+
84
+ if lora_to_use:
85
+ print(f"Applying LoRA: {lora_to_use['title']}")
86
  try:
87
+ # Load LoRA directly from the Hugging Face Hub
88
+ pipe.load_lora_weights(
89
+ lora_to_use["repo"],
90
+ weight_name=lora_to_use["weights"],
91
+ adapter_name="selected_lora"
92
+ )
93
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
94
+ print(f"Loaded {lora_to_use['repo']} with scale {lora_scale}")
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  except Exception as e:
97
  print(f"Error loading LoRA: {e}")
98
+
99
+ # Use the prompt from the textbox directly.
100
+ final_prompt = prompt
101
+ print(f"Using prompt: {final_prompt}")
102
+
103
+ input_image = input_image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  try:
106
  image = pipe(
107
+ image=input_image,
108
+ width=input_image.size[0],
109
+ height=input_image.size[1],
110
+ prompt=final_prompt,
111
  guidance_scale=guidance_scale,
112
+ num_inference_steps=num_inference_steps,
113
+ generator=torch.Generator().manual_seed(seed)
114
  ).images[0]
115
 
116
+ # The seed value is no longer returned, as it's not being changed.
117
+ return image, lora_scale
118
 
119
  except Exception as e:
120
  print(f"Error during inference: {e}")
121
+ # Return an error state for all outputs
122
+ return None, lora_scale
123
+
124
+ # CSS styling
125
+ css = """
126
+ #main_app {
127
+ display: flex;
128
+ gap: 20px;
129
+ }
130
+ #box_column {
131
+ min-width: 400px;
132
+ }
133
+ #title{text-align: center}
134
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
135
+ #title img{width: 100px; margin-right: 0.5em}
136
+ #selected_lora {
137
+ color: #2563eb;
138
+ font-weight: bold;
139
+ }
140
+ #prompt {
141
+ flex-grow: 1;
142
+ }
143
+ #run_button {
144
+ background: linear-gradient(45deg, #2563eb, #3b82f6);
145
+ color: white;
146
+ border: none;
147
+ padding: 8px 16px;
148
+ border-radius: 6px;
149
+ font-weight: bold;
150
+ }
151
+ .custom_lora_card {
152
+ background: #f8fafc;
153
+ border: 1px solid #e2e8f0;
154
+ border-radius: 8px;
155
+ padding: 12px;
156
+ margin: 8px 0;
157
+ }
158
+ #gallery{
159
+ overflow: scroll !important
160
+ }
161
+ /* Custom CSS to ensure the input image is fully visible */
162
+ #input_image_display div[data-testid="image"] img {
163
+ object-fit: contain !important;
164
+ }
165
+ """
166
 
167
  # Create Gradio interface
168
+ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo:
169
  gr_flux_loras = gr.State(value=flux_loras_raw)
170
 
171
  title = gr.HTML(
172
+ """<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> Kontext-Style LoRA Explorer</h1>""",
173
+ elem_id="title",
174
  )
175
+ gr.Markdown("A demo for the style LoRAs from the [Kontext-Style](https://huggingface.co/Kontext-Style) 🤗")
176
 
177
  selected_state = gr.State(value=None)
178
+ # The following states are no longer used by the simplified logic but kept for component structure
179
  custom_loaded_lora = gr.State(value=None)
180
+ lora_state = gr.State(value=1.0)
181
 
182
  with gr.Row(elem_id="main_app"):
183
  with gr.Column(scale=4, elem_id="box_column"):
184
  with gr.Group(elem_id="gallery_box"):
185
+ input_image = gr.Image(
186
+ label="Upload a picture of yourself",
187
+ type="pil",
188
+ height=300,
189
+ elem_id="input_image_display"
190
+ )
191
  gallery = gr.Gallery(
192
+ label="Pick a LoRA",
193
  allow_preview=False,
194
+ columns=4,
195
  elem_id="gallery",
196
  show_share_button=False,
197
+ height=300,
198
+ object_fit="contain"
199
  )
200
 
201
  custom_model = gr.Textbox(
202
+ label="Or enter a custom HuggingFace FLUX LoRA",
203
  placeholder="e.g., username/lora-name",
204
+ visible=False
205
  )
206
  custom_model_card = gr.HTML(visible=False)
207
+ custom_model_button = gr.Button("Remove custom LoRA", visible=False)
208
 
209
  with gr.Column(scale=5):
210
  with gr.Row():
211
  prompt = gr.Textbox(
212
+ label="Editing Prompt",
213
  show_label=False,
214
  lines=1,
215
  max_lines=1,
216
+ placeholder="opt - describe the person/subject, e.g. 'a man with glasses and a beard'",
217
  elem_id="prompt"
218
  )
219
+ run_button = gr.Button("Generate", elem_id="run_button")
220
 
221
+ result = gr.Image(label="Generated Image", interactive=False, height=512)
 
222
 
223
+ with gr.Accordion("Advanced Settings", open=False):
224
  lora_scale = gr.Slider(
225
+ label="LoRA Scale",
226
  minimum=0,
227
  maximum=2,
228
  step=0.1,
229
  value=1.0,
230
+ info="Controls the strength of the LoRA effect"
231
  )
232
  seed = gr.Slider(
233
+ label="Seed",
234
  minimum=0,
235
  maximum=MAX_SEED,
236
  step=1,
237
  value=0,
 
238
  )
 
239
  guidance_scale = gr.Slider(
240
+ label="Guidance Scale",
241
  minimum=1,
242
  maximum=10,
243
  step=0.1,
244
  value=2.5,
245
+ )
246
+ num_inference_steps = gr.Slider(
247
+ label="Timesteps",
248
+ minimum=1,
249
+ maximum=100,
250
+ step=1,
251
+ value=28,
252
+ info="Number of inference steps"
253
  )
254
 
255
  prompt_title = gr.Markdown(
256
+ value="### Click on a LoRA in the gallery to select it",
257
  visible=True,
258
  elem_id="selected_lora",
259
  )
260
 
261
  # Event handlers
262
+ # The custom model inputs are no longer needed as we've hidden them.
 
 
 
 
 
 
 
 
 
263
 
264
  gallery.select(
265
  fn=update_selection,
266
  inputs=[gr_flux_loras],
267
+ outputs=[prompt_title, prompt, selected_state, lora_scale],
268
  show_progress=False
269
  )
270
 
271
  gr.on(
272
  triggers=[run_button.click, prompt.submit],
273
  fn=infer_with_lora_wrapper,
274
+ inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, guidance_scale, num_inference_steps, lora_scale, gr_flux_loras],
275
+ outputs=[result, lora_state]
 
 
 
 
 
 
276
  )
277
 
278
  # Initialize gallery
279
  demo.load(
280
+ fn=lambda loras: ([(item["image"], item["title"]) for item in loras], loras),
281
  inputs=[gr_flux_loras],
282
  outputs=[gallery, gr_flux_loras]
283
  )