Spaces:
Runtime error
Runtime error
add compose
Browse files- app.py +4 -2
- demo/demos.py +32 -0
- demo/model.py +106 -4
app.py
CHANGED
@@ -8,7 +8,7 @@ os.system('mim install mmcv-full==1.7.0')
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
-
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
@@ -44,7 +44,7 @@ for url in urls_mmpose:
|
|
44 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
45 |
model = Model_all(device)
|
46 |
|
47 |
-
DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose & Segmentation)
|
48 |
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
|
49 |
|
50 |
This gradio demo is for a simple experience of T2I-Adapter:
|
@@ -74,5 +74,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
74 |
create_demo_seg(model.process_seg)
|
75 |
with gr.TabItem('Depth'):
|
76 |
create_demo_depth(model.process_depth)
|
|
|
|
|
77 |
|
78 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
+
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
|
|
44 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
45 |
model = Model_all(device)
|
46 |
|
47 |
+
DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose & Segmentation & Depth)
|
48 |
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
|
49 |
|
50 |
This gradio demo is for a simple experience of T2I-Adapter:
|
|
|
74 |
create_demo_seg(model.process_seg)
|
75 |
with gr.TabItem('Depth'):
|
76 |
create_demo_depth(model.process_depth)
|
77 |
+
with gr.TabItem('Multi-adapters (Depth & Keypose)'):
|
78 |
+
create_demo_depth_keypose(model.process_depth_keypose)
|
79 |
|
80 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
demo/demos.py
CHANGED
@@ -120,6 +120,38 @@ def create_demo_depth(process):
|
|
120 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
121 |
return demo
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def create_demo_draw(process):
|
124 |
with gr.Blocks() as demo:
|
125 |
with gr.Row():
|
|
|
120 |
run_button.click(fn=process, inputs=ips, outputs=[result])
|
121 |
return demo
|
122 |
|
123 |
+
def create_demo_depth_keypose(process):
|
124 |
+
with gr.Blocks() as demo:
|
125 |
+
with gr.Row():
|
126 |
+
gr.Markdown('## T2I-Adapter (Depth)')
|
127 |
+
with gr.Row():
|
128 |
+
with gr.Column():
|
129 |
+
with gr.Row():
|
130 |
+
input_img_depth = gr.Image(source='upload', type="numpy", label='Depth guidance')
|
131 |
+
input_img_keypose = gr.Image(source='upload', type="numpy", label='Keypose guidance')
|
132 |
+
|
133 |
+
prompt = gr.Textbox(label="Prompt")
|
134 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
135 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
136 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
137 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
138 |
+
with gr.Row():
|
139 |
+
type_in_depth = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
|
140 |
+
type_in_keypose = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='You can input an image or a keypose map (mmpose style)')
|
141 |
+
with gr.Row():
|
142 |
+
w_depth = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
|
143 |
+
w_keypose = gr.Slider(label="Keypose guidance weight", minimum=0, maximum=2, value=1.5, step=0.1)
|
144 |
+
run_button = gr.Button(label="Run")
|
145 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the multi-guidance to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
146 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
147 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
148 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
149 |
+
with gr.Column():
|
150 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
|
151 |
+
ips = [input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
152 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
153 |
+
return demo
|
154 |
+
|
155 |
def create_demo_draw(process):
|
156 |
with gr.Blocks() as demo:
|
157 |
with gr.Row():
|
demo/model.py
CHANGED
@@ -135,8 +135,8 @@ class Model_all:
|
|
135 |
|
136 |
# sketch part
|
137 |
self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
138 |
-
use_conv=False)
|
139 |
-
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
140 |
self.model_edge = pidinet().to(device)
|
141 |
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
|
142 |
|
@@ -144,8 +144,8 @@ class Model_all:
|
|
144 |
self.model_seger = seger().to(device)
|
145 |
self.model_seger.eval()
|
146 |
self.coler = Colorize(n=182)
|
147 |
-
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)
|
148 |
-
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
149 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
150 |
|
151 |
# depth part
|
@@ -311,6 +311,108 @@ class Model_all:
|
|
311 |
|
312 |
return [im_depth, x_samples_ddim]
|
313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
@torch.no_grad()
|
315 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
316 |
con_strength, base_model):
|
|
|
135 |
|
136 |
# sketch part
|
137 |
self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
138 |
+
use_conv=False)#.to(device)
|
139 |
+
# self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
140 |
self.model_edge = pidinet().to(device)
|
141 |
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
|
142 |
|
|
|
144 |
self.model_seger = seger().to(device)
|
145 |
self.model_seger.eval()
|
146 |
self.coler = Colorize(n=182)
|
147 |
+
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)#.to(device)
|
148 |
+
# self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
149 |
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
150 |
|
151 |
# depth part
|
|
|
311 |
|
312 |
return [im_depth, x_samples_ddim]
|
313 |
|
314 |
+
@torch.no_grad()
|
315 |
+
def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
316 |
+
if self.current_base != base_model:
|
317 |
+
ckpt = os.path.join("models", base_model)
|
318 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
319 |
+
if "state_dict" in pl_sd:
|
320 |
+
sd = pl_sd["state_dict"]
|
321 |
+
else:
|
322 |
+
sd = pl_sd
|
323 |
+
self.base_model.load_state_dict(sd, strict=False)
|
324 |
+
self.current_base = base_model
|
325 |
+
if 'anything' in base_model.lower():
|
326 |
+
self.load_vae()
|
327 |
+
|
328 |
+
if fix_sample == 'True':
|
329 |
+
seed_everything(42)
|
330 |
+
im_depth = cv2.resize(input_img_depth, (512, 512))
|
331 |
+
im_keypose = cv2.resize(input_img_keypose, (512, 512))
|
332 |
+
|
333 |
+
# get depth
|
334 |
+
if type_in_depth == 'Depth':
|
335 |
+
im_depth_out = im_depth.copy()
|
336 |
+
depth = img2tensor(im).unsqueeze(0) / 255.
|
337 |
+
elif type_in_depth == 'Image':
|
338 |
+
im_depth = img2tensor(im_depth).unsqueeze(0) / 127.5 - 1.0
|
339 |
+
depth = self.depth_model(im_depth.to(self.device)).repeat(1, 3, 1, 1)
|
340 |
+
depth -= torch.min(depth)
|
341 |
+
depth /= torch.max(depth)
|
342 |
+
im_depth_out = tensor2img(depth)
|
343 |
+
|
344 |
+
# get keypose
|
345 |
+
if type_in_keypose == 'Keypose':
|
346 |
+
im_keypose_out = im_keypose.copy()
|
347 |
+
pose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
348 |
+
elif type_in_keypose == 'Image':
|
349 |
+
image = im_keypose.copy()
|
350 |
+
im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
|
351 |
+
mmdet_results = inference_detector(self.det_model, image)
|
352 |
+
# keep the person class bounding boxes.
|
353 |
+
person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
|
354 |
+
|
355 |
+
# optional
|
356 |
+
return_heatmap = False
|
357 |
+
dataset = self.pose_model.cfg.data['test']['type']
|
358 |
+
|
359 |
+
# e.g. use ('backbone', ) to return backbone feature
|
360 |
+
output_layer_names = None
|
361 |
+
pose_results, _ = inference_top_down_pose_model(
|
362 |
+
self.pose_model,
|
363 |
+
image,
|
364 |
+
person_results,
|
365 |
+
bbox_thr=self.bbox_thr,
|
366 |
+
format='xyxy',
|
367 |
+
dataset=dataset,
|
368 |
+
dataset_info=None,
|
369 |
+
return_heatmap=return_heatmap,
|
370 |
+
outputs=output_layer_names)
|
371 |
+
|
372 |
+
# show the results
|
373 |
+
im_keypose_out = imshow_keypoints(
|
374 |
+
image,
|
375 |
+
pose_results,
|
376 |
+
skeleton=self.skeleton,
|
377 |
+
pose_kpt_color=self.pose_kpt_color,
|
378 |
+
pose_link_color=self.pose_link_color,
|
379 |
+
radius=2,
|
380 |
+
thickness=2)
|
381 |
+
|
382 |
+
# extract condition features
|
383 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
384 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
385 |
+
features_adapter_depth = self.model_depth(depth.to(self.device))
|
386 |
+
pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
|
387 |
+
pose = pose.unsqueeze(0)
|
388 |
+
features_adapter_keypose = self.model_pose(pose.to(self.device))
|
389 |
+
features_adapter = [f_d*w_depth + f_k*w_keypose for f_d, f_k in zip(features_adapter_depth, features_adapter_keypose)]
|
390 |
+
shape = [4, 64, 64]
|
391 |
+
|
392 |
+
# sampling
|
393 |
+
con_strength = int((1 - con_strength) * 50)
|
394 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
395 |
+
conditioning=c,
|
396 |
+
batch_size=1,
|
397 |
+
shape=shape,
|
398 |
+
verbose=False,
|
399 |
+
unconditional_guidance_scale=scale,
|
400 |
+
unconditional_conditioning=nc,
|
401 |
+
eta=0.0,
|
402 |
+
x_T=None,
|
403 |
+
features_adapter1=features_adapter,
|
404 |
+
mode='sketch',
|
405 |
+
con_strength=con_strength)
|
406 |
+
|
407 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
408 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
409 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
410 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
411 |
+
x_samples_ddim = 255. * x_samples_ddim
|
412 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
413 |
+
|
414 |
+
return [im_depth_out, im_keypose_out, x_samples_ddim]
|
415 |
+
|
416 |
@torch.no_grad()
|
417 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
418 |
con_strength, base_model):
|