multimodalart HF Staff commited on
Commit
a556947
·
verified ·
1 Parent(s): 51f469a

Fix user scale override

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -55,7 +55,7 @@ def load_lora_weights(repo_id, weights_filename):
55
  def update_selection(selected_state: gr.SelectData, flux_loras):
56
  """Update UI when a LoRA is selected"""
57
  if selected_state.index >= len(flux_loras):
58
- return "### No LoRA selected", gr.update(), None
59
 
60
  lora_repo = flux_loras[selected_state.index]["repo"]
61
  trigger_word = flux_loras[selected_state.index]["trigger_word"]
@@ -67,7 +67,10 @@ def update_selection(selected_state: gr.SelectData, flux_loras):
67
  else:
68
  new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
69
 
70
- return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
 
 
 
71
 
72
  def get_huggingface_lora(link):
73
  """Download LoRA from HuggingFace link"""
@@ -133,12 +136,12 @@ def classify_gallery(flux_loras):
133
  sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
134
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
135
 
136
- def infer_with_lora_wrapper(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
137
  """Wrapper function to handle state serialization"""
138
- return infer_with_lora(input_image, prompt, selected_index, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
139
 
140
  @spaces.GPU
141
- def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
142
  """Generate image with selected LoRA"""
143
  global current_lora, pipe
144
 
@@ -155,14 +158,9 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
155
  # Load LoRA if needed
156
  if lora_to_use and lora_to_use != current_lora:
157
  try:
158
- # Unload current LoRA
159
  if current_lora:
160
  pipe.unload_lora_weights()
161
 
162
- # Load new LoRA
163
- if lora_to_use["lora_scale_config"]:
164
- lora_scale = lora_to_use["lora_scale_config"]
165
- print("lora scale loaded from config", lora_scale)
166
  lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
167
  if lora_path:
168
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
@@ -173,8 +171,9 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
173
  except Exception as e:
174
  print(f"Error loading LoRA: {e}")
175
  # Continue without LoRA
176
- else:
177
- print(f"using already loaded lora: {lora_to_use}")
 
178
 
179
  input_image = input_image.convert("RGB")
180
  # Add trigger word to prompt
@@ -204,7 +203,7 @@ def infer_with_lora(input_image, prompt, selected_index, custom_lora, seed=42, r
204
  height=input_image.size[1],
205
  prompt=prompt,
206
  guidance_scale=guidance_scale,
207
- generator=torch.Generator().manual_seed(seed),
208
  ).images[0]
209
 
210
  return image, seed, gr.update(visible=True), lora_scale
@@ -264,6 +263,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
264
 
265
  selected_state = gr.State(value=None)
266
  custom_loaded_lora = gr.State(value=None)
 
267
 
268
  with gr.Row(elem_id="main_app"):
269
  with gr.Column(scale=4, elem_id="box_column"):
@@ -348,15 +348,15 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
348
  gallery.select(
349
  fn=update_selection,
350
  inputs=[gr_flux_loras],
351
- outputs=[prompt_title, prompt, selected_state],
352
  show_progress=False
353
  )
354
 
355
  gr.on(
356
  triggers=[run_button.click, prompt.submit],
357
  fn=infer_with_lora_wrapper,
358
- inputs=[input_image, prompt, selected_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
359
- outputs=[result, seed, reuse_button, lora_scale]
360
  )
361
 
362
  reuse_button.click(
 
55
  def update_selection(selected_state: gr.SelectData, flux_loras):
56
  """Update UI when a LoRA is selected"""
57
  if selected_state.index >= len(flux_loras):
58
+ return "### No LoRA selected", gr.update(), None, gr.update()
59
 
60
  lora_repo = flux_loras[selected_state.index]["repo"]
61
  trigger_word = flux_loras[selected_state.index]["trigger_word"]
 
67
  else:
68
  new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
69
 
70
+ optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.0)
71
+
72
+ return updated_text, gr.update(placeholder=new_placeholder), selected_state.index, gr.update(value=optimal_scale)
73
+
74
 
75
  def get_huggingface_lora(link):
76
  """Download LoRA from HuggingFace link"""
 
136
  sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
137
  return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
138
 
139
+ def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
140
  """Wrapper function to handle state serialization"""
141
+ return infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
142
 
143
  @spaces.GPU
144
+ def infer_with_lora(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
145
  """Generate image with selected LoRA"""
146
  global current_lora, pipe
147
 
 
158
  # Load LoRA if needed
159
  if lora_to_use and lora_to_use != current_lora:
160
  try:
 
161
  if current_lora:
162
  pipe.unload_lora_weights()
163
 
 
 
 
 
164
  lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
165
  if lora_path:
166
  pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
 
171
  except Exception as e:
172
  print(f"Error loading LoRA: {e}")
173
  # Continue without LoRA
174
+ elif lora_scale != lora_state:
175
+ pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
176
+ print(f"using already loaded lora: {lora_to_use}, udpated {lora_scale} based on user preference")
177
 
178
  input_image = input_image.convert("RGB")
179
  # Add trigger word to prompt
 
203
  height=input_image.size[1],
204
  prompt=prompt,
205
  guidance_scale=guidance_scale,
206
+ generator=torch.Generator().manual_seed(seed)
207
  ).images[0]
208
 
209
  return image, seed, gr.update(visible=True), lora_scale
 
263
 
264
  selected_state = gr.State(value=None)
265
  custom_loaded_lora = gr.State(value=None)
266
+ lora_state = gr.State(value=1.0)
267
 
268
  with gr.Row(elem_id="main_app"):
269
  with gr.Column(scale=4, elem_id="box_column"):
 
348
  gallery.select(
349
  fn=update_selection,
350
  inputs=[gr_flux_loras],
351
+ outputs=[prompt_title, prompt, selected_state, lora_scale],
352
  show_progress=False
353
  )
354
 
355
  gr.on(
356
  triggers=[run_button.click, prompt.submit],
357
  fn=infer_with_lora_wrapper,
358
+ inputs=[input_image, prompt, selected_state, lora_state, custom_loaded_lora, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, gr_flux_loras],
359
+ outputs=[result, seed, reuse_button, lora_state]
360
  )
361
 
362
  reuse_button.click(