Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -57,10 +57,14 @@ def load_model(custom_model):
|
|
57 |
|
58 |
print(f"Safetensors available: {sfts_available_files}")
|
59 |
|
60 |
-
return "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
|
61 |
|
62 |
-
def custom_model_changed(custom_model):
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
@spaces.GPU
|
66 |
def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
|
@@ -159,41 +163,45 @@ with gr.Blocks(css=css) as demo:
|
|
159 |
<h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
|
160 |
<p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
|
161 |
""")
|
162 |
-
with gr.
|
163 |
-
with gr.
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
)
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
info="Make sure your model is set to PUBLIC"
|
181 |
)
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
allow_custom_value=True,
|
188 |
-
visible = False
|
189 |
-
)
|
190 |
-
with gr.Column():
|
191 |
-
load_model_btn = gr.Button("Load my model")
|
192 |
-
model_status = gr.Textbox(
|
193 |
-
label = "model status",
|
194 |
-
show_label = False
|
195 |
-
)
|
196 |
-
trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
|
197 |
|
198 |
prompt_in = gr.Textbox(
|
199 |
label="Your Prompt",
|
@@ -242,14 +250,14 @@ with gr.Blocks(css=css) as demo:
|
|
242 |
|
243 |
custom_model.blur(
|
244 |
fn=custom_model_changed,
|
245 |
-
inputs = [custom_model],
|
246 |
outputs = [model_status],
|
247 |
queue = False
|
248 |
)
|
249 |
load_model_btn.click(
|
250 |
fn = load_model,
|
251 |
inputs=[custom_model],
|
252 |
-
outputs = [model_status, weight_name, trigger_word],
|
253 |
queue = False
|
254 |
)
|
255 |
submit_btn.click(
|
|
|
57 |
|
58 |
print(f"Safetensors available: {sfts_available_files}")
|
59 |
|
60 |
+
return custom_model, "Model Ready", gr.update(choices=sfts_available_files, value=sfts_available_files[0], visible=True), gr.update(value=instance_prompt, visible=True)
|
61 |
|
62 |
+
def custom_model_changed(custom_model, previous_model):
|
63 |
+
if custom_model != previous_model:
|
64 |
+
status_message = "Model changed, you must reload before re-run"
|
65 |
+
else:
|
66 |
+
status_message = "Model ready"
|
67 |
+
return status_message
|
68 |
|
69 |
@spaces.GPU
|
70 |
def infer (custom_model, weight_name, prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
|
|
|
163 |
<h2 style="text-align: center;">SD-XL Custom Model Inference</h2>
|
164 |
<p style="text-align: center;">Use this demo to check results from your previously trained LoRa model.</p>
|
165 |
""")
|
166 |
+
with gr.Box():
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column():
|
169 |
+
if not is_shared_ui:
|
170 |
+
your_username = api.whoami()["name"]
|
171 |
+
my_models = api.list_models(author=your_username, filter=["diffusers", "stable-diffusion-xl", 'lora'])
|
172 |
+
model_names = [item.modelId for item in my_models]
|
173 |
+
|
174 |
+
if not is_shared_ui:
|
175 |
+
custom_model = gr.Dropdown(
|
176 |
+
label = "Your custom model ID",
|
177 |
+
choices = model_names,
|
178 |
+
allow_custom_value = True
|
179 |
+
#placeholder = "username/model_id"
|
180 |
+
)
|
181 |
+
else:
|
182 |
+
custom_model = gr.Textbox(
|
183 |
+
label="Your custom model ID",
|
184 |
+
placeholder="your_username/your_trained_model_name",
|
185 |
+
info="Make sure your model is set to PUBLIC"
|
186 |
+
)
|
187 |
+
|
188 |
+
weight_name = gr.Dropdown(
|
189 |
+
label="Safetensors file",
|
190 |
+
#value="pytorch_lora_weights.safetensors",
|
191 |
+
info="specify which one if model has several .safetensors files",
|
192 |
+
allow_custom_value=True,
|
193 |
+
visible = False
|
194 |
)
|
195 |
+
with gr.Column():
|
196 |
+
load_model_btn = gr.Button("Load my model")
|
197 |
+
previous_model = gr.Textbox(
|
198 |
+
visible=False
|
|
|
199 |
)
|
200 |
+
model_status = gr.Textbox(
|
201 |
+
label = "model status",
|
202 |
+
show_label = False
|
203 |
+
)
|
204 |
+
trigger_word = gr.Textbox(label="Trigger word", interactive=False, visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
prompt_in = gr.Textbox(
|
207 |
label="Your Prompt",
|
|
|
250 |
|
251 |
custom_model.blur(
|
252 |
fn=custom_model_changed,
|
253 |
+
inputs = [custom_model, previous_model],
|
254 |
outputs = [model_status],
|
255 |
queue = False
|
256 |
)
|
257 |
load_model_btn.click(
|
258 |
fn = load_model,
|
259 |
inputs=[custom_model],
|
260 |
+
outputs = [previous_model, model_status, weight_name, trigger_word],
|
261 |
queue = False
|
262 |
)
|
263 |
submit_btn.click(
|