fffiloni commited on
Commit
4c26217
·
1 Parent(s): 3276af4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -18,9 +18,9 @@ pipe = DiffusionPipeline.from_pretrained(
18
 
19
  device="cuda" if torch.cuda.is_available() else "cpu"
20
 
21
- def load_model(custom_model):
22
  # This is where you load your trained weights
23
- pipe.load_lora_weights(custom_model, use_auth_token=True)
24
  pipe.to(device)
25
  return "Model loaded!"
26
 
@@ -88,6 +88,7 @@ with gr.Blocks(css=css) as demo:
88
  with gr.Row():
89
  with gr.Column():
90
  custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
 
91
  model_status = gr.Textbox(label="model status", interactive=False)
92
  load_model_btn = gr.Button("Load my model")
93
 
@@ -125,7 +126,7 @@ with gr.Blocks(css=css) as demo:
125
 
126
  load_model_btn.click(
127
  fn = load_model,
128
- inputs=[custom_model],
129
  outputs = [model_status]
130
  )
131
  submit_btn.click(
 
18
 
19
  device="cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ def load_model(custom_model, weight_name):
22
  # This is where you load your trained weights
23
+ pipe.load_lora_weights(custom_model, weight_name=weight_name, use_auth_token=True)
24
  pipe.to(device)
25
  return "Model loaded!"
26
 
 
88
  with gr.Row():
89
  with gr.Column():
90
  custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name", info="Make sure your model is set to PUBLIC ")
91
+ weight_name = gr.Textbox(label="Safetensors file", value="pytorch_lora_weights.safetensors")
92
  model_status = gr.Textbox(label="model status", interactive=False)
93
  load_model_btn = gr.Button("Load my model")
94
 
 
126
 
127
  load_model_btn.click(
128
  fn = load_model,
129
+ inputs=[custom_model, weight_name],
130
  outputs = [model_status]
131
  )
132
  submit_btn.click(