Adapter commited on
Commit
6d93939
·
1 Parent(s): 2254a67

add compose

Browse files
Files changed (3) hide show
  1. app.py +4 -2
  2. demo/demos.py +32 -0
  3. demo/model.py +106 -4
app.py CHANGED
@@ -8,7 +8,7 @@ os.system('mim install mmcv-full==1.7.0')
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
- from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth
12
  import torch
13
  import subprocess
14
  import shlex
@@ -44,7 +44,7 @@ for url in urls_mmpose:
44
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
  model = Model_all(device)
46
 
47
- DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose & Segmentation)
48
  [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
49
 
50
  This gradio demo is for a simple experience of T2I-Adapter:
@@ -74,5 +74,7 @@ with gr.Blocks(css='style.css') as demo:
74
  create_demo_seg(model.process_seg)
75
  with gr.TabItem('Depth'):
76
  create_demo_depth(model.process_depth)
 
 
77
 
78
  demo.queue().launch(debug=True, server_name='0.0.0.0')
 
8
 
9
  from demo.model import Model_all
10
  import gradio as gr
11
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth, create_demo_depth_keypose
12
  import torch
13
  import subprocess
14
  import shlex
 
44
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
  model = Model_all(device)
46
 
47
+ DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose & Segmentation & Depth)
48
  [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
49
 
50
  This gradio demo is for a simple experience of T2I-Adapter:
 
74
  create_demo_seg(model.process_seg)
75
  with gr.TabItem('Depth'):
76
  create_demo_depth(model.process_depth)
77
+ with gr.TabItem('Multi-adapters (Depth & Keypose)'):
78
+ create_demo_depth_keypose(model.process_depth_keypose)
79
 
80
  demo.queue().launch(debug=True, server_name='0.0.0.0')
demo/demos.py CHANGED
@@ -120,6 +120,38 @@ def create_demo_depth(process):
120
  run_button.click(fn=process, inputs=ips, outputs=[result])
121
  return demo
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def create_demo_draw(process):
124
  with gr.Blocks() as demo:
125
  with gr.Row():
 
120
  run_button.click(fn=process, inputs=ips, outputs=[result])
121
  return demo
122
 
123
+ def create_demo_depth_keypose(process):
124
+ with gr.Blocks() as demo:
125
+ with gr.Row():
126
+ gr.Markdown('## T2I-Adapter (Depth)')
127
+ with gr.Row():
128
+ with gr.Column():
129
+ with gr.Row():
130
+ input_img_depth = gr.Image(source='upload', type="numpy", label='Depth guidance')
131
+ input_img_keypose = gr.Image(source='upload', type="numpy", label='Keypose guidance')
132
+
133
+ prompt = gr.Textbox(label="Prompt")
134
+ neg_prompt = gr.Textbox(label="Negative Prompt",
135
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
136
+ pos_prompt = gr.Textbox(label="Positive Prompt",
137
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
138
+ with gr.Row():
139
+ type_in_depth = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
140
+ type_in_keypose = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='You can input an image or a keypose map (mmpose style)')
141
+ with gr.Row():
142
+ w_depth = gr.Slider(label="Depth guidance weight", minimum=0, maximum=2, value=1.0, step=0.1)
143
+ w_keypose = gr.Slider(label="Keypose guidance weight", minimum=0, maximum=2, value=1.5, step=0.1)
144
+ run_button = gr.Button(label="Run")
145
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the multi-guidance to the result)", minimum=0, maximum=1, value=1, step=0.1)
146
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
147
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
148
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
149
+ with gr.Column():
150
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
151
+ ips = [input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
152
+ run_button.click(fn=process, inputs=ips, outputs=[result])
153
+ return demo
154
+
155
  def create_demo_draw(process):
156
  with gr.Blocks() as demo:
157
  with gr.Row():
demo/model.py CHANGED
@@ -135,8 +135,8 @@ class Model_all:
135
 
136
  # sketch part
137
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
138
- use_conv=False).to(device)
139
- self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
  self.model_edge = pidinet().to(device)
141
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
142
 
@@ -144,8 +144,8 @@ class Model_all:
144
  self.model_seger = seger().to(device)
145
  self.model_seger.eval()
146
  self.coler = Colorize(n=182)
147
- self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
148
- self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
 
151
  # depth part
@@ -311,6 +311,108 @@ class Model_all:
311
 
312
  return [im_depth, x_samples_ddim]
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  @torch.no_grad()
315
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
316
  con_strength, base_model):
 
135
 
136
  # sketch part
137
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
138
+ use_conv=False)#.to(device)
139
+ # self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
  self.model_edge = pidinet().to(device)
141
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
142
 
 
144
  self.model_seger = seger().to(device)
145
  self.model_seger.eval()
146
  self.coler = Colorize(n=182)
147
+ self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)#.to(device)
148
+ # self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
 
151
  # depth part
 
311
 
312
  return [im_depth, x_samples_ddim]
313
 
314
+ @torch.no_grad()
315
+ def process_depth_keypose(self, input_img_depth, input_img_keypose, type_in_depth, type_in_keypose, w_depth, w_keypose, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
316
+ if self.current_base != base_model:
317
+ ckpt = os.path.join("models", base_model)
318
+ pl_sd = torch.load(ckpt, map_location="cuda")
319
+ if "state_dict" in pl_sd:
320
+ sd = pl_sd["state_dict"]
321
+ else:
322
+ sd = pl_sd
323
+ self.base_model.load_state_dict(sd, strict=False)
324
+ self.current_base = base_model
325
+ if 'anything' in base_model.lower():
326
+ self.load_vae()
327
+
328
+ if fix_sample == 'True':
329
+ seed_everything(42)
330
+ im_depth = cv2.resize(input_img_depth, (512, 512))
331
+ im_keypose = cv2.resize(input_img_keypose, (512, 512))
332
+
333
+ # get depth
334
+ if type_in_depth == 'Depth':
335
+ im_depth_out = im_depth.copy()
336
+ depth = img2tensor(im).unsqueeze(0) / 255.
337
+ elif type_in_depth == 'Image':
338
+ im_depth = img2tensor(im_depth).unsqueeze(0) / 127.5 - 1.0
339
+ depth = self.depth_model(im_depth.to(self.device)).repeat(1, 3, 1, 1)
340
+ depth -= torch.min(depth)
341
+ depth /= torch.max(depth)
342
+ im_depth_out = tensor2img(depth)
343
+
344
+ # get keypose
345
+ if type_in_keypose == 'Keypose':
346
+ im_keypose_out = im_keypose.copy()
347
+ pose = img2tensor(im_keypose).unsqueeze(0) / 255.
348
+ elif type_in_keypose == 'Image':
349
+ image = im_keypose.copy()
350
+ im_keypose = img2tensor(im_keypose).unsqueeze(0) / 255.
351
+ mmdet_results = inference_detector(self.det_model, image)
352
+ # keep the person class bounding boxes.
353
+ person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
354
+
355
+ # optional
356
+ return_heatmap = False
357
+ dataset = self.pose_model.cfg.data['test']['type']
358
+
359
+ # e.g. use ('backbone', ) to return backbone feature
360
+ output_layer_names = None
361
+ pose_results, _ = inference_top_down_pose_model(
362
+ self.pose_model,
363
+ image,
364
+ person_results,
365
+ bbox_thr=self.bbox_thr,
366
+ format='xyxy',
367
+ dataset=dataset,
368
+ dataset_info=None,
369
+ return_heatmap=return_heatmap,
370
+ outputs=output_layer_names)
371
+
372
+ # show the results
373
+ im_keypose_out = imshow_keypoints(
374
+ image,
375
+ pose_results,
376
+ skeleton=self.skeleton,
377
+ pose_kpt_color=self.pose_kpt_color,
378
+ pose_link_color=self.pose_link_color,
379
+ radius=2,
380
+ thickness=2)
381
+
382
+ # extract condition features
383
+ c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
384
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
385
+ features_adapter_depth = self.model_depth(depth.to(self.device))
386
+ pose = img2tensor(im_keypose_out, bgr2rgb=True, float32=True) / 255.
387
+ pose = pose.unsqueeze(0)
388
+ features_adapter_keypose = self.model_pose(pose.to(self.device))
389
+ features_adapter = [f_d*w_depth + f_k*w_keypose for f_d, f_k in zip(features_adapter_depth, features_adapter_keypose)]
390
+ shape = [4, 64, 64]
391
+
392
+ # sampling
393
+ con_strength = int((1 - con_strength) * 50)
394
+ samples_ddim, _ = self.sampler.sample(S=50,
395
+ conditioning=c,
396
+ batch_size=1,
397
+ shape=shape,
398
+ verbose=False,
399
+ unconditional_guidance_scale=scale,
400
+ unconditional_conditioning=nc,
401
+ eta=0.0,
402
+ x_T=None,
403
+ features_adapter1=features_adapter,
404
+ mode='sketch',
405
+ con_strength=con_strength)
406
+
407
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
408
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
409
+ x_samples_ddim = x_samples_ddim.to('cpu')
410
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
411
+ x_samples_ddim = 255. * x_samples_ddim
412
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
413
+
414
+ return [im_depth_out, im_keypose_out, x_samples_ddim]
415
+
416
  @torch.no_grad()
417
  def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
418
  con_strength, base_model):