Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,9 +18,9 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
| 18 |
|
| 19 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
|
| 21 |
-
def load_model(custom_model):
|
| 22 |
# This is where you load your trained weights
|
| 23 |
-
pipe.load_lora_weights(custom_model, use_auth_token=True)
|
| 24 |
pipe.to(device)
|
| 25 |
return "Model loaded!"
|
| 26 |
|
|
@@ -88,6 +88,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 88 |
with gr.Row():
|
| 89 |
with gr.Column():
|
| 90 |
custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
|
|
|
|
| 91 |
model_status = gr.Textbox(label="model status", interactive=False)
|
| 92 |
load_model_btn = gr.Button("Load my model")
|
| 93 |
|
|
@@ -125,7 +126,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 125 |
|
| 126 |
load_model_btn.click(
|
| 127 |
fn = load_model,
|
| 128 |
-
inputs=[custom_model],
|
| 129 |
outputs = [model_status]
|
| 130 |
)
|
| 131 |
submit_btn.click(
|
|
|
|
| 18 |
|
| 19 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
|
| 21 |
+
def load_model(custom_model, weight_name):
|
| 22 |
# This is where you load your trained weights
|
| 23 |
+
pipe.load_lora_weights(custom_model, weight_name=weight_name, use_auth_token=True)
|
| 24 |
pipe.to(device)
|
| 25 |
return "Model loaded!"
|
| 26 |
|
|
|
|
| 88 |
with gr.Row():
|
| 89 |
with gr.Column():
|
| 90 |
custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
|
| 91 |
+
weight_name = gr.Textbox(label="Safetensors file", value="pytorch_lora_weights.safetensors")
|
| 92 |
model_status = gr.Textbox(label="model status", interactive=False)
|
| 93 |
load_model_btn = gr.Button("Load my model")
|
| 94 |
|
|
|
|
| 126 |
|
| 127 |
load_model_btn.click(
|
| 128 |
fn = load_model,
|
| 129 |
+
inputs=[custom_model, weight_name],
|
| 130 |
outputs = [model_status]
|
| 131 |
)
|
| 132 |
submit_btn.click(
|