fffiloni commited on
Commit
6be2e5b
·
1 Parent(s): 5494b47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -6
app.py CHANGED
@@ -10,24 +10,41 @@ pipe = DiffusionPipeline.from_pretrained(
10
  use_safetensors=True
11
  )
12
 
13
- # This is where you load your trained weights
14
- pipe.load_lora_weights("victor/outicon")
15
 
16
- pipe.to("cuda")
 
 
 
 
17
 
18
  def infer (prompt):
19
  image = pipe(prompt=prompt, num_inference_steps=50).images[0]
20
  return image
21
 
22
  css = """
23
- #col-container {max-width: 780px; margin-left: auto; margin-right: auto;}
24
  """
25
- with gr.Blocks() as demo:
 
26
  with gr.Column(elem_id="col-container"):
27
- prompt_in = gr.Textbox(label="Prompt")
 
 
 
 
 
 
 
 
 
28
  submit_btn = gr.Button("Submit")
29
  image_out = gr.Image(label="Image output")
30
 
 
 
 
 
 
31
  submit_btn.click(
32
  fn = infer,
33
  inputs = [prompt_in],
 
10
  use_safetensors=True
11
  )
12
 
 
 
13
 
14
+ def load_model(custom_model):
15
+ # This is where you load your trained weights
16
+ pipe.load_lora_weights(custom_model)
17
+ pipe.to("cuda")
18
+ return "Model loaded!"
19
 
20
  def infer (prompt):
21
  image = pipe(prompt=prompt, num_inference_steps=50).images[0]
22
  return image
23
 
24
  css = """
25
+ #col-container {max-width: 580px; margin-left: auto; margin-right: auto;}
26
  """
27
+
28
+ with gr.Blocks(css=css) as demo:
29
  with gr.Column(elem_id="col-container"):
30
+ gr.Markdown("""
31
+ # SD-XL Custom Model Inference
32
+ """)
33
+ with gr.Row():
34
+ with gr.Column():
35
+ custom_model = gr.Textbox(label="Your custom model ID", placeholder="your_username/your_trained_model_name")
36
+ model_status = gr.Textbox(label="model status", interactive=False)
37
+ load_model_btn = gr.Button("Load my model")
38
+
39
+ prompt_in = gr.Textbox(label="Prompt")
40
  submit_btn = gr.Button("Submit")
41
  image_out = gr.Image(label="Image output")
42
 
43
+ load_model_btn.click(
44
+ fn = load_model,
45
+ inputs=[custom_model],
46
+ outputs = [model_status]
47
+ )
48
  submit_btn.click(
49
  fn = infer,
50
  inputs = [prompt_in],