RamAnanth1 commited on
Commit
4be7ef1
·
1 Parent(s): 1e6d524

First attempt at porting to diffusers

Browse files
Files changed (1) hide show
  1. app.py +112 -201
app.py CHANGED
@@ -4,182 +4,103 @@ import gradio as gr
4
  import numpy as np
5
  import torch
6
 
 
 
 
 
7
 
8
- from pytorch_lightning import seed_everything
9
- from util import resize_image, HWC3, apply_canny
10
- from ldm.models.diffusion.ddim import DDIMSampler
11
-
12
- from annotator.openpose import apply_openpose
13
-
14
- from cldm.model import create_model, load_state_dict
15
-
16
- from huggingface_hub import hf_hub_url, cached_download
17
-
18
- REPO_ID = "lllyasviel/ControlNet"
19
- canny_checkpoint = "models/control_sd15_canny.pth"
20
- scribble_checkpoint = "models/control_sd15_scribble.pth"
21
- pose_checkpoint = "models/control_sd15_openpose.pth"
22
-
23
- # REPO_ID = "webui/ControlNet-modules-safetensors"
24
- # canny_checkpoint = "control_canny-fp16.safetensors"
25
- # scribble_checkpoint = "control_scribble-fp16.safetensors"
26
- # pose_checkpoint = "control_openpose-fp16.safetensors"
27
-
28
- canny_model = create_model('./models/cldm_v15.yaml').cpu()
29
- canny_model.load_state_dict(load_state_dict(cached_download(
30
- hf_hub_url(REPO_ID, canny_checkpoint)
31
- ), location='cpu'))
32
- canny_model = canny_model.cuda()
33
- ddim_sampler = DDIMSampler(canny_model)
34
-
35
- pose_model = create_model('./models/cldm_v15.yaml').cpu()
36
- pose_model.load_state_dict(load_state_dict(cached_download(
37
- hf_hub_url(REPO_ID, pose_checkpoint)
38
- ), location='cpu'))
39
- pose_model = pose_model.cuda()
40
- ddim_sampler_pose = DDIMSampler(pose_model)
41
-
42
- scribble_model = create_model('./models/cldm_v15.yaml').cpu()
43
- scribble_model.load_state_dict(load_state_dict(cached_download(
44
- hf_hub_url(REPO_ID, scribble_checkpoint)
45
- ), location='cpu'))
46
- scribble_model = scribble_model.cuda()
47
- ddim_sampler_scribble = DDIMSampler(scribble_model)
48
 
49
- save_memory = False
50
 
51
- def process(input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
52
- # TODO: Add other control tasks
53
- if input_control == "Scribble":
54
- return process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta)
55
- elif input_control == "Pose":
56
- return process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, image_resolution, ddim_steps, scale, seed, eta)
57
-
58
- return process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold)
59
-
60
- def process_canny(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold):
61
- with torch.no_grad():
62
- img = resize_image(HWC3(input_image), image_resolution)
63
- H, W, C = img.shape
64
-
65
- detected_map = apply_canny(img, low_threshold, high_threshold)
66
- detected_map = HWC3(detected_map)
67
-
68
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
69
- control = torch.stack([control for _ in range(num_samples)], dim=0)
70
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
71
 
72
- seed_everything(seed)
 
 
73
 
74
- if save_memory:
75
- canny_model.low_vram_shift(is_diffusing=False)
76
 
77
- cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
78
- un_cond = {"c_concat": [control], "c_crossattn": [canny_model.get_learned_conditioning([n_prompt] * num_samples)]}
79
- shape = (4, H // 8, W // 8)
80
-
81
- if save_memory:
82
- canny_model.low_vram_shift(is_diffusing=False)
83
-
84
- samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
85
- shape, cond, verbose=False, eta=eta,
86
- unconditional_guidance_scale=scale,
87
- unconditional_conditioning=un_cond)
88
-
89
- if save_memory:
90
- canny_model.low_vram_shift(is_diffusing=False)
91
-
92
- x_samples = canny_model.decode_first_stage(samples)
93
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
94
-
95
- results = [x_samples[i] for i in range(num_samples)]
96
- return [255 - detected_map] + results
97
-
98
- def process_scribble(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta):
99
- with torch.no_grad():
100
- img = resize_image(HWC3(input_image), image_resolution)
101
- H, W, C = img.shape
102
 
103
- detected_map = np.zeros_like(img, dtype=np.uint8)
104
- detected_map[np.min(img, axis=2) < 127] = 255
 
 
 
 
 
 
105
 
106
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
107
- control = torch.stack([control for _ in range(num_samples)], dim=0)
108
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
109
 
110
- seed_everything(seed)
 
111
 
112
- if save_memory:
113
- scribble_model.low_vram_shift(is_diffusing=False)
114
-
115
- cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
116
- un_cond = {"c_concat": [control], "c_crossattn": [scribble_model.get_learned_conditioning([n_prompt] * num_samples)]}
117
- shape = (4, H // 8, W // 8)
118
-
119
- if save_memory:
120
- scribble_model.low_vram_shift(is_diffusing=False)
121
-
122
- samples, intermediates = ddim_sampler_scribble.sample(ddim_steps, num_samples,
123
- shape, cond, verbose=False, eta=eta,
124
- unconditional_guidance_scale=scale,
125
- unconditional_conditioning=un_cond)
126
-
127
- if save_memory:
128
- scribble_model.low_vram_shift(is_diffusing=False)
129
-
130
- x_samples = scribble_model.decode_first_stage(samples)
131
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
132
-
133
- results = [x_samples[i] for i in range(num_samples)]
134
- return [255 - detected_map] + results
135
-
136
- def process_pose(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, detect_resolution, ddim_steps, scale, seed, eta):
137
- with torch.no_grad():
138
- input_image = HWC3(input_image)
139
- detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution))
140
- detected_map = HWC3(detected_map)
141
- img = resize_image(input_image, image_resolution)
142
- H, W, C = img.shape
143
-
144
- detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
145
 
146
- control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
147
- control = torch.stack([control for _ in range(num_samples)], dim=0)
148
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
149
 
150
- if seed == -1:
151
- seed = random.randint(0, 65535)
152
- seed_everything(seed)
153
 
154
- if save_memory:
155
- pose_model.low_vram_shift(is_diffusing=False)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
159
- un_cond = {"c_concat": [control], "c_crossattn": [pose_model.get_learned_conditioning([n_prompt] * num_samples)]}
160
- shape = (4, H // 8, W // 8)
161
-
162
- if save_memory:
163
- pose_model.low_vram_shift(is_diffusing=False)
164
-
165
- samples, intermediates = ddim_sampler_pose.sample(ddim_steps, num_samples,
166
- shape, cond, verbose=False, eta=eta,
167
- unconditional_guidance_scale=scale,
168
- unconditional_conditioning=un_cond)
169
-
170
- if save_memory:
171
- pose_model.low_vram_shift(is_diffusing=False)
172
-
173
- x_samples = pose_model.decode_first_stage(samples)
174
- x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
175
-
176
- results = [x_samples[i] for i in range(num_samples)]
177
- return [detected_map] + results
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- def create_canvas(w, h):
180
- new_control_options = ["Interactive Scribble"]
181
- return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
182
-
183
 
184
  block = gr.Blocks().queue()
185
  control_task_list = [
@@ -222,52 +143,42 @@ with block:
222
  [
223
  "bird.png",
224
  "bird",
225
- "Canny Edge Map",
226
- "best quality, extremely detailed",
227
- 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
228
- 1,
229
- 512,
230
- 20,
231
- 9.0,
232
- 123490213,
233
- 0.0,
234
- 100,
235
- 200
236
 
237
  ],
238
 
239
- [
240
- "turtle.png",
241
- "turtle",
242
- "Scribble",
243
- "best quality, extremely detailed",
244
- 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
245
- 1,
246
- 512,
247
- 20,
248
- 9.0,
249
- 123490213,
250
- 0.0,
251
- 100,
252
- 200
253
 
254
- ],
255
- [
256
- "pose1.png",
257
- "Chef in the Kitchen",
258
- "Pose",
259
- "best quality, extremely detailed",
260
- 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
261
- 1,
262
- 512,
263
- 20,
264
- 9.0,
265
- 123490213,
266
- 0.0,
267
- 100,
268
- 200
269
 
270
- ]
271
  ]
272
  examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold], outputs = [result_gallery], cache_examples = True, fn = process)
273
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=RamAnanth1.ControlNet)")
 
4
  import numpy as np
5
  import torch
6
 
7
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
8
+ from diffusers import UniPCMultistepScheduler
9
+ from PIL import Image
10
+ from controlnet_aux import OpenposeDetector
11
 
12
+ # Constants
13
+ low_threshold = 100
14
+ high_threshold = 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
 
17
+ # Models
18
+ controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
19
+ pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
20
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet_canny, safety_checker=None, torch_dtype=torch.float16
21
+ )
22
+ pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ # This command loads the individual model components on GPU on-demand. So, we don't
25
+ # need to explicitly call pipe.to("cuda").
26
+ pipe_canny.enable_model_cpu_offload()
27
 
28
+ pipe_canny.enable_xformers_memory_efficient_attention()
 
29
 
30
+ # Generator seed,
31
+ generator = torch.manual_seed(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ pose_model = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
34
+ controlnet_pose = ControlNetModel.from_pretrained(
35
+ "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
36
+ )
37
+ pipe_pose = StableDiffusionControlNetPipeline.from_pretrained(
38
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet_pose, safety_checker=None, torch_dtype=torch.float16
39
+ )
40
+ pipe_pose.scheduler = UniPCMultistepScheduler.from_config(pipe_pose.scheduler.config)
41
 
42
+ # This command loads the individual model components on GPU on-demand. So, we don't
43
+ # need to explicitly call pipe.to("cuda").
44
+ pipe_pose.enable_model_cpu_offload()
45
 
46
+ # xformers
47
+ pipe_pose.enable_xformers_memory_efficient_attention()
48
 
49
+ from pytorch_lightning import seed_everything
50
+ from util import resize_image, HWC3, apply_canny
51
+ from ldm.models.diffusion.ddim import DDIMSampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ from annotator.openpose import apply_openpose
 
 
54
 
55
+ from cldm.model import create_model, load_state_dict
 
 
56
 
 
 
57
 
58
+ def get_canny_filter(image):
59
+ if not isinstance(image, np.ndarray):
60
+ image = np.array(image)
61
+
62
+ image = cv2.Canny(image, low_threshold, high_threshold)
63
+ image = image[:
64
+ , :, None]
65
+ image = np.concatenate([image, image, image], axis=2)
66
+ canny_image = Image.fromarray(image)
67
+ return canny_image
68
+
69
+ def get_pose(image):
70
+ return pose_model(image)
71
 
72
+ def process(input_image, prompt, input_control):
73
+ # TODO: Add other control tasks
74
+ if input_control == "Scribble":
75
+ return process_canny(input_image, prompt)
76
+ elif input_control == "Pose":
77
+ return process_pose(input_image, prompt)
78
+
79
+ return process_canny(input_image, prompt)
80
+
81
+ def process_canny(input_image, prompt):
82
+ canny_image = get_canny_filter(input_image)
83
+ output = pipe_canny(
84
+ prompt,
85
+ canny_image,
86
+ generator=generator,
87
+ num_images_per_prompt=1,
88
+ num_inference_steps=20,
89
+ )
90
+ return [canny_image,output.images[0]]
91
+
92
+
93
+ def process_pose(input_image, prompt):
94
+ pose_image = get_pose(input_image)
95
+ output = pipe_pose(
96
+ prompt,
97
+ pose_image,
98
+ generator=generator,
99
+ num_images_per_prompt=1,
100
+ num_inference_steps=20,
101
+ )
102
+ return [pose_image,output.images[0]]
103
 
 
 
 
 
104
 
105
  block = gr.Blocks().queue()
106
  control_task_list = [
 
143
  [
144
  "bird.png",
145
  "bird",
146
+ "Canny Edge Map"
 
 
 
 
 
 
 
 
 
 
147
 
148
  ],
149
 
150
+ # [
151
+ # "turtle.png",
152
+ # "turtle",
153
+ # "Scribble",
154
+ # "best quality, extremely detailed",
155
+ # 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
156
+ # 1,
157
+ # 512,
158
+ # 20,
159
+ # 9.0,
160
+ # 123490213,
161
+ # 0.0,
162
+ # 100,
163
+ # 200
164
 
165
+ # ],
166
+ # [
167
+ # "pose1.png",
168
+ # "Chef in the Kitchen",
169
+ # "Pose",
170
+ # "best quality, extremely detailed",
171
+ # 'longbody, lowres, bad anatomy, bad hands, missing fingers, pubic hair,extra digit, fewer digits, cropped, worst quality, low quality',
172
+ # 1,
173
+ # 512,
174
+ # 20,
175
+ # 9.0,
176
+ # 123490213,
177
+ # 0.0,
178
+ # 100,
179
+ # 200
180
 
181
+ # ]
182
  ]
183
  examples = gr.Examples(examples=examples_list,inputs = [input_image, prompt, input_control, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, scale, seed, eta, low_threshold, high_threshold], outputs = [result_gallery], cache_examples = True, fn = process)
184
  gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=RamAnanth1.ControlNet)")