Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -18,9 +18,9 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
18 |
|
19 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
-
def load_model(custom_model):
|
22 |
# This is where you load your trained weights
|
23 |
-
pipe.load_lora_weights(custom_model, use_auth_token=True)
|
24 |
pipe.to(device)
|
25 |
return "Model loaded!"
|
26 |
|
@@ -88,6 +88,7 @@ with gr.Blocks(css=css) as demo:
|
|
88 |
with gr.Row():
|
89 |
with gr.Column():
|
90 |
custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
|
|
|
91 |
model_status = gr.Textbox(label="model status", interactive=False)
|
92 |
load_model_btn = gr.Button("Load my model")
|
93 |
|
@@ -125,7 +126,7 @@ with gr.Blocks(css=css) as demo:
|
|
125 |
|
126 |
load_model_btn.click(
|
127 |
fn = load_model,
|
128 |
-
inputs=[custom_model],
|
129 |
outputs = [model_status]
|
130 |
)
|
131 |
submit_btn.click(
|
|
|
18 |
|
19 |
device="cuda" if torch.cuda.is_available() else "cpu"
|
20 |
|
21 |
+
def load_model(custom_model, weight_name):
|
22 |
# This is where you load your trained weights
|
23 |
+
pipe.load_lora_weights(custom_model, weight_name=weight_name, use_auth_token=True)
|
24 |
pipe.to(device)
|
25 |
return "Model loaded!"
|
26 |
|
|
|
88 |
with gr.Row():
|
89 |
with gr.Column():
|
90 |
custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
|
91 |
+
weight_name = gr.Textbox(label="Safetensors file", value="pytorch_lora_weights.safetensors")
|
92 |
model_status = gr.Textbox(label="model status", interactive=False)
|
93 |
load_model_btn = gr.Button("Load my model")
|
94 |
|
|
|
126 |
|
127 |
load_model_btn.click(
|
128 |
fn = load_model,
|
129 |
+
inputs=[custom_model, weight_name],
|
130 |
outputs = [model_status]
|
131 |
)
|
132 |
submit_btn.click(
|