fffiloni commited on
Commit
c5267aa
·
1 Parent(s): 515dd49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -3
app.py CHANGED
@@ -17,8 +17,15 @@ def load_model(custom_model):
17
  pipe.to("cuda")
18
  return "Model loaded!"
19
 
20
- def infer (prompt):
21
- image = pipe(prompt=prompt, num_inference_steps=50).images[0]
 
 
 
 
 
 
 
22
  return image
23
 
24
  css = """
@@ -37,6 +44,34 @@ with gr.Blocks(css=css) as demo:
37
  load_model_btn = gr.Button("Load my model")
38
 
39
  prompt_in = gr.Textbox(label="Prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  submit_btn = gr.Button("Submit")
41
  image_out = gr.Image(label="Image output")
42
 
@@ -47,7 +82,7 @@ with gr.Blocks(css=css) as demo:
47
  )
48
  submit_btn.click(
49
  fn = infer,
50
- inputs = [prompt_in],
51
  outputs = [image_out]
52
  )
53
 
 
17
  pipe.to("cuda")
18
  return "Model loaded!"
19
 
20
+ def infer (prompt, inf_steps, guidance_scale, seed, lora_weigth, progress=gr.Progress(track_tqdm=True)):
21
+ generator = torch.Generator(device="cuda").manual_seed(seed)
22
+ image = pipe(
23
+ prompt=prompt,
24
+ num_inference_steps=inf_steps,
25
+ guidance_scale = float(guidance_scale),
26
+ generator=generator,
27
+ cross_attention_kwargs={"scale": float(lora_weight)}
28
+ ).images[0]
29
  return image
30
 
31
  css = """
 
44
  load_model_btn = gr.Button("Load my model")
45
 
46
  prompt_in = gr.Textbox(label="Prompt")
47
+ inf_steps = gr.Slider(
48
+ label="Inference steps",
49
+ minimum=12,
50
+ maximum=50,
51
+ step=1,
52
+ value=25
53
+ )
54
+ guidance_scale = gr.Slider(
55
+ label="Guidance scale",
56
+ minimum=0.1,
57
+ maximum=0.9,
58
+ step=0.1,
59
+ value=7.5
60
+ )
61
+ seed = gr.Slider(
62
+ label="Seed",
63
+ minimum=0,
64
+ maximum=500000,
65
+ step=1,
66
+ value=42
67
+ )
68
+ lora_weight = gr.Slider(
69
+ label="LoRa weigth",
70
+ minimum=0.0,
71
+ maximum=10.0,
72
+ step=0.01,
73
+ value=0.9
74
+ )
75
  submit_btn = gr.Button("Submit")
76
  image_out = gr.Image(label="Image output")
77
 
 
82
  )
83
  submit_btn.click(
84
  fn = infer,
85
+ inputs = [prompt_in, inf_steps, guidance_scale, seed, lora_weight],
86
  outputs = [image_out]
87
  )
88