Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -57,10 +57,14 @@ def load_model(custom_model):
|
|
| 57 |
|
| 58 |
print(f"Safetensors available: {sfts_available_files}")
|
| 59 |
|
| 60 |
-
return "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
|
| 61 |
|
| 62 |
-
def custom_model_changed(custom_model):
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
@spaces.GPU
|
| 66 |
def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
|
|
@@ -159,41 +163,45 @@ with gr.Blocks(css=css) as demo:
|
|
| 159 |
<h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
|
| 160 |
<p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
|
| 161 |
""")
|
| 162 |
-
with gr.
|
| 163 |
-
with gr.
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
info="Make sure your model is set to PUBLIC"
|
| 181 |
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
allow_custom_value=True,
|
| 188 |
-
visible = False
|
| 189 |
-
)
|
| 190 |
-
with gr.Column():
|
| 191 |
-
load_model_btn = gr.Button("Load my model")
|
| 192 |
-
model_status = gr.Textbox(
|
| 193 |
-
label = "model status",
|
| 194 |
-
show_label = False
|
| 195 |
-
)
|
| 196 |
-
trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
|
| 197 |
|
| 198 |
prompt_in = gr.Textbox(
|
| 199 |
label="Your Prompt",
|
|
@@ -242,14 +250,14 @@ with gr.Blocks(css=css) as demo:
|
|
| 242 |
|
| 243 |
custom_model.blur(
|
| 244 |
fn=custom_model_changed,
|
| 245 |
-
inputs = [custom_model],
|
| 246 |
outputs = [model_status],
|
| 247 |
queue = False
|
| 248 |
)
|
| 249 |
load_model_btn.click(
|
| 250 |
fn = load_model,
|
| 251 |
inputs=[custom_model],
|
| 252 |
-
outputs = [model_status, weight_name, trigger_word],
|
| 253 |
queue = False
|
| 254 |
)
|
| 255 |
submit_btn.click(
|
|
|
|
| 57 |
|
| 58 |
print(f"Safetensors available: {sfts_available_files}")
|
| 59 |
|
| 60 |
+
return custom_model, "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
|
| 61 |
|
| 62 |
+
def custom_model_changed(custom_model, previous_model):
|
| 63 |
+
if custom_model != previous_model:
|
| 64 |
+
status_message = "Model changed, you must reload before re-run"
|
| 65 |
+
else:
|
| 66 |
+
status_message = "Model ready"
|
| 67 |
+
return status_message
|
| 68 |
|
| 69 |
@spaces.GPU
|
| 70 |
def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 163 |
<h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
|
| 164 |
<p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
|
| 165 |
""")
|
| 166 |
+
with gr.Box():
|
| 167 |
+
with gr.Row():
|
| 168 |
+
with gr.Column():
|
| 169 |
+
if not is_shared_ui:
|
| 170 |
+
your_username = api.whoami()["name"]
|
| 171 |
+
my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
|
| 172 |
+
model_names = [item.modelId for item in my_models]
|
| 173 |
+
|
| 174 |
+
if not is_shared_ui:
|
| 175 |
+
custom_model = gr.Dropdown(
|
| 176 |
+
label = "Your custom model ID",
|
| 177 |
+
choices = model_names,
|
| 178 |
+
allow_custom_value = True
|
| 179 |
+
#placeholder = "username/model_id"
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
custom_model = gr.Textbox(
|
| 183 |
+
label="Your custom model ID",
|
| 184 |
+
placeholder="your_username/your_trained_model_name",
|
| 185 |
+
info="Make sure your model is set to PUBLIC"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
weight_name = gr.Dropdown(
|
| 189 |
+
label="Safetensors file",
|
| 190 |
+
#value="pytorch_lora_weights.safetensors",
|
| 191 |
+
info="specify which one if model has several .safetensors files",
|
| 192 |
+
allow_custom_value=True,
|
| 193 |
+
visible = False
|
| 194 |
)
|
| 195 |
+
with gr.Column():
|
| 196 |
+
load_model_btn = gr.Button("Load my model")
|
| 197 |
+
previous_model = gr.Textbox(
|
| 198 |
+
visible=False
|
|
|
|
| 199 |
)
|
| 200 |
+
model_status = gr.Textbox(
|
| 201 |
+
label = "model status",
|
| 202 |
+
show_label = False
|
| 203 |
+
)
|
| 204 |
+
trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
prompt_in = gr.Textbox(
|
| 207 |
label="Your Prompt",
|
|
|
|
| 250 |
|
| 251 |
custom_model.blur(
|
| 252 |
fn=custom_model_changed,
|
| 253 |
+
inputs = [custom_model, previous_model],
|
| 254 |
outputs = [model_status],
|
| 255 |
queue = False
|
| 256 |
)
|
| 257 |
load_model_btn.click(
|
| 258 |
fn = load_model,
|
| 259 |
inputs=[custom_model],
|
| 260 |
+
outputs = [previous_model, model_status, weight_name, trigger_word],
|
| 261 |
queue = False
|
| 262 |
)
|
| 263 |
submit_btn.click(
|