ktrndy commited on
Commit
5280b61
·
verified ·
1 Parent(s): c606e76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -36
app.py CHANGED
@@ -3,9 +3,10 @@ import numpy as np
3
  import random
4
  import os
5
  import torch
6
- from diffusers import StableDiffusionPipeline
7
  from peft import PeftModel, LoraConfig
8
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
11
 
@@ -29,7 +30,14 @@ def infer(
29
  guidance_scale=7.0,
30
  lora_scale=1.0,
31
  num_inference_steps=20,
32
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
 
33
  ):
34
  generator = torch.Generator(device).manual_seed(seed)
35
 
@@ -40,9 +48,52 @@ def infer(
40
  if model_id is None:
41
  raise ValueError("Please specify the base model name or path")
42
 
43
- pipe = StableDiffusionPipeline.from_pretrained(model_id,
44
- torch_dtype=torch_dtype,
45
- safety_checker=None).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
47
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
48
 
@@ -54,16 +105,31 @@ def infer(
54
  pipe.text_encoder.half()
55
 
56
  pipe.to(device)
57
-
58
- image = pipe(
59
- prompt=prompt,
60
- negative_prompt=negative_prompt,
61
- guidance_scale=guidance_scale,
62
- num_inference_steps=num_inference_steps,
63
- width=width,
64
- height=height,
65
- generator=generator,
66
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  return image
69
 
@@ -138,25 +204,24 @@ with gr.Blocks(css=css, fill_height=True) as demo:
138
  label="ControlNet",
139
  )
140
  with gr.Column(visible=False) as controlnet_params:
141
- control_strength = gr.Slider(
142
  label="ControlNet conditioning scale",
143
  minimum=0.0,
144
  maximum=1.0,
145
  step=0.01,
146
  value=1.0,
147
  )
148
- control_mode = gr.Dropdown(
149
  label="ControlNet mode",
150
  choices=["edge_detection",
 
151
  "pose_estimation",
152
- "straight_line_detection",
153
- "hed_boundary",
154
- "scribbles",
155
- "human pose"],
156
  value="edge_detection",
157
  max_choices=1
158
  )
159
- condition_image = gr.Image(
160
  label="ControlNet condition image",
161
  type="pil",
162
  format="png"
@@ -168,27 +233,26 @@ with gr.Blocks(css=css, fill_height=True) as demo:
168
  )
169
 
170
  with gr.Row():
171
- controlnet_checkbox = gr.Checkbox(
172
  label="IPAdapter",
173
  )
174
- with gr.Column(visible=False) as controlnet_params:
175
- control_strength = gr.Slider(
176
- label="ControlNet conditioning scale",
177
  minimum=0.0,
178
  maximum=1.0,
179
  step=0.01,
180
  value=1.0,
181
  )
182
- control_mode = gr.Dropdown(
183
- label="ControlNet mode",
184
- choices=["edge_detection", "other"],
185
- value="edge_detection",
186
- max_choices=1
187
  )
188
- controlnet_checkbox.change(
189
  fn=lambda x: gr.Row.update(visible=x),
190
- inputs=controlnet_checkbox,
191
- outputs=controlnet_params
192
  )
193
 
194
  with gr.Accordion("Optional Settings", open=False):
@@ -225,8 +289,14 @@ with gr.Blocks(css=css, fill_height=True) as demo:
225
  seed,
226
  guidance_scale,
227
  lora_scale,
228
- num_inference_steps
229
-
 
 
 
 
 
 
230
  ],
231
  outputs=[result],
232
  )
 
3
  import random
4
  import os
5
  import torch
6
+ from diffusers import StableDiffusionPipeline, ControlNetModel
7
  from peft import PeftModel, LoraConfig
8
 
9
+
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
12
 
 
30
  guidance_scale=7.0,
31
  lora_scale=1.0,
32
  num_inference_steps=20,
33
+ controlnet_checkbox=False,
34
+ controlnet_strength=0.0,
35
+ controlnet_mode="edge_detection",
36
+ controlnet_image=None,
37
+ ip_adapter_checkbox=False,
38
+ ip_adapter_scale=0.0,
39
+ ip_adapter_image=None,
40
+ progress=gr.Progress(track_tqdm=True),
41
  ):
42
  generator = torch.Generator(device).manual_seed(seed)
43
 
 
48
  if model_id is None:
49
  raise ValueError("Please specify the base model name or path")
50
 
51
+ if controlnet_checkbox:
52
+ if controlnet_mode == "depth_map":
53
+ controlnet = ControlNetModel.from_pretrained(
54
+ "lllyasviel/sd-controlnet-depth",
55
+ cache_dir="./models_cache",
56
+ torch_dtype=torch_dtype
57
+ )
58
+ elif controlnet_mode == "pose_estimation":
59
+ controlnet = ControlNetModel.from_pretrained(
60
+ "lllyasviel/sd-controlnet-openpose",
61
+ cache_dir="./models_cache",
62
+ torch_dtype=torch_dtype
63
+ )
64
+ elif controlnet_mode == "normal_map":
65
+ controlnet = ControlNetModel.from_pretrained(
66
+ "lllyasviel/sd-controlnet-normal",
67
+ cache_dir="./models_cache",
68
+ torch_dtype=torch_dtype
69
+ )
70
+ elif controlnet_mode == "scribbles":
71
+ controlnet = ControlNetModel.from_pretrained(
72
+ "lllyasviel/sd-controlnet-scribble",
73
+ cache_dir="./models_cache",
74
+ torch_dtype=torch_dtype
75
+ )
76
+ else:
77
+ controlnet_mode == "edge_detection":
78
+ controlnet = ControlNetModel.from_pretrained(
79
+ "lllyasviel/sd-controlnet-canny",
80
+ cache_dir="./models_cache",
81
+ torch_dtype=torch_dtype
82
+ )
83
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
84
+ controlnet=controlnet,
85
+ torch_dtype=torch_dtype,
86
+ safety_checker=None).to(device)
87
+ else:
88
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
89
+ torch_dtype=torch_dtype,
90
+ safety_checker=None).to(device)
91
+
92
+
93
+ if ip_adapter_checkbox:
94
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
95
+ pipe.set_ip_adapter_scale(ip_adapter_scale)
96
+
97
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
98
  pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
99
 
 
105
  pipe.text_encoder.half()
106
 
107
  pipe.to(device)
108
+
109
+ if controlnet_checkbox:
110
+ image = pipe(
111
+ prompt=prompt,
112
+ negative_prompt=negative_prompt,
113
+ guidance_scale=guidance_scale,
114
+ num_inference_steps=num_inference_steps,
115
+ width=width,
116
+ height=height,
117
+ generator=generator,
118
+ image=controlnet_image,
119
+ controlnet_conditioning_scale=controlnet_strength,
120
+ ip_adapter_image=ip_adapter_image if ip_adapter_checkbox else None
121
+ ).images[0]
122
+ else:
123
+ image = pipe(
124
+ prompt=prompt,
125
+ negative_prompt=negative_prompt,
126
+ guidance_scale=guidance_scale,
127
+ num_inference_steps=num_inference_steps,
128
+ width=width,
129
+ height=height,
130
+ generator=generator,
131
+ ip_adapter_image=ip_adapter_image if ip_adapter_checkbox else None
132
+ ).images[0]
133
 
134
  return image
135
 
 
204
  label="ControlNet",
205
  )
206
  with gr.Column(visible=False) as controlnet_params:
207
+ controlnet_strength = gr.Slider(
208
  label="ControlNet conditioning scale",
209
  minimum=0.0,
210
  maximum=1.0,
211
  step=0.01,
212
  value=1.0,
213
  )
214
+ controlnet_mode = gr.Dropdown(
215
  label="ControlNet mode",
216
  choices=["edge_detection",
217
+ "depth_map",
218
  "pose_estimation",
219
+ "normal_map",
220
+ "scribbles"],
 
 
221
  value="edge_detection",
222
  max_choices=1
223
  )
224
+ controlnet_image = gr.Image(
225
  label="ControlNet condition image",
226
  type="pil",
227
  format="png"
 
233
  )
234
 
235
  with gr.Row():
236
+ ip_adapter_checkbox = gr.Checkbox(
237
  label="IPAdapter",
238
  )
239
+ with gr.Column(visible=False) as ip_adapter_params:
240
+ ip_adapter_scale = gr.Slider(
241
+ label="IPAdapter scale",
242
  minimum=0.0,
243
  maximum=1.0,
244
  step=0.01,
245
  value=1.0,
246
  )
247
+ ip_adapter_image = gr.Image(
248
+ label="IPAdapter condition image",
249
+ type="pil",
250
+ format="png"
 
251
  )
252
+ ip_adapter_checkbox.change(
253
  fn=lambda x: gr.Row.update(visible=x),
254
+ inputs=ip_adapter_checkbox,
255
+ outputs=ip_adapter_params
256
  )
257
 
258
  with gr.Accordion("Optional Settings", open=False):
 
289
  seed,
290
  guidance_scale,
291
  lora_scale,
292
+ num_inference_steps,
293
+ controlnet_checkbox,
294
+ controlnet_strength,
295
+ controlnet_mode,
296
+ controlnet_image,
297
+ ip_adapter_checkbox,
298
+ ip_adapter_scale,
299
+ ip_adapter_image
300
  ],
301
  outputs=[result],
302
  )