svjack commited on
Commit
69414d8
·
verified ·
1 Parent(s): 5999dd6

Create pose_app.py

Browse files
Files changed (1) hide show
  1. pose_app.py +219 -0
pose_app.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
+ from diffusers.utils import load_image
9
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from kolors.models.controlnet import ControlNetModel
13
+ from diffusers import AutoencoderKL
14
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers import EulerDiscreteScheduler
16
+ from PIL import Image
17
+ from annotator.dwpose import DWposeDetector
18
+ from annotator.util import resize_image, HWC3
19
+
20
+ device = "cuda"
21
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
22
+ ckpt_dir_ipa = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
23
+ ckpt_dir_pose = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Pose")
24
+
25
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
26
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
27
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
28
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
29
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
30
+
31
+ controlnet_pose = ControlNetModel.from_pretrained(f"{ckpt_dir_pose}", revision=None).half().to(device)
32
+
33
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_ipa}/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
34
+ ip_img_size = 336
35
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
36
+
37
+ pipe_pose = StableDiffusionXLControlNetImg2ImgPipeline(
38
+ vae=vae,
39
+ controlnet=controlnet_pose,
40
+ text_encoder=text_encoder,
41
+ tokenizer=tokenizer,
42
+ unet=unet,
43
+ scheduler=scheduler,
44
+ image_encoder=image_encoder,
45
+ feature_extractor=clip_image_processor,
46
+ force_zeros_for_empty_prompt=False
47
+ )
48
+
49
+ pipe_pose.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
50
+
51
+ model_dwpose = DWposeDetector()
52
+
53
+ def process_dwpose_condition(image, res=1024):
54
+ h, w, _ = image.shape
55
+ img = resize_image(HWC3(image), res)
56
+ out_res, out_img = model_dwpose(image)
57
+ result = HWC3(out_img)
58
+ result = cv2.resize(result, (w, h))
59
+ return Image.fromarray(result)
60
+
61
+ MAX_SEED = np.iinfo(np.int32).max
62
+ MAX_IMAGE_SIZE = 1024
63
+
64
+ def infer_pose(prompt,
65
+ image=None,
66
+ ipa_img=None,
67
+ negative_prompt="nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯",
68
+ seed=66,
69
+ randomize_seed=False,
70
+ guidance_scale=5.0,
71
+ num_inference_steps=50,
72
+ controlnet_conditioning_scale=0.5,
73
+ control_guidance_end=0.9,
74
+ strength=1.0,
75
+ ip_scale=0.5,
76
+ ):
77
+ if randomize_seed:
78
+ seed = random.randint(0, MAX_SEED)
79
+ generator = torch.Generator().manual_seed(seed)
80
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
81
+ pipe = pipe_pose.to("cuda")
82
+ pipe.set_ip_adapter_scale([ip_scale])
83
+ condi_img = process_dwpose_condition(np.array(init_image), MAX_IMAGE_SIZE)
84
+ image = pipe(
85
+ prompt=prompt,
86
+ image=init_image,
87
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
88
+ control_guidance_end=control_guidance_end,
89
+ ip_adapter_image=[ipa_img],
90
+ strength=strength,
91
+ control_image=condi_img,
92
+ negative_prompt=negative_prompt,
93
+ num_inference_steps=num_inference_steps,
94
+ guidance_scale=guidance_scale,
95
+ num_images_per_prompt=1,
96
+ generator=generator,
97
+ ).images[0]
98
+ return [condi_img, image], seed
99
+
100
+ pose_examples = [
101
+ ["一位穿着紫色泡泡袖连衣裙、戴着皇冠和白色蕾丝手套的女孩,超高分辨率,最佳品质,8k画质",
102
+ "image/woman_3.png", "image/woman_4.png"],
103
+ ]
104
+
105
+ css = """
106
+ #col-left {
107
+ margin: 0 auto;
108
+ max-width: 600px;
109
+ }
110
+ #col-right {
111
+ margin: 0 auto;
112
+ max-width: 750px;
113
+ }
114
+ #button {
115
+ color: blue;
116
+ }
117
+ """
118
+
119
+ def load_description(fp):
120
+ with open(fp, 'r', encoding='utf-8') as f:
121
+ content = f.read()
122
+ return content
123
+
124
+ with gr.Blocks(css=css) as PoseApp:
125
+ gr.HTML(load_description("assets/title.md"))
126
+ with gr.Row():
127
+ with gr.Column(elem_id="col-left"):
128
+ with gr.Row():
129
+ prompt = gr.Textbox(
130
+ label="Prompt",
131
+ placeholder="Enter your prompt",
132
+ lines=2
133
+ )
134
+ with gr.Row():
135
+ image = gr.Image(label="Image", type="pil")
136
+ ipa_image = gr.Image(label="IP-Adapter-Image", type="pil")
137
+ with gr.Accordion("Advanced Settings", open=False):
138
+ negative_prompt = gr.Textbox(
139
+ label="Negative prompt",
140
+ placeholder="Enter a negative prompt",
141
+ visible=True,
142
+ value="nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯"
143
+ )
144
+ seed = gr.Slider(
145
+ label="Seed",
146
+ minimum=0,
147
+ maximum=MAX_SEED,
148
+ step=1,
149
+ value=0,
150
+ )
151
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
152
+ with gr.Row():
153
+ guidance_scale = gr.Slider(
154
+ label="Guidance scale",
155
+ minimum=0.0,
156
+ maximum=10.0,
157
+ step=0.1,
158
+ value=5.0,
159
+ )
160
+ num_inference_steps = gr.Slider(
161
+ label="Number of inference steps",
162
+ minimum=10,
163
+ maximum=50,
164
+ step=1,
165
+ value=30,
166
+ )
167
+ with gr.Row():
168
+ controlnet_conditioning_scale = gr.Slider(
169
+ label="Controlnet Conditioning Scale",
170
+ minimum=0.0,
171
+ maximum=1.0,
172
+ step=0.1,
173
+ value=0.5,
174
+ )
175
+ control_guidance_end = gr.Slider(
176
+ label="Control Guidance End",
177
+ minimum=0.0,
178
+ maximum=1.0,
179
+ step=0.1,
180
+ value=0.9,
181
+ )
182
+ with gr.Row():
183
+ strength = gr.Slider(
184
+ label="Strength",
185
+ minimum=0.0,
186
+ maximum=1.0,
187
+ step=0.1,
188
+ value=1.0,
189
+ )
190
+ ip_scale = gr.Slider(
191
+ label="IP_Scale",
192
+ minimum=0.0,
193
+ maximum=1.0,
194
+ step=0.1,
195
+ value=0.5,
196
+ )
197
+ with gr.Row():
198
+ pose_button = gr.Button("Pose", elem_id="button")
199
+
200
+ with gr.Column(elem_id="col-right"):
201
+ result = gr.Gallery(label="Result", show_label=False, columns=2)
202
+ seed_used = gr.Number(label="Seed Used")
203
+
204
+ with gr.Row():
205
+ gr.Examples(
206
+ fn=infer_pose,
207
+ examples=pose_examples,
208
+ inputs=[prompt, image, ipa_image],
209
+ outputs=[result, seed_used],
210
+ label="Pose"
211
+ )
212
+
213
+ pose_button.click(
214
+ fn=infer_pose,
215
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale],
216
+ outputs=[result, seed_used]
217
+ )
218
+
219
+ PoseApp.queue().launch(debug=True, share=True)