fffiloni commited on
Commit
b2ca5bb
·
1 Parent(s): 02cde20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -38
app.py CHANGED
@@ -57,10 +57,14 @@ def load_model(custom_model):
57
 
58
  print(f"Safetensors available: {sfts_available_files}")
59
 
60
- return "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
61
 
62
- def custom_model_changed(custom_model):
63
- return "Model changed, you must reload before re-run"
 
 
 
 
64
 
65
  @spaces.GPU
66
  def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
@@ -159,41 +163,45 @@ with gr.Blocks(css=css) as demo:
159
  <h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
160
  <p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
161
  """)
162
- with gr.Row():
163
- with gr.Column():
164
- if not is_shared_ui:
165
- your_username = api.whoami()["name"]
166
- my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
167
- model_names = [item.modelId for item in my_models]
168
-
169
- if not is_shared_ui:
170
- custom_model = gr.Dropdown(
171
- label = "Your custom model ID",
172
- choices = model_names,
173
- allow_custom_value = True
174
- #placeholder = "username/model_id"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
- else:
177
- custom_model = gr.Textbox(
178
- label="Your custom model ID",
179
- placeholder="your_username/your_trained_model_name",
180
- info="Make sure your model is set to PUBLIC"
181
  )
182
-
183
- weight_name = gr.Dropdown(
184
- label="Safetensors file",
185
- #value="pytorch_lora_weights.safetensors",
186
- info="specify which one if model has several .safetensors files",
187
- allow_custom_value=True,
188
- visible = False
189
- )
190
- with gr.Column():
191
- load_model_btn = gr.Button("Load my model")
192
- model_status = gr.Textbox(
193
- label = "model status",
194
- show_label = False
195
- )
196
- trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
197
 
198
  prompt_in = gr.Textbox(
199
  label="Your Prompt",
@@ -242,14 +250,14 @@ with gr.Blocks(css=css) as demo:
242
 
243
  custom_model.blur(
244
  fn=custom_model_changed,
245
- inputs = [custom_model],
246
  outputs = [model_status],
247
  queue = False
248
  )
249
  load_model_btn.click(
250
  fn = load_model,
251
  inputs=[custom_model],
252
- outputs = [model_status, weight_name, trigger_word],
253
  queue = False
254
  )
255
  submit_btn.click(
 
57
 
58
  print(f"Safetensors available: {sfts_available_files}")
59
 
60
+ return custom_model, "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
61
 
62
+ def custom_model_changed(custom_model, previous_model):
63
+ if custom_model != previous_model:
64
+ status_message = "Model changed, you must reload before re-run"
65
+ else:
66
+ status_message = "Model ready"
67
+ return status_message
68
 
69
  @spaces.GPU
70
  def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
 
163
  <h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
164
  <p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
165
  """)
166
+ with gr.Box():
167
+ with gr.Row():
168
+ with gr.Column():
169
+ if not is_shared_ui:
170
+ your_username = api.whoami()["name"]
171
+ my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
172
+ model_names = [item.modelId for item in my_models]
173
+
174
+ if not is_shared_ui:
175
+ custom_model = gr.Dropdown(
176
+ label = "Your custom model ID",
177
+ choices = model_names,
178
+ allow_custom_value = True
179
+ #placeholder = "username/model_id"
180
+ )
181
+ else:
182
+ custom_model = gr.Textbox(
183
+ label="Your custom model ID",
184
+ placeholder="your_username/your_trained_model_name",
185
+ info="Make sure your model is set to PUBLIC"
186
+ )
187
+
188
+ weight_name = gr.Dropdown(
189
+ label="Safetensors file",
190
+ #value="pytorch_lora_weights.safetensors",
191
+ info="specify which one if model has several .safetensors files",
192
+ allow_custom_value=True,
193
+ visible = False
194
  )
195
+ with gr.Column():
196
+ load_model_btn = gr.Button("Load my model")
197
+ previous_model = gr.Textbox(
198
+ visible=False
 
199
  )
200
+ model_status = gr.Textbox(
201
+ label = "model status",
202
+ show_label = False
203
+ )
204
+ trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
 
 
 
 
 
 
 
 
 
 
205
 
206
  prompt_in = gr.Textbox(
207
  label="Your Prompt",
 
250
 
251
  custom_model.blur(
252
  fn=custom_model_changed,
253
+ inputs = [custom_model, previous_model],
254
  outputs = [model_status],
255
  queue = False
256
  )
257
  load_model_btn.click(
258
  fn = load_model,
259
  inputs=[custom_model],
260
+ outputs = [previous_model, model_status, weight_name, trigger_word],
261
  queue = False
262
  )
263
  submit_btn.click(