Spaces:
Running
on
Zero
Running
on
Zero
Upload 114 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- EdgeCape/VERSION +1 -0
- EdgeCape/__init__.py +3 -0
- EdgeCape/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/apis/__init__.py +5 -0
- EdgeCape/apis/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/apis/__pycache__/test.cpython-39.pyc +0 -0
- EdgeCape/apis/__pycache__/train.cpython-39.pyc +0 -0
- EdgeCape/apis/test.py +198 -0
- EdgeCape/apis/train.py +124 -0
- EdgeCape/core/__init__.py +1 -0
- EdgeCape/core/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/core/custom_hooks/__pycache__/shuffle_hooks.cpython-39.pyc +0 -0
- EdgeCape/core/custom_hooks/shuffle_hooks.py +28 -0
- EdgeCape/datasets/__init__.py +3 -0
- EdgeCape/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/datasets/__pycache__/builder.cpython-39.pyc +0 -0
- EdgeCape/datasets/builder.py +55 -0
- EdgeCape/datasets/datasets/__init__.py +6 -0
- EdgeCape/datasets/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__init__.py +13 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/custom_test_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_base_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/test_base_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/test_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/transformer_base_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/__pycache__/transformer_dataset.cpython-39.pyc +0 -0
- EdgeCape/datasets/datasets/mp100/custom_test_dataset.py +355 -0
- EdgeCape/datasets/datasets/mp100/fewshot_base_dataset.py +223 -0
- EdgeCape/datasets/datasets/mp100/fewshot_dataset.py +312 -0
- EdgeCape/datasets/datasets/mp100/test_base_dataset.py +226 -0
- EdgeCape/datasets/datasets/mp100/test_dataset.py +319 -0
- EdgeCape/datasets/datasets/mp100/transformer_base_dataset.py +209 -0
- EdgeCape/datasets/datasets/mp100/transformer_dataset.py +319 -0
- EdgeCape/datasets/pipelines/__init__.py +8 -0
- EdgeCape/datasets/pipelines/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/datasets/pipelines/__pycache__/post_transforms.cpython-39.pyc +0 -0
- EdgeCape/datasets/pipelines/__pycache__/top_down_transform.cpython-39.pyc +0 -0
- EdgeCape/datasets/pipelines/post_transforms.py +121 -0
- EdgeCape/datasets/pipelines/top_down_transform.py +716 -0
- EdgeCape/models/__init__.py +3 -0
- EdgeCape/models/__pycache__/__init__.cpython-39.pyc +0 -0
- EdgeCape/models/backbones/__pycache__/adapter.cpython-39.pyc +0 -0
- EdgeCape/models/backbones/__pycache__/dino.cpython-39.pyc +0 -0
- EdgeCape/models/backbones/adapter.py +935 -0
- EdgeCape/models/backbones/dino.py +206 -0
- EdgeCape/models/detectors/EdgeCape.py +392 -0
- EdgeCape/models/detectors/__init__.py +3 -0
- EdgeCape/models/detectors/__pycache__/EdgeCape.cpython-39.pyc +0 -0
EdgeCape/VERSION
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0.2.0
|
EdgeCape/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .core import * # noqa
|
2 |
+
from .datasets import * # noqa
|
3 |
+
from .models import * # noqa
|
EdgeCape/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (206 Bytes). View file
|
|
EdgeCape/apis/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .train import train_model
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'train_model'
|
5 |
+
]
|
EdgeCape/apis/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (217 Bytes). View file
|
|
EdgeCape/apis/__pycache__/test.cpython-39.pyc
ADDED
Binary file (5.14 kB). View file
|
|
EdgeCape/apis/__pycache__/train.cpython-39.pyc
ADDED
Binary file (3.19 kB). View file
|
|
EdgeCape/apis/test.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import os.path as osp
|
3 |
+
import pickle
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
from mmcv.runner import get_dist_info
|
12 |
+
|
13 |
+
|
14 |
+
def single_gpu_test(model, data_loader):
|
15 |
+
"""Test model with a single gpu.
|
16 |
+
|
17 |
+
This method tests model with a single gpu and displays test progress bar.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model (nn.Module): Model to be tested.
|
21 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
22 |
+
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
list: The prediction results.
|
26 |
+
"""
|
27 |
+
model.eval()
|
28 |
+
results = []
|
29 |
+
dataset = data_loader.dataset
|
30 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
31 |
+
for data in data_loader:
|
32 |
+
with torch.no_grad():
|
33 |
+
result = model(return_loss=False, **data)
|
34 |
+
batch_size = len(next(iter(data.values()))[0])
|
35 |
+
# results.append(result)
|
36 |
+
if 'preds' in result:
|
37 |
+
for i in range(batch_size):
|
38 |
+
results.append({
|
39 |
+
'preds': result['preds'][i][None],
|
40 |
+
'boxes': result['boxes'][i][None],
|
41 |
+
'bbox_ids': [result['bbox_ids'][i]],
|
42 |
+
'image_paths': [result['image_paths'][i]],
|
43 |
+
})
|
44 |
+
# use the first key as main key to calculate the batch size
|
45 |
+
# for _ in range(batch_size):
|
46 |
+
prog_bar.update(batch_size)
|
47 |
+
return results
|
48 |
+
|
49 |
+
|
50 |
+
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
51 |
+
"""Test model with multiple gpus.
|
52 |
+
|
53 |
+
This method tests model with multiple gpus and collects the results
|
54 |
+
under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
|
55 |
+
it encodes results to gpu tensors and use gpu communication for results
|
56 |
+
collection. On cpu mode it saves the results on different gpus to 'tmpdir'
|
57 |
+
and collects them by the rank 0 worker.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
model (nn.Module): Model to be tested.
|
61 |
+
data_loader (nn.Dataloader): Pytorch data loader.
|
62 |
+
tmpdir (str): Path of directory to save the temporary results from
|
63 |
+
different gpus under cpu mode.
|
64 |
+
gpu_collect (bool): Option to use either gpu or cpu to collect results.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
list: The prediction results.
|
68 |
+
"""
|
69 |
+
model.eval()
|
70 |
+
results = []
|
71 |
+
dataset = data_loader.dataset
|
72 |
+
rank, world_size = get_dist_info()
|
73 |
+
if rank == 0:
|
74 |
+
prog_bar = mmcv.ProgressBar(len(dataset))
|
75 |
+
for data in data_loader:
|
76 |
+
with torch.no_grad():
|
77 |
+
result = model(return_loss=False, **data)
|
78 |
+
results.append(result)
|
79 |
+
|
80 |
+
if rank == 0:
|
81 |
+
# use the first key as main key to calculate the batch size
|
82 |
+
batch_size = len(next(iter(data.values())))
|
83 |
+
for _ in range(batch_size * world_size):
|
84 |
+
prog_bar.update()
|
85 |
+
|
86 |
+
# collect results from all ranks
|
87 |
+
if gpu_collect:
|
88 |
+
results = collect_results_gpu(results, len(dataset))
|
89 |
+
else:
|
90 |
+
results = collect_results_cpu(results, len(dataset), tmpdir)
|
91 |
+
return results
|
92 |
+
|
93 |
+
|
94 |
+
def collect_results_cpu(result_part, size, tmpdir=None):
|
95 |
+
"""Collect results in cpu mode.
|
96 |
+
|
97 |
+
It saves the results on different gpus to 'tmpdir' and collects
|
98 |
+
them by the rank 0 worker.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
result_part (list): Results to be collected
|
102 |
+
size (int): Result size.
|
103 |
+
tmpdir (str): Path of directory to save the temporary results from
|
104 |
+
different gpus under cpu mode. Default: None
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
list: Ordered results.
|
108 |
+
"""
|
109 |
+
rank, world_size = get_dist_info()
|
110 |
+
# create a tmp dir if it is not specified
|
111 |
+
if tmpdir is None:
|
112 |
+
MAX_LEN = 512
|
113 |
+
# 32 is whitespace
|
114 |
+
dir_tensor = torch.full((MAX_LEN, ),
|
115 |
+
32,
|
116 |
+
dtype=torch.uint8,
|
117 |
+
device='cuda')
|
118 |
+
if rank == 0:
|
119 |
+
mmcv.mkdir_or_exist('.dist_test')
|
120 |
+
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
121 |
+
tmpdir = torch.tensor(
|
122 |
+
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
123 |
+
dir_tensor[:len(tmpdir)] = tmpdir
|
124 |
+
dist.broadcast(dir_tensor, 0)
|
125 |
+
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
126 |
+
else:
|
127 |
+
mmcv.mkdir_or_exist(tmpdir)
|
128 |
+
# synchronizes all processes to make sure tmpdir exist
|
129 |
+
dist.barrier()
|
130 |
+
# dump the part result to the dir
|
131 |
+
mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
|
132 |
+
# synchronizes all processes for loading pickle file
|
133 |
+
dist.barrier()
|
134 |
+
# collect all parts
|
135 |
+
if rank != 0:
|
136 |
+
return None
|
137 |
+
|
138 |
+
# load results of all parts from tmp dir
|
139 |
+
part_list = []
|
140 |
+
for i in range(world_size):
|
141 |
+
part_file = osp.join(tmpdir, f'part_{i}.pkl')
|
142 |
+
part_list.append(mmcv.load(part_file))
|
143 |
+
# sort the results
|
144 |
+
ordered_results = []
|
145 |
+
for res in zip(*part_list):
|
146 |
+
ordered_results.extend(list(res))
|
147 |
+
# the dataloader may pad some samples
|
148 |
+
ordered_results = ordered_results[:size]
|
149 |
+
# remove tmp dir
|
150 |
+
shutil.rmtree(tmpdir)
|
151 |
+
return ordered_results
|
152 |
+
|
153 |
+
|
154 |
+
def collect_results_gpu(result_part, size):
|
155 |
+
"""Collect results in gpu mode.
|
156 |
+
|
157 |
+
It encodes results to gpu tensors and use gpu communication for results
|
158 |
+
collection.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
result_part (list): Results to be collected
|
162 |
+
size (int): Result size.
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
list: Ordered results.
|
166 |
+
"""
|
167 |
+
|
168 |
+
rank, world_size = get_dist_info()
|
169 |
+
# dump result part to tensor with pickle
|
170 |
+
part_tensor = torch.tensor(
|
171 |
+
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
172 |
+
# gather all result part tensor shape
|
173 |
+
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
174 |
+
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
175 |
+
dist.all_gather(shape_list, shape_tensor)
|
176 |
+
# padding result part tensor to max length
|
177 |
+
shape_max = torch.tensor(shape_list).max()
|
178 |
+
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
179 |
+
part_send[:shape_tensor[0]] = part_tensor
|
180 |
+
part_recv_list = [
|
181 |
+
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
182 |
+
]
|
183 |
+
# gather all result part
|
184 |
+
dist.all_gather(part_recv_list, part_send)
|
185 |
+
|
186 |
+
if rank == 0:
|
187 |
+
part_list = []
|
188 |
+
for recv, shape in zip(part_recv_list, shape_list):
|
189 |
+
part_list.append(
|
190 |
+
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
191 |
+
# sort the results
|
192 |
+
ordered_results = []
|
193 |
+
for res in zip(*part_list):
|
194 |
+
ordered_results.extend(list(res))
|
195 |
+
# the dataloader may pad some samples
|
196 |
+
ordered_results = ordered_results[:size]
|
197 |
+
return ordered_results
|
198 |
+
return None
|
EdgeCape/apis/train.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
4 |
+
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
|
5 |
+
build_optimizer)
|
6 |
+
|
7 |
+
from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook
|
8 |
+
from mmpose.datasets import build_dataloader
|
9 |
+
from mmpose.utils import get_root_logger
|
10 |
+
from EdgeCape.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook
|
11 |
+
|
12 |
+
def train_model(model,
|
13 |
+
dataset,
|
14 |
+
val_dataset,
|
15 |
+
cfg,
|
16 |
+
distributed=False,
|
17 |
+
validate=False,
|
18 |
+
timestamp=None,
|
19 |
+
meta=None):
|
20 |
+
"""Train model entry function.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
model (nn.Module): The model to be trained.
|
24 |
+
dataset (Dataset): Train dataset.
|
25 |
+
cfg (dict): The config dict for training.
|
26 |
+
distributed (bool): Whether to use distributed training.
|
27 |
+
Default: False.
|
28 |
+
validate (bool): Whether to do evaluation. Default: False.
|
29 |
+
timestamp (str | None): Local time for runner. Default: None.
|
30 |
+
meta (dict | None): Meta dict to record some important information.
|
31 |
+
Default: None
|
32 |
+
"""
|
33 |
+
logger = get_root_logger(cfg.log_level)
|
34 |
+
|
35 |
+
# prepare data loaders
|
36 |
+
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
37 |
+
dataloader_setting = dict(
|
38 |
+
samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
|
39 |
+
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
|
40 |
+
# cfg.gpus will be ignored if distributed
|
41 |
+
num_gpus=len(cfg.gpu_ids),
|
42 |
+
dist=distributed,
|
43 |
+
seed=cfg.seed,
|
44 |
+
pin_memory=False,
|
45 |
+
)
|
46 |
+
dataloader_setting = dict(dataloader_setting,
|
47 |
+
**cfg.data.get('train_dataloader', {}))
|
48 |
+
|
49 |
+
data_loaders = [
|
50 |
+
build_dataloader(ds, **dataloader_setting) for ds in dataset
|
51 |
+
]
|
52 |
+
|
53 |
+
# put model on gpus
|
54 |
+
if distributed:
|
55 |
+
find_unused_parameters = cfg.get('find_unused_parameters', True) # NOTE: True has been modified to False for faster training.
|
56 |
+
# Sets the `find_unused_parameters` parameter in
|
57 |
+
# torch.nn.parallel.DistributedDataParallel
|
58 |
+
model = MMDistributedDataParallel(
|
59 |
+
model.cuda(),
|
60 |
+
device_ids=[torch.cuda.current_device()],
|
61 |
+
broadcast_buffers=False,
|
62 |
+
find_unused_parameters=find_unused_parameters)
|
63 |
+
else:
|
64 |
+
model = MMDataParallel(
|
65 |
+
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
|
66 |
+
|
67 |
+
# build runner
|
68 |
+
optimizer = build_optimizer(model, cfg.optimizer)
|
69 |
+
runner = EpochBasedRunner(
|
70 |
+
model,
|
71 |
+
optimizer=optimizer,
|
72 |
+
work_dir=cfg.work_dir,
|
73 |
+
logger=logger,
|
74 |
+
meta=meta)
|
75 |
+
# an ugly workaround to make .log and .log.json filenames the same
|
76 |
+
runner.timestamp = timestamp
|
77 |
+
|
78 |
+
# fp16 setting
|
79 |
+
fp16_cfg = cfg.get('fp16', None)
|
80 |
+
if fp16_cfg is not None:
|
81 |
+
optimizer_config = Fp16OptimizerHook(
|
82 |
+
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
|
83 |
+
elif distributed and 'type' not in cfg.optimizer_config:
|
84 |
+
optimizer_config = OptimizerHook(**cfg.optimizer_config)
|
85 |
+
else:
|
86 |
+
optimizer_config = cfg.optimizer_config
|
87 |
+
|
88 |
+
# register hooks
|
89 |
+
runner.register_training_hooks(cfg.lr_config, optimizer_config,
|
90 |
+
cfg.checkpoint_config, cfg.log_config,
|
91 |
+
cfg.get('momentum_config', None))
|
92 |
+
if distributed:
|
93 |
+
runner.register_hook(DistSamplerSeedHook())
|
94 |
+
|
95 |
+
shuffle_cfg = cfg.get('shuffle_cfg', None)
|
96 |
+
if shuffle_cfg is not None:
|
97 |
+
for data_loader in data_loaders:
|
98 |
+
runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))
|
99 |
+
|
100 |
+
# register eval hooks
|
101 |
+
if validate:
|
102 |
+
eval_cfg = cfg.get('evaluation', {})
|
103 |
+
eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])
|
104 |
+
dataloader_setting = dict(
|
105 |
+
# samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
|
106 |
+
samples_per_gpu=1,
|
107 |
+
workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
|
108 |
+
# cfg.gpus will be ignored if distributed
|
109 |
+
num_gpus=len(cfg.gpu_ids),
|
110 |
+
dist=distributed,
|
111 |
+
shuffle=False,
|
112 |
+
pin_memory=False,
|
113 |
+
)
|
114 |
+
dataloader_setting = dict(dataloader_setting,
|
115 |
+
**cfg.data.get('val_dataloader', {}))
|
116 |
+
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
|
117 |
+
eval_hook = DistEvalHook if distributed else EvalHook
|
118 |
+
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
119 |
+
|
120 |
+
if cfg.resume_from:
|
121 |
+
runner.resume(cfg.resume_from)
|
122 |
+
elif cfg.load_from:
|
123 |
+
runner.load_checkpoint(cfg.load_from)
|
124 |
+
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
|
EdgeCape/core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
EdgeCape/core/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
EdgeCape/core/custom_hooks/__pycache__/shuffle_hooks.cpython-39.pyc
ADDED
Binary file (1.27 kB). View file
|
|
EdgeCape/core/custom_hooks/shuffle_hooks.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.runner import Hook
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from mmpose.utils import get_root_logger
|
4 |
+
|
5 |
+
class ShufflePairedSamplesHook(Hook):
|
6 |
+
"""Non-Distributed ShufflePairedSamples.
|
7 |
+
After each training epoch, run FewShotKeypointDataset.random_paired_samples()
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
dataloader,
|
12 |
+
interval=1):
|
13 |
+
if not isinstance(dataloader, DataLoader):
|
14 |
+
raise TypeError(f'dataloader must be a pytorch DataLoader, '
|
15 |
+
f'but got {type(dataloader)}')
|
16 |
+
|
17 |
+
self.dataloader = dataloader
|
18 |
+
self.interval = interval
|
19 |
+
self.logger = get_root_logger()
|
20 |
+
|
21 |
+
def after_train_epoch(self, runner):
|
22 |
+
"""Called after every training epoch to evaluate the results."""
|
23 |
+
if not self.every_n_epochs(runner, self.interval):
|
24 |
+
return
|
25 |
+
# self.logger.info("Run random_paired_samples()")
|
26 |
+
# self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}")
|
27 |
+
self.dataloader.dataset.random_paired_samples()
|
28 |
+
# self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}")
|
EdgeCape/datasets/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .builder import * # noqa
|
2 |
+
from .datasets import * # noqa
|
3 |
+
from .pipelines import * # noqa
|
EdgeCape/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (221 Bytes). View file
|
|
EdgeCape/datasets/__pycache__/builder.cpython-39.pyc
ADDED
Binary file (1.9 kB). View file
|
|
EdgeCape/datasets/builder.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmcv.utils import build_from_cfg
|
2 |
+
from torch.utils.data.dataset import ConcatDataset
|
3 |
+
|
4 |
+
from mmpose.datasets.dataset_wrappers import RepeatDataset
|
5 |
+
from mmpose.datasets.builder import DATASETS
|
6 |
+
|
7 |
+
|
8 |
+
def _concat_cfg(cfg):
|
9 |
+
replace = ['ann_file', 'img_prefix']
|
10 |
+
channels = ['num_joints', 'dataset_channel']
|
11 |
+
concat_cfg = []
|
12 |
+
for i in range(len(cfg['type'])):
|
13 |
+
cfg_tmp = cfg.deepcopy()
|
14 |
+
cfg_tmp['type'] = cfg['type'][i]
|
15 |
+
for item in replace:
|
16 |
+
assert item in cfg_tmp
|
17 |
+
assert len(cfg['type']) == len(cfg[item]), (cfg[item])
|
18 |
+
cfg_tmp[item] = cfg[item][i]
|
19 |
+
for item in channels:
|
20 |
+
assert item in cfg_tmp['data_cfg']
|
21 |
+
assert len(cfg['type']) == len(cfg['data_cfg'][item])
|
22 |
+
cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
|
23 |
+
concat_cfg.append(cfg_tmp)
|
24 |
+
return concat_cfg
|
25 |
+
|
26 |
+
|
27 |
+
def _check_vaild(cfg):
|
28 |
+
replace = ['num_joints', 'dataset_channel']
|
29 |
+
if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
|
30 |
+
for item in replace:
|
31 |
+
cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
|
32 |
+
return cfg
|
33 |
+
|
34 |
+
|
35 |
+
def build_dataset(cfg, default_args=None):
|
36 |
+
"""Build a dataset from config dict.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
cfg (dict): Config dict. It should at least contain the key "type".
|
40 |
+
default_args (dict, optional): Default initialization arguments.
|
41 |
+
Default: None.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Dataset: The constructed dataset.
|
45 |
+
"""
|
46 |
+
if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset
|
47 |
+
dataset = ConcatDataset(
|
48 |
+
[build_dataset(c, default_args) for c in _concat_cfg(cfg)])
|
49 |
+
elif cfg['type'] == 'RepeatDataset':
|
50 |
+
dataset = RepeatDataset(
|
51 |
+
build_dataset(cfg['dataset'], default_args), cfg['times'])
|
52 |
+
else:
|
53 |
+
cfg = _check_vaild(cfg)
|
54 |
+
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
55 |
+
return dataset
|
EdgeCape/datasets/datasets/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .mp100 import (FewShotKeypointDataset, FewShotBaseDataset,
|
2 |
+
TransformerBaseDataset, TransformerPoseDataset,)
|
3 |
+
|
4 |
+
__all__ = ['FewShotBaseDataset', 'FewShotKeypointDataset',
|
5 |
+
'TransformerBaseDataset', 'TransformerPoseDataset',
|
6 |
+
]
|
EdgeCape/datasets/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (353 Bytes). View file
|
|
EdgeCape/datasets/datasets/mp100/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .fewshot_dataset import FewShotKeypointDataset
|
2 |
+
from .fewshot_base_dataset import FewShotBaseDataset
|
3 |
+
from .transformer_dataset import TransformerPoseDataset
|
4 |
+
from .transformer_base_dataset import TransformerBaseDataset
|
5 |
+
from .test_base_dataset import TestBaseDataset
|
6 |
+
from .test_dataset import TestPoseDataset
|
7 |
+
from .custom_test_dataset import CustomTestPoseDataset
|
8 |
+
|
9 |
+
__all__ = [
|
10 |
+
'FewShotKeypointDataset', 'FewShotBaseDataset',
|
11 |
+
'TransformerPoseDataset', 'TransformerBaseDataset',
|
12 |
+
'TestBaseDataset', 'TestPoseDataset', 'CustomTestPoseDataset'
|
13 |
+
]
|
EdgeCape/datasets/datasets/mp100/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (663 Bytes). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/custom_test_dataset.cpython-39.pyc
ADDED
Binary file (10.2 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_base_dataset.cpython-39.pyc
ADDED
Binary file (7.22 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_dataset.cpython-39.pyc
ADDED
Binary file (8.95 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/test_base_dataset.cpython-39.pyc
ADDED
Binary file (7.4 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/test_dataset.cpython-39.pyc
ADDED
Binary file (9.02 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/transformer_base_dataset.cpython-39.pyc
ADDED
Binary file (7.23 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/__pycache__/transformer_dataset.cpython-39.pyc
ADDED
Binary file (9.06 kB). View file
|
|
EdgeCape/datasets/datasets/mp100/custom_test_dataset.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmpose.datasets import DATASETS
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
from xtcocotools.coco import COCO
|
7 |
+
from .test_base_dataset import TestBaseDataset
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class CustomTestPoseDataset(TestBaseDataset):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
ann_file,
|
14 |
+
img_prefix,
|
15 |
+
data_cfg,
|
16 |
+
pipeline,
|
17 |
+
valid_class_ids,
|
18 |
+
max_kpt_num=None,
|
19 |
+
num_shots=1,
|
20 |
+
num_queries=100,
|
21 |
+
num_episodes=1,
|
22 |
+
pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],
|
23 |
+
test_mode=True):
|
24 |
+
super().__init__(
|
25 |
+
ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)
|
26 |
+
|
27 |
+
self.ann_info['flip_pairs'] = []
|
28 |
+
|
29 |
+
self.ann_info['upper_body_ids'] = []
|
30 |
+
self.ann_info['lower_body_ids'] = []
|
31 |
+
|
32 |
+
self.ann_info['use_different_joint_weights'] = False
|
33 |
+
self.ann_info['joint_weights'] = np.array([1.,],
|
34 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
35 |
+
|
36 |
+
self.coco = COCO(ann_file)
|
37 |
+
|
38 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
39 |
+
self.img_ids = self.coco.getImgIds()
|
40 |
+
|
41 |
+
cat = None
|
42 |
+
relevant_names = [
|
43 |
+
'000000052046',
|
44 |
+
'000000052152'
|
45 |
+
|
46 |
+
# '000000027059',
|
47 |
+
# '000000030361'
|
48 |
+
# '000000027936',
|
49 |
+
# 'Pileated_Woodpecker_0004_180307', 'American_Three_Toed_Woodpecker_0019_179870'
|
50 |
+
# '000000016379', '000000008869'
|
51 |
+
# 'commonwarthog_115',
|
52 |
+
# 'commonwarthog_78'
|
53 |
+
# '000000027059', '000000030361', '000000027936'
|
54 |
+
# 'klipspringer_66', '000000008333', '000000026814', '000000047543', '000000052080', 'Common_Tern_0050_148928'
|
55 |
+
]
|
56 |
+
if len(relevant_names) > 0:
|
57 |
+
if cat is not None:
|
58 |
+
relevant_names = [os.path.join(cat, name) for name in relevant_names]
|
59 |
+
self.img_ids = [img_id for img_id in self.img_ids if self.id2name[img_id] in relevant_names]
|
60 |
+
else:
|
61 |
+
new_ids = []
|
62 |
+
for relevant_name in relevant_names:
|
63 |
+
new_ids += [img_id for img_id in self.img_ids if relevant_name in self.id2name[img_id]]
|
64 |
+
self.img_ids = new_ids
|
65 |
+
else:
|
66 |
+
self.img_ids = [img_id for img_id in self.img_ids if cat == self.id2name[img_id].split('/')[0]]
|
67 |
+
|
68 |
+
self.classes = [
|
69 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
70 |
+
]
|
71 |
+
|
72 |
+
self.num_classes = len(self.classes)
|
73 |
+
self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
|
74 |
+
self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
|
75 |
+
|
76 |
+
if valid_class_ids is not None: # None by default
|
77 |
+
self.valid_class_ids = valid_class_ids
|
78 |
+
else:
|
79 |
+
self.valid_class_ids = self.coco.getCatIds()
|
80 |
+
|
81 |
+
self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
|
82 |
+
self.cats = self.coco.cats
|
83 |
+
self.max_kpt_num = max_kpt_num
|
84 |
+
|
85 |
+
# Also update self.cat2obj
|
86 |
+
self.db = self._get_db()
|
87 |
+
|
88 |
+
self.num_shots = num_shots
|
89 |
+
|
90 |
+
if not test_mode:
|
91 |
+
# Update every training epoch
|
92 |
+
self.random_paired_samples()
|
93 |
+
else:
|
94 |
+
self.num_queries = num_queries
|
95 |
+
self.num_episodes = num_episodes
|
96 |
+
self.make_paired_samples()
|
97 |
+
|
98 |
+
|
99 |
+
def random_paired_samples(self):
|
100 |
+
num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
|
101 |
+
|
102 |
+
# balance the dataset
|
103 |
+
max_num_data = max(num_datas)
|
104 |
+
|
105 |
+
all_samples = []
|
106 |
+
for cls in self.valid_class_ids:
|
107 |
+
for i in range(max_num_data):
|
108 |
+
shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
|
109 |
+
all_samples.append(shot)
|
110 |
+
|
111 |
+
self.paired_samples = np.array(all_samples)
|
112 |
+
np.random.shuffle(self.paired_samples)
|
113 |
+
|
114 |
+
def make_paired_samples(self):
|
115 |
+
random.seed(1)
|
116 |
+
np.random.seed(0)
|
117 |
+
all_samples = []
|
118 |
+
self.num_episodes = 1000
|
119 |
+
for cls in self.valid_class_ids:
|
120 |
+
for _ in range(self.num_episodes):
|
121 |
+
if self.cat2obj[cls] == []:
|
122 |
+
continue
|
123 |
+
self.num_queries = 1
|
124 |
+
self.num_shots = 1
|
125 |
+
if len(self.cat2obj[cls]) < self.num_shots + self.num_queries:
|
126 |
+
shots = random.choices(self.cat2obj[cls], k=self.num_shots + self.num_queries)
|
127 |
+
else:
|
128 |
+
shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
|
129 |
+
sample_ids = shots[:self.num_shots]
|
130 |
+
query_ids = shots[self.num_shots:]
|
131 |
+
for query_id in query_ids:
|
132 |
+
all_samples.append(sample_ids + [query_id])
|
133 |
+
all_samples.append([query_id] + [query_id])
|
134 |
+
|
135 |
+
self.paired_samples = np.array(list(set(tuple(x) for x in all_samples)))
|
136 |
+
|
137 |
+
def _select_kpt(self, obj, kpt_id):
|
138 |
+
obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
|
139 |
+
obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
|
140 |
+
obj['kpt_id'] = kpt_id
|
141 |
+
|
142 |
+
return obj
|
143 |
+
|
144 |
+
@staticmethod
|
145 |
+
def _get_mapping_id_name(imgs):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
imgs (dict): dict of image info.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
tuple: Image name & id mapping dicts.
|
152 |
+
|
153 |
+
- id2name (dict): Mapping image id to name.
|
154 |
+
- name2id (dict): Mapping image name to id.
|
155 |
+
"""
|
156 |
+
id2name = {}
|
157 |
+
name2id = {}
|
158 |
+
for image_id, image in imgs.items():
|
159 |
+
file_name = image['file_name']
|
160 |
+
id2name[image_id] = file_name
|
161 |
+
name2id[file_name] = image_id
|
162 |
+
|
163 |
+
return id2name, name2id
|
164 |
+
|
165 |
+
def _get_db(self):
|
166 |
+
"""Ground truth bbox and keypoints."""
|
167 |
+
self.obj_id = 0
|
168 |
+
|
169 |
+
self.cat2obj = {}
|
170 |
+
for i in self.coco.getCatIds():
|
171 |
+
self.cat2obj.update({i: []})
|
172 |
+
|
173 |
+
gt_db = []
|
174 |
+
for img_id in self.img_ids:
|
175 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
176 |
+
return gt_db
|
177 |
+
|
178 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
179 |
+
"""load annotation from COCOAPI.
|
180 |
+
|
181 |
+
Note:
|
182 |
+
bbox:[x1, y1, w, h]
|
183 |
+
Args:
|
184 |
+
img_id: coco image id
|
185 |
+
Returns:
|
186 |
+
dict: db entry
|
187 |
+
"""
|
188 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
189 |
+
width = img_ann['width']
|
190 |
+
height = img_ann['height']
|
191 |
+
|
192 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
193 |
+
objs = self.coco.loadAnns(ann_ids)
|
194 |
+
|
195 |
+
# sanitize bboxes
|
196 |
+
valid_objs = []
|
197 |
+
for obj in objs:
|
198 |
+
if 'bbox' not in obj:
|
199 |
+
continue
|
200 |
+
x, y, w, h = obj['bbox']
|
201 |
+
x1 = max(0, x)
|
202 |
+
y1 = max(0, y)
|
203 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
204 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
205 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
206 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
207 |
+
valid_objs.append(obj)
|
208 |
+
objs = valid_objs
|
209 |
+
|
210 |
+
bbox_id = 0
|
211 |
+
rec = []
|
212 |
+
for obj in objs:
|
213 |
+
if 'keypoints' not in obj:
|
214 |
+
continue
|
215 |
+
if max(obj['keypoints']) == 0:
|
216 |
+
continue
|
217 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
218 |
+
continue
|
219 |
+
|
220 |
+
category_id = obj['category_id']
|
221 |
+
# the number of keypoint for this specific category
|
222 |
+
cat_kpt_num = int(len(obj['keypoints']) / 3)
|
223 |
+
if self.max_kpt_num is None:
|
224 |
+
kpt_num = cat_kpt_num
|
225 |
+
else:
|
226 |
+
kpt_num = self.max_kpt_num
|
227 |
+
|
228 |
+
joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
|
229 |
+
joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
|
230 |
+
|
231 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
232 |
+
joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
|
233 |
+
joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
|
234 |
+
|
235 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
236 |
+
|
237 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
238 |
+
|
239 |
+
self.cat2obj[category_id].append(self.obj_id)
|
240 |
+
|
241 |
+
rec.append({
|
242 |
+
'image_file': image_file,
|
243 |
+
'center': center,
|
244 |
+
'scale': scale,
|
245 |
+
'rotation': 0,
|
246 |
+
'bbox': obj['clean_bbox'][:4],
|
247 |
+
'bbox_score': 1,
|
248 |
+
'joints_3d': joints_3d,
|
249 |
+
'joints_3d_visible': joints_3d_visible,
|
250 |
+
'category_id': category_id,
|
251 |
+
'cat_kpt_num': cat_kpt_num,
|
252 |
+
'bbox_id': self.obj_id,
|
253 |
+
'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
|
254 |
+
})
|
255 |
+
bbox_id = bbox_id + 1
|
256 |
+
self.obj_id += 1
|
257 |
+
|
258 |
+
return rec
|
259 |
+
|
260 |
+
def _xywh2cs(self, x, y, w, h):
|
261 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
262 |
+
|
263 |
+
Args:
|
264 |
+
x, y, w, h
|
265 |
+
|
266 |
+
Returns:
|
267 |
+
tuple: A tuple containing center and scale.
|
268 |
+
|
269 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
270 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
271 |
+
"""
|
272 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
|
273 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
274 |
+
#
|
275 |
+
# if (not self.test_mode) and np.random.rand() < 0.3:
|
276 |
+
# center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
277 |
+
|
278 |
+
if w > aspect_ratio * h:
|
279 |
+
h = w * 1.0 / aspect_ratio
|
280 |
+
elif w < aspect_ratio * h:
|
281 |
+
w = h * aspect_ratio
|
282 |
+
|
283 |
+
# pixel std is 200.0
|
284 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
285 |
+
# padding to include proper amount of context
|
286 |
+
scale = scale * 1.25
|
287 |
+
|
288 |
+
return center, scale
|
289 |
+
|
290 |
+
def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
|
291 |
+
"""Evaluate interhand2d keypoint results. The pose prediction results
|
292 |
+
will be saved in `${res_folder}/result_keypoints.json`.
|
293 |
+
|
294 |
+
Note:
|
295 |
+
batch_size: N
|
296 |
+
num_keypoints: K
|
297 |
+
heatmap height: H
|
298 |
+
heatmap width: W
|
299 |
+
|
300 |
+
Args:
|
301 |
+
outputs (list(preds, boxes, image_path, output_heatmap))
|
302 |
+
:preds (np.ndarray[N,K,3]): The first two dimensions are
|
303 |
+
coordinates, score is the third dimension of the array.
|
304 |
+
:boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
|
305 |
+
, scale[1],area, score]
|
306 |
+
:image_paths (list[str]): For example, ['C', 'a', 'p', 't',
|
307 |
+
'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
|
308 |
+
'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
|
309 |
+
'/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
|
310 |
+
'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
|
311 |
+
'j', 'p', 'g']
|
312 |
+
:output_heatmap (np.ndarray[N, K, H, W]): model outpus.
|
313 |
+
|
314 |
+
res_folder (str): Path of directory to save the results.
|
315 |
+
metric (str | list[str]): Metric to be performed.
|
316 |
+
Options: 'PCK', 'AUC', 'EPE'.
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
dict: Evaluation results for evaluation metric.
|
320 |
+
"""
|
321 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
322 |
+
allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
|
323 |
+
for metric in metrics:
|
324 |
+
if metric not in allowed_metrics:
|
325 |
+
raise KeyError(f'metric {metric} is not supported')
|
326 |
+
|
327 |
+
res_file = os.path.join(res_folder, 'result_keypoints.json')
|
328 |
+
|
329 |
+
kpts = []
|
330 |
+
for output in outputs:
|
331 |
+
preds = output['preds']
|
332 |
+
boxes = output['boxes']
|
333 |
+
image_paths = output['image_paths']
|
334 |
+
bbox_ids = output['bbox_ids']
|
335 |
+
|
336 |
+
batch_size = len(image_paths)
|
337 |
+
for i in range(batch_size):
|
338 |
+
image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
|
339 |
+
|
340 |
+
kpts.append({
|
341 |
+
'keypoints': preds[i].tolist(),
|
342 |
+
'center': boxes[i][0:2].tolist(),
|
343 |
+
'scale': boxes[i][2:4].tolist(),
|
344 |
+
'area': float(boxes[i][4]),
|
345 |
+
'score': float(boxes[i][5]),
|
346 |
+
'image_id': image_id,
|
347 |
+
'bbox_id': bbox_ids[i]
|
348 |
+
})
|
349 |
+
kpts = self._sort_and_unique_bboxes(kpts)
|
350 |
+
|
351 |
+
self._write_keypoint_results(kpts, res_file)
|
352 |
+
info_str = self._report_metric(res_file, metrics)
|
353 |
+
name_value = OrderedDict(info_str)
|
354 |
+
|
355 |
+
return name_value
|
EdgeCape/datasets/datasets/mp100/fewshot_base_dataset.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
import json_tricks as json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from mmcv.parallel import DataContainer as DC
|
7 |
+
from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe,
|
8 |
+
keypoint_pck_accuracy)
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from mmpose.datasets import DATASETS
|
11 |
+
from mmpose.datasets.pipelines import Compose
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class FewShotBaseDataset(Dataset, metaclass=ABCMeta):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
ann_file,
|
18 |
+
img_prefix,
|
19 |
+
data_cfg,
|
20 |
+
pipeline,
|
21 |
+
test_mode=False):
|
22 |
+
self.image_info = {}
|
23 |
+
self.ann_info = {}
|
24 |
+
|
25 |
+
self.annotations_path = ann_file
|
26 |
+
if not img_prefix.endswith('/'):
|
27 |
+
img_prefix = img_prefix + '/'
|
28 |
+
self.img_prefix = img_prefix
|
29 |
+
self.pipeline = pipeline
|
30 |
+
self.test_mode = test_mode
|
31 |
+
|
32 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
33 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
34 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
35 |
+
|
36 |
+
self.ann_info['flip_pairs'] = None
|
37 |
+
|
38 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
39 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
40 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
41 |
+
|
42 |
+
self.db = []
|
43 |
+
self.num_shots = 1
|
44 |
+
self.paired_samples = []
|
45 |
+
self.pipeline = Compose(self.pipeline)
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def _get_db(self):
|
49 |
+
"""Load dataset."""
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
@abstractmethod
|
53 |
+
def _select_kpt(self, obj, kpt_id):
|
54 |
+
"""Select kpt."""
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
@abstractmethod
|
58 |
+
def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
|
59 |
+
"""Evaluate keypoint results."""
|
60 |
+
raise NotImplementedError
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def _write_keypoint_results(keypoints, res_file):
|
64 |
+
"""Write results into a json file."""
|
65 |
+
|
66 |
+
with open(res_file, 'w') as f:
|
67 |
+
json.dump(keypoints, f, sort_keys=True, indent=4)
|
68 |
+
|
69 |
+
def _report_metric(self,
|
70 |
+
res_file,
|
71 |
+
metrics,
|
72 |
+
pck_thr=0.2,
|
73 |
+
pckh_thr=0.7,
|
74 |
+
auc_nor=30):
|
75 |
+
"""Keypoint evaluation.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
res_file (str): Json file stored prediction results.
|
79 |
+
metrics (str | list[str]): Metric to be performed.
|
80 |
+
Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
|
81 |
+
pck_thr (float): PCK threshold, default as 0.2.
|
82 |
+
pckh_thr (float): PCKh threshold, default as 0.7.
|
83 |
+
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
List: Evaluation results for evaluation metric.
|
87 |
+
"""
|
88 |
+
info_str = []
|
89 |
+
|
90 |
+
with open(res_file, 'r') as fin:
|
91 |
+
preds = json.load(fin)
|
92 |
+
assert len(preds) == len(self.paired_samples)
|
93 |
+
|
94 |
+
outputs = []
|
95 |
+
gts = []
|
96 |
+
masks = []
|
97 |
+
threshold_bbox = []
|
98 |
+
threshold_head_box = []
|
99 |
+
|
100 |
+
for pred, pair in zip(preds, self.paired_samples):
|
101 |
+
item = self.db[pair[-1]]
|
102 |
+
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
103 |
+
gts.append(np.array(item['joints_3d'])[:, :-1])
|
104 |
+
|
105 |
+
mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
106 |
+
mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
|
107 |
+
for id_s in pair[:-1]:
|
108 |
+
mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
|
109 |
+
masks.append(np.bitwise_and(mask_query, mask_sample))
|
110 |
+
|
111 |
+
if 'PCK' in metrics:
|
112 |
+
bbox = np.array(item['bbox'])
|
113 |
+
bbox_thr = np.max(bbox[2:])
|
114 |
+
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
115 |
+
if 'PCKh' in metrics:
|
116 |
+
head_box_thr = item['head_size']
|
117 |
+
threshold_head_box.append(
|
118 |
+
np.array([head_box_thr, head_box_thr]))
|
119 |
+
|
120 |
+
if 'PCK' in metrics:
|
121 |
+
pck_avg = []
|
122 |
+
for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
|
123 |
+
_, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
|
124 |
+
pck_avg.append(pck)
|
125 |
+
info_str.append(('PCK', np.mean(pck_avg)))
|
126 |
+
|
127 |
+
return info_str
|
128 |
+
|
129 |
+
def _merge_obj(self, Xs_list, Xq, idx):
|
130 |
+
""" merge Xs_list and Xq.
|
131 |
+
|
132 |
+
:param Xs_list: N-shot samples X
|
133 |
+
:param Xq: query X
|
134 |
+
:param idx: id of paired_samples
|
135 |
+
:return: Xall
|
136 |
+
"""
|
137 |
+
Xall = dict()
|
138 |
+
Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
|
139 |
+
Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
|
140 |
+
Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
|
141 |
+
xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
|
142 |
+
|
143 |
+
Xall['img_q'] = Xq['img']
|
144 |
+
Xall['target_q'] = Xq['target']
|
145 |
+
Xall['target_weight_q'] = Xq['target_weight']
|
146 |
+
xq_img_metas = Xq['img_metas'].data
|
147 |
+
|
148 |
+
img_metas = dict()
|
149 |
+
for key in xq_img_metas.keys():
|
150 |
+
img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
|
151 |
+
img_metas['query_' + key] = xq_img_metas[key]
|
152 |
+
img_metas['bbox_id'] = idx
|
153 |
+
|
154 |
+
Xall['img_metas'] = DC(img_metas, cpu_only=True)
|
155 |
+
|
156 |
+
return Xall
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
"""Get the size of the dataset."""
|
160 |
+
return len(self.paired_samples)
|
161 |
+
|
162 |
+
def __getitem__(self, idx):
|
163 |
+
"""Get the sample given index."""
|
164 |
+
|
165 |
+
pair_ids = self.paired_samples[idx]
|
166 |
+
assert len(pair_ids) == self.num_shots + 1
|
167 |
+
sample_id_list = pair_ids[:self.num_shots]
|
168 |
+
query_id = pair_ids[-1]
|
169 |
+
|
170 |
+
sample_obj_list = []
|
171 |
+
for sample_id in sample_id_list:
|
172 |
+
sample_obj = copy.deepcopy(self.db[sample_id])
|
173 |
+
sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
174 |
+
sample_obj_list.append(sample_obj)
|
175 |
+
|
176 |
+
query_obj = copy.deepcopy(self.db[query_id])
|
177 |
+
query_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
178 |
+
|
179 |
+
if not self.test_mode:
|
180 |
+
# randomly select "one" keypoint
|
181 |
+
sample_valid = (sample_obj_list[0]['joints_3d_visible'][:, 0] > 0)
|
182 |
+
for sample_obj in sample_obj_list:
|
183 |
+
sample_valid = sample_valid & (sample_obj['joints_3d_visible'][:, 0] > 0)
|
184 |
+
query_valid = (query_obj['joints_3d_visible'][:, 0] > 0)
|
185 |
+
|
186 |
+
valid_s = np.where(sample_valid)[0]
|
187 |
+
valid_q = np.where(query_valid)[0]
|
188 |
+
valid_sq = np.where(sample_valid & query_valid)[0]
|
189 |
+
if len(valid_sq) > 0:
|
190 |
+
kpt_id = np.random.choice(valid_sq)
|
191 |
+
elif len(valid_s) > 0:
|
192 |
+
kpt_id = np.random.choice(valid_s)
|
193 |
+
elif len(valid_q) > 0:
|
194 |
+
kpt_id = np.random.choice(valid_q)
|
195 |
+
else:
|
196 |
+
kpt_id = np.random.choice(np.array(range(len(query_valid))))
|
197 |
+
|
198 |
+
for i in range(self.num_shots):
|
199 |
+
sample_obj_list[i] = self._select_kpt(sample_obj_list[i], kpt_id)
|
200 |
+
query_obj = self._select_kpt(query_obj, kpt_id)
|
201 |
+
|
202 |
+
# when test, all keypoints will be preserved.
|
203 |
+
|
204 |
+
Xs_list = []
|
205 |
+
for sample_obj in sample_obj_list:
|
206 |
+
Xs = self.pipeline(sample_obj)
|
207 |
+
Xs_list.append(Xs)
|
208 |
+
Xq = self.pipeline(query_obj)
|
209 |
+
|
210 |
+
Xall = self._merge_obj(Xs_list, Xq, idx)
|
211 |
+
Xall['skeleton'] = self.db[query_id]['skeleton']
|
212 |
+
|
213 |
+
return Xall
|
214 |
+
|
215 |
+
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
216 |
+
"""sort kpts and remove the repeated ones."""
|
217 |
+
kpts = sorted(kpts, key=lambda x: x[key])
|
218 |
+
num = len(kpts)
|
219 |
+
for i in range(num - 1, 0, -1):
|
220 |
+
if kpts[i][key] == kpts[i - 1][key]:
|
221 |
+
del kpts[i]
|
222 |
+
|
223 |
+
return kpts
|
EdgeCape/datasets/datasets/mp100/fewshot_dataset.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmpose.datasets import DATASETS
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
from xtcocotools.coco import COCO
|
7 |
+
from .fewshot_base_dataset import FewShotBaseDataset
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class FewShotKeypointDataset(FewShotBaseDataset):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
ann_file,
|
14 |
+
img_prefix,
|
15 |
+
data_cfg,
|
16 |
+
pipeline,
|
17 |
+
valid_class_ids,
|
18 |
+
num_shots = 1,
|
19 |
+
num_queries = 100,
|
20 |
+
num_episodes = 1,
|
21 |
+
test_mode=False):
|
22 |
+
super().__init__(
|
23 |
+
ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
|
24 |
+
|
25 |
+
self.ann_info['flip_pairs'] = []
|
26 |
+
|
27 |
+
self.ann_info['upper_body_ids'] = []
|
28 |
+
self.ann_info['lower_body_ids'] = []
|
29 |
+
|
30 |
+
self.ann_info['use_different_joint_weights'] = False
|
31 |
+
self.ann_info['joint_weights'] = np.array([1.,],
|
32 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
33 |
+
|
34 |
+
self.coco = COCO(ann_file)
|
35 |
+
|
36 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
37 |
+
self.img_ids = self.coco.getImgIds()
|
38 |
+
self.classes = [
|
39 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
40 |
+
]
|
41 |
+
|
42 |
+
self.num_classes = len(self.classes)
|
43 |
+
self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
|
44 |
+
self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
|
45 |
+
|
46 |
+
if valid_class_ids is not None:
|
47 |
+
self.valid_class_ids = valid_class_ids
|
48 |
+
else:
|
49 |
+
self.valid_class_ids = self.coco.getCatIds()
|
50 |
+
self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
|
51 |
+
|
52 |
+
self.cats = self.coco.cats
|
53 |
+
|
54 |
+
# Also update self.cat2obj
|
55 |
+
self.db = self._get_db()
|
56 |
+
|
57 |
+
self.num_shots = num_shots
|
58 |
+
|
59 |
+
if not test_mode:
|
60 |
+
# Update every training epoch
|
61 |
+
self.random_paired_samples()
|
62 |
+
else:
|
63 |
+
self.num_queries = num_queries
|
64 |
+
self.num_episodes = num_episodes
|
65 |
+
self.make_paired_samples()
|
66 |
+
|
67 |
+
|
68 |
+
def random_paired_samples(self):
|
69 |
+
num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
|
70 |
+
|
71 |
+
# balance the dataset
|
72 |
+
max_num_data = max(num_datas)
|
73 |
+
|
74 |
+
all_samples = []
|
75 |
+
for cls in self.valid_class_ids:
|
76 |
+
for i in range(max_num_data):
|
77 |
+
shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
|
78 |
+
all_samples.append(shot)
|
79 |
+
|
80 |
+
self.paired_samples = np.array(all_samples)
|
81 |
+
np.random.shuffle(self.paired_samples)
|
82 |
+
|
83 |
+
def make_paired_samples(self):
|
84 |
+
random.seed(1)
|
85 |
+
np.random.seed(0)
|
86 |
+
|
87 |
+
all_samples = []
|
88 |
+
for cls in self.valid_class_ids:
|
89 |
+
for _ in range(self.num_episodes):
|
90 |
+
shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
|
91 |
+
sample_ids = shots[:self.num_shots]
|
92 |
+
query_ids = shots[self.num_shots:]
|
93 |
+
for query_id in query_ids:
|
94 |
+
all_samples.append(sample_ids + [query_id])
|
95 |
+
|
96 |
+
self.paired_samples = np.array(all_samples)
|
97 |
+
|
98 |
+
def _select_kpt(self, obj, kpt_id):
|
99 |
+
obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
|
100 |
+
obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
|
101 |
+
obj['kpt_id'] = kpt_id
|
102 |
+
|
103 |
+
return obj
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def _get_mapping_id_name(imgs):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
imgs (dict): dict of image info.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
tuple: Image name & id mapping dicts.
|
113 |
+
|
114 |
+
- id2name (dict): Mapping image id to name.
|
115 |
+
- name2id (dict): Mapping image name to id.
|
116 |
+
"""
|
117 |
+
id2name = {}
|
118 |
+
name2id = {}
|
119 |
+
for image_id, image in imgs.items():
|
120 |
+
file_name = image['file_name']
|
121 |
+
id2name[image_id] = file_name
|
122 |
+
name2id[file_name] = image_id
|
123 |
+
|
124 |
+
return id2name, name2id
|
125 |
+
|
126 |
+
def _get_db(self):
|
127 |
+
"""Ground truth bbox and keypoints."""
|
128 |
+
self.obj_id = 0
|
129 |
+
|
130 |
+
self.cat2obj = {}
|
131 |
+
for i in self.coco.getCatIds():
|
132 |
+
self.cat2obj.update({i: []})
|
133 |
+
|
134 |
+
gt_db = []
|
135 |
+
for img_id in self.img_ids:
|
136 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
137 |
+
return gt_db
|
138 |
+
|
139 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
140 |
+
"""load annotation from COCOAPI.
|
141 |
+
|
142 |
+
Note:
|
143 |
+
bbox:[x1, y1, w, h]
|
144 |
+
Args:
|
145 |
+
img_id: coco image id
|
146 |
+
Returns:
|
147 |
+
dict: db entry
|
148 |
+
"""
|
149 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
150 |
+
width = img_ann['width']
|
151 |
+
height = img_ann['height']
|
152 |
+
|
153 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
154 |
+
objs = self.coco.loadAnns(ann_ids)
|
155 |
+
|
156 |
+
# sanitize bboxes
|
157 |
+
valid_objs = []
|
158 |
+
for obj in objs:
|
159 |
+
if 'bbox' not in obj:
|
160 |
+
continue
|
161 |
+
x, y, w, h = obj['bbox']
|
162 |
+
x1 = max(0, x)
|
163 |
+
y1 = max(0, y)
|
164 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
165 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
166 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
167 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
168 |
+
valid_objs.append(obj)
|
169 |
+
objs = valid_objs
|
170 |
+
|
171 |
+
bbox_id = 0
|
172 |
+
rec = []
|
173 |
+
for obj in objs:
|
174 |
+
if 'keypoints' not in obj:
|
175 |
+
continue
|
176 |
+
if max(obj['keypoints']) == 0:
|
177 |
+
continue
|
178 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
179 |
+
continue
|
180 |
+
|
181 |
+
category_id = obj['category_id']
|
182 |
+
# the number of keypoint for this specific category
|
183 |
+
cat_kpt_num = int(len(obj['keypoints']) / 3)
|
184 |
+
|
185 |
+
joints_3d = np.zeros((cat_kpt_num, 3), dtype=np.float32)
|
186 |
+
joints_3d_visible = np.zeros((cat_kpt_num, 3), dtype=np.float32)
|
187 |
+
|
188 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
189 |
+
joints_3d[:, :2] = keypoints[:, :2]
|
190 |
+
joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
|
191 |
+
|
192 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
193 |
+
|
194 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
195 |
+
|
196 |
+
self.cat2obj[category_id].append(self.obj_id)
|
197 |
+
|
198 |
+
rec.append({
|
199 |
+
'image_file': image_file,
|
200 |
+
'center': center,
|
201 |
+
'scale': scale,
|
202 |
+
'rotation': 0,
|
203 |
+
'bbox': obj['clean_bbox'][:4],
|
204 |
+
'bbox_score': 1,
|
205 |
+
'joints_3d': joints_3d,
|
206 |
+
'joints_3d_visible': joints_3d_visible,
|
207 |
+
'category_id': category_id,
|
208 |
+
'cat_kpt_num': cat_kpt_num,
|
209 |
+
'bbox_id': self.obj_id,
|
210 |
+
'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
|
211 |
+
})
|
212 |
+
bbox_id = bbox_id + 1
|
213 |
+
self.obj_id += 1
|
214 |
+
|
215 |
+
return rec
|
216 |
+
|
217 |
+
def _xywh2cs(self, x, y, w, h):
|
218 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
219 |
+
|
220 |
+
Args:
|
221 |
+
x, y, w, h
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
tuple: A tuple containing center and scale.
|
225 |
+
|
226 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
227 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
228 |
+
"""
|
229 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
|
230 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
231 |
+
#
|
232 |
+
# if (not self.test_mode) and np.random.rand() < 0.3:
|
233 |
+
# center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
234 |
+
|
235 |
+
if w > aspect_ratio * h:
|
236 |
+
h = w * 1.0 / aspect_ratio
|
237 |
+
elif w < aspect_ratio * h:
|
238 |
+
w = h * aspect_ratio
|
239 |
+
|
240 |
+
# pixel std is 200.0
|
241 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
242 |
+
# padding to include proper amount of context
|
243 |
+
scale = scale * 1.25
|
244 |
+
|
245 |
+
return center, scale
|
246 |
+
|
247 |
+
def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
|
248 |
+
"""Evaluate interhand2d keypoint results. The pose prediction results
|
249 |
+
will be saved in `${res_folder}/result_keypoints.json`.
|
250 |
+
|
251 |
+
Note:
|
252 |
+
batch_size: N
|
253 |
+
num_keypoints: K
|
254 |
+
heatmap height: H
|
255 |
+
heatmap width: W
|
256 |
+
|
257 |
+
Args:
|
258 |
+
outputs (list(preds, boxes, image_path, output_heatmap))
|
259 |
+
:preds (np.ndarray[N,K,3]): The first two dimensions are
|
260 |
+
coordinates, score is the third dimension of the array.
|
261 |
+
:boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
|
262 |
+
, scale[1],area, score]
|
263 |
+
:image_paths (list[str]): For example, ['C', 'a', 'p', 't',
|
264 |
+
'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
|
265 |
+
'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
|
266 |
+
'/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
|
267 |
+
'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
|
268 |
+
'j', 'p', 'g']
|
269 |
+
:output_heatmap (np.ndarray[N, K, H, W]): model outpus.
|
270 |
+
|
271 |
+
res_folder (str): Path of directory to save the results.
|
272 |
+
metric (str | list[str]): Metric to be performed.
|
273 |
+
Options: 'PCK', 'AUC', 'EPE'.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
dict: Evaluation results for evaluation metric.
|
277 |
+
"""
|
278 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
279 |
+
allowed_metrics = ['PCK', 'AUC', 'EPE']
|
280 |
+
for metric in metrics:
|
281 |
+
if metric not in allowed_metrics:
|
282 |
+
raise KeyError(f'metric {metric} is not supported')
|
283 |
+
|
284 |
+
res_file = os.path.join(res_folder, 'result_keypoints.json')
|
285 |
+
|
286 |
+
kpts = []
|
287 |
+
for output in outputs:
|
288 |
+
preds = output['preds']
|
289 |
+
boxes = output['boxes']
|
290 |
+
image_paths = output['image_paths']
|
291 |
+
bbox_ids = output['bbox_ids']
|
292 |
+
|
293 |
+
batch_size = len(image_paths)
|
294 |
+
for i in range(batch_size):
|
295 |
+
image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
|
296 |
+
|
297 |
+
kpts.append({
|
298 |
+
'keypoints': preds[i].tolist(),
|
299 |
+
'center': boxes[i][0:2].tolist(),
|
300 |
+
'scale': boxes[i][2:4].tolist(),
|
301 |
+
'area': float(boxes[i][4]),
|
302 |
+
'score': float(boxes[i][5]),
|
303 |
+
'image_id': image_id,
|
304 |
+
'bbox_id': bbox_ids[i]
|
305 |
+
})
|
306 |
+
kpts = self._sort_and_unique_bboxes(kpts)
|
307 |
+
|
308 |
+
self._write_keypoint_results(kpts, res_file)
|
309 |
+
info_str = self._report_metric(res_file, metrics)
|
310 |
+
name_value = OrderedDict(info_str)
|
311 |
+
|
312 |
+
return name_value
|
EdgeCape/datasets/datasets/mp100/test_base_dataset.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
import json_tricks as json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from mmcv.parallel import DataContainer as DC
|
7 |
+
from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
|
8 |
+
keypoint_pck_accuracy)
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from mmpose.datasets import DATASETS
|
11 |
+
from mmpose.datasets.pipelines import Compose
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class TestBaseDataset(Dataset, metaclass=ABCMeta):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
ann_file,
|
18 |
+
img_prefix,
|
19 |
+
data_cfg,
|
20 |
+
pipeline,
|
21 |
+
test_mode=True,
|
22 |
+
PCK_threshold_list=[0.05, 0.1, 0.15, 0.2, 0.25]):
|
23 |
+
self.image_info = {}
|
24 |
+
self.ann_info = {}
|
25 |
+
|
26 |
+
self.annotations_path = ann_file
|
27 |
+
if not img_prefix.endswith('/'):
|
28 |
+
img_prefix = img_prefix + '/'
|
29 |
+
self.img_prefix = img_prefix
|
30 |
+
self.pipeline = pipeline
|
31 |
+
self.test_mode = test_mode
|
32 |
+
self.PCK_threshold_list = PCK_threshold_list
|
33 |
+
|
34 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
35 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
36 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
37 |
+
|
38 |
+
self.ann_info['flip_pairs'] = None
|
39 |
+
|
40 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
41 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
42 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
43 |
+
|
44 |
+
self.db = []
|
45 |
+
self.num_shots = 1
|
46 |
+
self.paired_samples = []
|
47 |
+
self.pipeline = Compose(self.pipeline)
|
48 |
+
|
49 |
+
@abstractmethod
|
50 |
+
def _get_db(self):
|
51 |
+
"""Load dataset."""
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def _select_kpt(self, obj, kpt_id):
|
56 |
+
"""Select kpt."""
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
|
61 |
+
"""Evaluate keypoint results."""
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def _write_keypoint_results(keypoints, res_file):
|
66 |
+
"""Write results into a json file."""
|
67 |
+
|
68 |
+
with open(res_file, 'w') as f:
|
69 |
+
json.dump(keypoints, f, sort_keys=True, indent=4)
|
70 |
+
|
71 |
+
def _report_metric(self,
|
72 |
+
res_file,
|
73 |
+
metrics):
|
74 |
+
"""Keypoint evaluation.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
res_file (str): Json file stored prediction results.
|
78 |
+
metrics (str | list[str]): Metric to be performed.
|
79 |
+
Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
|
80 |
+
pck_thr (float): PCK threshold, default as 0.2.
|
81 |
+
pckh_thr (float): PCKh threshold, default as 0.7.
|
82 |
+
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
List: Evaluation results for evaluation metric.
|
86 |
+
"""
|
87 |
+
info_str = []
|
88 |
+
|
89 |
+
with open(res_file, 'r') as fin:
|
90 |
+
preds = json.load(fin)
|
91 |
+
assert len(preds) == len(self.paired_samples)
|
92 |
+
|
93 |
+
outputs = []
|
94 |
+
gts = []
|
95 |
+
masks = []
|
96 |
+
threshold_bbox = []
|
97 |
+
threshold_head_box = []
|
98 |
+
|
99 |
+
for pred, pair in zip(preds, self.paired_samples):
|
100 |
+
item = self.db[pair[-1]]
|
101 |
+
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
102 |
+
gts.append(np.array(item['joints_3d'])[:, :-1])
|
103 |
+
|
104 |
+
mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
105 |
+
mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
|
106 |
+
for id_s in pair[:-1]:
|
107 |
+
mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
|
108 |
+
masks.append(np.bitwise_and(mask_query, mask_sample))
|
109 |
+
|
110 |
+
if 'PCK' in metrics or 'NME' in metrics or 'AUC' in metrics:
|
111 |
+
bbox = np.array(item['bbox'])
|
112 |
+
bbox_thr = np.max(bbox[2:])
|
113 |
+
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
114 |
+
if 'PCKh' in metrics:
|
115 |
+
head_box_thr = item['head_size']
|
116 |
+
threshold_head_box.append(
|
117 |
+
np.array([head_box_thr, head_box_thr]))
|
118 |
+
|
119 |
+
if 'PCK' in metrics:
|
120 |
+
pck_results = dict()
|
121 |
+
for pck_thr in self.PCK_threshold_list:
|
122 |
+
pck_results[pck_thr] = []
|
123 |
+
|
124 |
+
for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
|
125 |
+
for pck_thr in self.PCK_threshold_list:
|
126 |
+
_, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
|
127 |
+
pck_results[pck_thr].append(pck)
|
128 |
+
|
129 |
+
mPCK = 0
|
130 |
+
for pck_thr in self.PCK_threshold_list:
|
131 |
+
info_str.append(['PCK@' + str(pck_thr), np.mean(pck_results[pck_thr])])
|
132 |
+
mPCK += np.mean(pck_results[pck_thr])
|
133 |
+
info_str.append(['mPCK', mPCK / len(self.PCK_threshold_list)])
|
134 |
+
|
135 |
+
if 'NME' in metrics:
|
136 |
+
nme_results = []
|
137 |
+
for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
|
138 |
+
nme = keypoint_nme(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), np.expand_dims(thr_bbox,0))
|
139 |
+
nme_results.append(nme)
|
140 |
+
info_str.append(['NME', np.mean(nme_results)])
|
141 |
+
|
142 |
+
if 'AUC' in metrics:
|
143 |
+
auc_results = []
|
144 |
+
for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
|
145 |
+
auc = keypoint_auc(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), thr_bbox[0])
|
146 |
+
auc_results.append(auc)
|
147 |
+
info_str.append(['AUC', np.mean(auc_results)])
|
148 |
+
|
149 |
+
if 'EPE' in metrics:
|
150 |
+
epe_results = []
|
151 |
+
for (output, gt, mask) in zip(outputs, gts, masks):
|
152 |
+
epe = keypoint_epe(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0))
|
153 |
+
epe_results.append(epe)
|
154 |
+
info_str.append(['EPE', np.mean(epe_results)])
|
155 |
+
return info_str
|
156 |
+
|
157 |
+
def _merge_obj(self, Xs_list, Xq, idx):
|
158 |
+
""" merge Xs_list and Xq.
|
159 |
+
|
160 |
+
:param Xs_list: N-shot samples X
|
161 |
+
:param Xq: query X
|
162 |
+
:param idx: id of paired_samples
|
163 |
+
:return: Xall
|
164 |
+
"""
|
165 |
+
Xall = dict()
|
166 |
+
Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
|
167 |
+
Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
|
168 |
+
Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
|
169 |
+
xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
|
170 |
+
|
171 |
+
Xall['img_q'] = Xq['img']
|
172 |
+
Xall['target_q'] = Xq['target']
|
173 |
+
Xall['target_weight_q'] = Xq['target_weight']
|
174 |
+
xq_img_metas = Xq['img_metas'].data
|
175 |
+
|
176 |
+
img_metas = dict()
|
177 |
+
for key in xq_img_metas.keys():
|
178 |
+
img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
|
179 |
+
img_metas['query_' + key] = xq_img_metas[key]
|
180 |
+
img_metas['bbox_id'] = idx
|
181 |
+
|
182 |
+
Xall['img_metas'] = DC(img_metas, cpu_only=True)
|
183 |
+
|
184 |
+
return Xall
|
185 |
+
|
186 |
+
def __len__(self):
|
187 |
+
"""Get the size of the dataset."""
|
188 |
+
return len(self.paired_samples)
|
189 |
+
|
190 |
+
def __getitem__(self, idx):
|
191 |
+
"""Get the sample given index."""
|
192 |
+
|
193 |
+
pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
|
194 |
+
assert len(pair_ids) == self.num_shots + 1
|
195 |
+
sample_id_list = pair_ids[:self.num_shots]
|
196 |
+
query_id = pair_ids[-1]
|
197 |
+
|
198 |
+
sample_obj_list = []
|
199 |
+
for sample_id in sample_id_list:
|
200 |
+
sample_obj = copy.deepcopy(self.db[sample_id])
|
201 |
+
sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
202 |
+
sample_obj_list.append(sample_obj)
|
203 |
+
|
204 |
+
query_obj = copy.deepcopy(self.db[query_id])
|
205 |
+
query_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
206 |
+
|
207 |
+
Xs_list = []
|
208 |
+
for sample_obj in sample_obj_list:
|
209 |
+
Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
|
210 |
+
Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
|
211 |
+
Xq = self.pipeline(query_obj)
|
212 |
+
|
213 |
+
Xall = self._merge_obj(Xs_list, Xq, idx)
|
214 |
+
Xall['skeleton'] = self.db[query_id]['skeleton']
|
215 |
+
|
216 |
+
return Xall
|
217 |
+
|
218 |
+
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
219 |
+
"""sort kpts and remove the repeated ones."""
|
220 |
+
kpts = sorted(kpts, key=lambda x: x[key])
|
221 |
+
num = len(kpts)
|
222 |
+
for i in range(num - 1, 0, -1):
|
223 |
+
if kpts[i][key] == kpts[i - 1][key]:
|
224 |
+
del kpts[i]
|
225 |
+
|
226 |
+
return kpts
|
EdgeCape/datasets/datasets/mp100/test_dataset.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmpose.datasets import DATASETS
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
from xtcocotools.coco import COCO
|
7 |
+
from .test_base_dataset import TestBaseDataset
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class TestPoseDataset(TestBaseDataset):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
ann_file,
|
14 |
+
img_prefix,
|
15 |
+
data_cfg,
|
16 |
+
pipeline,
|
17 |
+
valid_class_ids,
|
18 |
+
max_kpt_num=None,
|
19 |
+
num_shots=1,
|
20 |
+
num_queries=100,
|
21 |
+
num_episodes=1,
|
22 |
+
pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],
|
23 |
+
test_mode=True):
|
24 |
+
super().__init__(
|
25 |
+
ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)
|
26 |
+
|
27 |
+
self.ann_info['flip_pairs'] = []
|
28 |
+
|
29 |
+
self.ann_info['upper_body_ids'] = []
|
30 |
+
self.ann_info['lower_body_ids'] = []
|
31 |
+
|
32 |
+
self.ann_info['use_different_joint_weights'] = False
|
33 |
+
self.ann_info['joint_weights'] = np.array([1.,],
|
34 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
35 |
+
|
36 |
+
self.coco = COCO(ann_file)
|
37 |
+
|
38 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
39 |
+
self.img_ids = self.coco.getImgIds()
|
40 |
+
self.classes = [
|
41 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
42 |
+
]
|
43 |
+
|
44 |
+
self.num_classes = len(self.classes)
|
45 |
+
self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
|
46 |
+
self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
|
47 |
+
|
48 |
+
if valid_class_ids is not None: # None by default
|
49 |
+
self.valid_class_ids = valid_class_ids
|
50 |
+
else:
|
51 |
+
self.valid_class_ids = self.coco.getCatIds()
|
52 |
+
self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
|
53 |
+
|
54 |
+
self.cats = self.coco.cats
|
55 |
+
self.max_kpt_num = max_kpt_num
|
56 |
+
|
57 |
+
# Also update self.cat2obj
|
58 |
+
self.db = self._get_db()
|
59 |
+
|
60 |
+
self.num_shots = num_shots
|
61 |
+
|
62 |
+
if not test_mode:
|
63 |
+
# Update every training epoch
|
64 |
+
self.random_paired_samples()
|
65 |
+
else:
|
66 |
+
self.num_queries = num_queries
|
67 |
+
self.num_episodes = num_episodes
|
68 |
+
self.make_paired_samples()
|
69 |
+
|
70 |
+
|
71 |
+
def random_paired_samples(self):
|
72 |
+
num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
|
73 |
+
|
74 |
+
# balance the dataset
|
75 |
+
max_num_data = max(num_datas)
|
76 |
+
|
77 |
+
all_samples = []
|
78 |
+
for cls in self.valid_class_ids:
|
79 |
+
for i in range(max_num_data):
|
80 |
+
shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
|
81 |
+
all_samples.append(shot)
|
82 |
+
|
83 |
+
self.paired_samples = np.array(all_samples)
|
84 |
+
np.random.shuffle(self.paired_samples)
|
85 |
+
|
86 |
+
def make_paired_samples(self):
|
87 |
+
random.seed(1)
|
88 |
+
np.random.seed(0)
|
89 |
+
|
90 |
+
all_samples = []
|
91 |
+
for cls in self.valid_class_ids:
|
92 |
+
for _ in range(self.num_episodes):
|
93 |
+
shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
|
94 |
+
sample_ids = shots[:self.num_shots]
|
95 |
+
query_ids = shots[self.num_shots:]
|
96 |
+
for query_id in query_ids:
|
97 |
+
all_samples.append(sample_ids + [query_id])
|
98 |
+
|
99 |
+
self.paired_samples = np.array(all_samples)
|
100 |
+
|
101 |
+
def _select_kpt(self, obj, kpt_id):
|
102 |
+
obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
|
103 |
+
obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
|
104 |
+
obj['kpt_id'] = kpt_id
|
105 |
+
|
106 |
+
return obj
|
107 |
+
|
108 |
+
@staticmethod
|
109 |
+
def _get_mapping_id_name(imgs):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
imgs (dict): dict of image info.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
tuple: Image name & id mapping dicts.
|
116 |
+
|
117 |
+
- id2name (dict): Mapping image id to name.
|
118 |
+
- name2id (dict): Mapping image name to id.
|
119 |
+
"""
|
120 |
+
id2name = {}
|
121 |
+
name2id = {}
|
122 |
+
for image_id, image in imgs.items():
|
123 |
+
file_name = image['file_name']
|
124 |
+
id2name[image_id] = file_name
|
125 |
+
name2id[file_name] = image_id
|
126 |
+
|
127 |
+
return id2name, name2id
|
128 |
+
|
129 |
+
def _get_db(self):
|
130 |
+
"""Ground truth bbox and keypoints."""
|
131 |
+
self.obj_id = 0
|
132 |
+
|
133 |
+
self.cat2obj = {}
|
134 |
+
for i in self.coco.getCatIds():
|
135 |
+
self.cat2obj.update({i: []})
|
136 |
+
|
137 |
+
gt_db = []
|
138 |
+
for img_id in self.img_ids:
|
139 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
140 |
+
return gt_db
|
141 |
+
|
142 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
143 |
+
"""load annotation from COCOAPI.
|
144 |
+
|
145 |
+
Note:
|
146 |
+
bbox:[x1, y1, w, h]
|
147 |
+
Args:
|
148 |
+
img_id: coco image id
|
149 |
+
Returns:
|
150 |
+
dict: db entry
|
151 |
+
"""
|
152 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
153 |
+
width = img_ann['width']
|
154 |
+
height = img_ann['height']
|
155 |
+
|
156 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
157 |
+
objs = self.coco.loadAnns(ann_ids)
|
158 |
+
|
159 |
+
# sanitize bboxes
|
160 |
+
valid_objs = []
|
161 |
+
for obj in objs:
|
162 |
+
if 'bbox' not in obj:
|
163 |
+
continue
|
164 |
+
x, y, w, h = obj['bbox']
|
165 |
+
x1 = max(0, x)
|
166 |
+
y1 = max(0, y)
|
167 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
168 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
169 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
170 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
171 |
+
valid_objs.append(obj)
|
172 |
+
objs = valid_objs
|
173 |
+
|
174 |
+
bbox_id = 0
|
175 |
+
rec = []
|
176 |
+
for obj in objs:
|
177 |
+
if 'keypoints' not in obj:
|
178 |
+
continue
|
179 |
+
if max(obj['keypoints']) == 0:
|
180 |
+
continue
|
181 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
182 |
+
continue
|
183 |
+
|
184 |
+
category_id = obj['category_id']
|
185 |
+
# the number of keypoint for this specific category
|
186 |
+
cat_kpt_num = int(len(obj['keypoints']) / 3)
|
187 |
+
if self.max_kpt_num is None:
|
188 |
+
kpt_num = cat_kpt_num
|
189 |
+
else:
|
190 |
+
kpt_num = self.max_kpt_num
|
191 |
+
|
192 |
+
joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
|
193 |
+
joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
|
194 |
+
|
195 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
196 |
+
joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
|
197 |
+
joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
|
198 |
+
|
199 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
200 |
+
|
201 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
202 |
+
|
203 |
+
self.cat2obj[category_id].append(self.obj_id)
|
204 |
+
|
205 |
+
rec.append({
|
206 |
+
'image_file': image_file,
|
207 |
+
'center': center,
|
208 |
+
'scale': scale,
|
209 |
+
'rotation': 0,
|
210 |
+
'bbox': obj['clean_bbox'][:4],
|
211 |
+
'bbox_score': 1,
|
212 |
+
'joints_3d': joints_3d,
|
213 |
+
'joints_3d_visible': joints_3d_visible,
|
214 |
+
'category_id': category_id,
|
215 |
+
'cat_kpt_num': cat_kpt_num,
|
216 |
+
'bbox_id': self.obj_id,
|
217 |
+
'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
|
218 |
+
})
|
219 |
+
bbox_id = bbox_id + 1
|
220 |
+
self.obj_id += 1
|
221 |
+
|
222 |
+
return rec
|
223 |
+
|
224 |
+
def _xywh2cs(self, x, y, w, h):
|
225 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
226 |
+
|
227 |
+
Args:
|
228 |
+
x, y, w, h
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
tuple: A tuple containing center and scale.
|
232 |
+
|
233 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
234 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
235 |
+
"""
|
236 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
|
237 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
238 |
+
#
|
239 |
+
# if (not self.test_mode) and np.random.rand() < 0.3:
|
240 |
+
# center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
241 |
+
|
242 |
+
if w > aspect_ratio * h:
|
243 |
+
h = w * 1.0 / aspect_ratio
|
244 |
+
elif w < aspect_ratio * h:
|
245 |
+
w = h * aspect_ratio
|
246 |
+
|
247 |
+
# pixel std is 200.0
|
248 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
249 |
+
# padding to include proper amount of context
|
250 |
+
scale = scale * 1.25
|
251 |
+
|
252 |
+
return center, scale
|
253 |
+
|
254 |
+
def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
|
255 |
+
"""Evaluate interhand2d keypoint results. The pose prediction results
|
256 |
+
will be saved in `${res_folder}/result_keypoints.json`.
|
257 |
+
|
258 |
+
Note:
|
259 |
+
batch_size: N
|
260 |
+
num_keypoints: K
|
261 |
+
heatmap height: H
|
262 |
+
heatmap width: W
|
263 |
+
|
264 |
+
Args:
|
265 |
+
outputs (list(preds, boxes, image_path, output_heatmap))
|
266 |
+
:preds (np.ndarray[N,K,3]): The first two dimensions are
|
267 |
+
coordinates, score is the third dimension of the array.
|
268 |
+
:boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
|
269 |
+
, scale[1],area, score]
|
270 |
+
:image_paths (list[str]): For example, ['C', 'a', 'p', 't',
|
271 |
+
'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
|
272 |
+
'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
|
273 |
+
'/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
|
274 |
+
'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
|
275 |
+
'j', 'p', 'g']
|
276 |
+
:output_heatmap (np.ndarray[N, K, H, W]): model outpus.
|
277 |
+
|
278 |
+
res_folder (str): Path of directory to save the results.
|
279 |
+
metric (str | list[str]): Metric to be performed.
|
280 |
+
Options: 'PCK', 'AUC', 'EPE'.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
dict: Evaluation results for evaluation metric.
|
284 |
+
"""
|
285 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
286 |
+
allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
|
287 |
+
for metric in metrics:
|
288 |
+
if metric not in allowed_metrics:
|
289 |
+
raise KeyError(f'metric {metric} is not supported')
|
290 |
+
|
291 |
+
res_file = os.path.join(res_folder, 'result_keypoints.json')
|
292 |
+
|
293 |
+
kpts = []
|
294 |
+
for output in outputs:
|
295 |
+
preds = output['preds']
|
296 |
+
boxes = output['boxes']
|
297 |
+
image_paths = output['image_paths']
|
298 |
+
bbox_ids = output['bbox_ids']
|
299 |
+
|
300 |
+
batch_size = len(image_paths)
|
301 |
+
for i in range(batch_size):
|
302 |
+
image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
|
303 |
+
|
304 |
+
kpts.append({
|
305 |
+
'keypoints': preds[i].tolist(),
|
306 |
+
'center': boxes[i][0:2].tolist(),
|
307 |
+
'scale': boxes[i][2:4].tolist(),
|
308 |
+
'area': float(boxes[i][4]),
|
309 |
+
'score': float(boxes[i][5]),
|
310 |
+
'image_id': image_id,
|
311 |
+
'bbox_id': bbox_ids[i]
|
312 |
+
})
|
313 |
+
kpts = self._sort_and_unique_bboxes(kpts)
|
314 |
+
|
315 |
+
self._write_keypoint_results(kpts, res_file)
|
316 |
+
info_str = self._report_metric(res_file, metrics)
|
317 |
+
name_value = OrderedDict(info_str)
|
318 |
+
|
319 |
+
return name_value
|
EdgeCape/datasets/datasets/mp100/transformer_base_dataset.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from abc import ABCMeta, abstractmethod
|
3 |
+
import json_tricks as json
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from mmcv.parallel import DataContainer as DC
|
7 |
+
from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe,
|
8 |
+
keypoint_pck_accuracy)
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from mmpose.datasets import DATASETS
|
11 |
+
from mmpose.datasets.pipelines import Compose
|
12 |
+
|
13 |
+
@DATASETS.register_module()
|
14 |
+
class TransformerBaseDataset(Dataset, metaclass=ABCMeta):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
ann_file,
|
18 |
+
img_prefix,
|
19 |
+
data_cfg,
|
20 |
+
pipeline,
|
21 |
+
masking_ratio=0.3,
|
22 |
+
test_mode=False):
|
23 |
+
self.image_info = {}
|
24 |
+
self.ann_info = {}
|
25 |
+
|
26 |
+
self.annotations_path = ann_file
|
27 |
+
if not img_prefix.endswith('/'):
|
28 |
+
img_prefix = img_prefix + '/'
|
29 |
+
self.img_prefix = img_prefix
|
30 |
+
self.pipeline = pipeline
|
31 |
+
self.test_mode = test_mode
|
32 |
+
self.masking_ratio = masking_ratio
|
33 |
+
self.ann_info['image_size'] = np.array(data_cfg['image_size'])
|
34 |
+
self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
|
35 |
+
self.ann_info['num_joints'] = data_cfg['num_joints']
|
36 |
+
|
37 |
+
self.ann_info['flip_pairs'] = None
|
38 |
+
|
39 |
+
self.ann_info['inference_channel'] = data_cfg['inference_channel']
|
40 |
+
self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
|
41 |
+
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
|
42 |
+
|
43 |
+
self.db = []
|
44 |
+
self.num_shots = 1
|
45 |
+
self.paired_samples = []
|
46 |
+
self.pipeline = Compose(self.pipeline)
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def _get_db(self):
|
50 |
+
"""Load dataset."""
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def _select_kpt(self, obj, kpt_id):
|
55 |
+
"""Select kpt."""
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
|
60 |
+
"""Evaluate keypoint results."""
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def _write_keypoint_results(keypoints, res_file):
|
65 |
+
"""Write results into a json file."""
|
66 |
+
|
67 |
+
with open(res_file, 'w') as f:
|
68 |
+
json.dump(keypoints, f, sort_keys=True, indent=4)
|
69 |
+
|
70 |
+
def _report_metric(self,
|
71 |
+
res_file,
|
72 |
+
metrics,
|
73 |
+
pck_thr=0.2,
|
74 |
+
pckh_thr=0.7,
|
75 |
+
auc_nor=30):
|
76 |
+
"""Keypoint evaluation.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
res_file (str): Json file stored prediction results.
|
80 |
+
metrics (str | list[str]): Metric to be performed.
|
81 |
+
Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
|
82 |
+
pck_thr (float): PCK threshold, default as 0.2.
|
83 |
+
pckh_thr (float): PCKh threshold, default as 0.7.
|
84 |
+
auc_nor (float): AUC normalization factor, default as 30 pixel.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
List: Evaluation results for evaluation metric.
|
88 |
+
"""
|
89 |
+
info_str = []
|
90 |
+
|
91 |
+
with open(res_file, 'r') as fin:
|
92 |
+
preds = json.load(fin)
|
93 |
+
assert len(preds) == len(self.paired_samples)
|
94 |
+
|
95 |
+
outputs = []
|
96 |
+
gts = []
|
97 |
+
masks = []
|
98 |
+
threshold_bbox = []
|
99 |
+
threshold_head_box = []
|
100 |
+
|
101 |
+
for pred, pair in zip(preds, self.paired_samples):
|
102 |
+
item = self.db[pair[-1]]
|
103 |
+
outputs.append(np.array(pred['keypoints'])[:, :-1])
|
104 |
+
gts.append(np.array(item['joints_3d'])[:, :-1])
|
105 |
+
|
106 |
+
mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
|
107 |
+
mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
|
108 |
+
for id_s in pair[:-1]:
|
109 |
+
mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
|
110 |
+
masks.append(np.bitwise_and(mask_query, mask_sample))
|
111 |
+
|
112 |
+
if 'PCK' in metrics:
|
113 |
+
bbox = np.array(item['bbox'])
|
114 |
+
bbox_thr = np.max(bbox[2:])
|
115 |
+
threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
|
116 |
+
if 'PCKh' in metrics:
|
117 |
+
head_box_thr = item['head_size']
|
118 |
+
threshold_head_box.append(
|
119 |
+
np.array([head_box_thr, head_box_thr]))
|
120 |
+
|
121 |
+
if 'PCK' in metrics:
|
122 |
+
pck_avg = []
|
123 |
+
for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
|
124 |
+
_, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
|
125 |
+
pck_avg.append(pck)
|
126 |
+
info_str.append(('PCK', np.mean(pck_avg)))
|
127 |
+
|
128 |
+
return info_str
|
129 |
+
|
130 |
+
def _merge_obj(self, Xs_list, Xq, idx):
|
131 |
+
""" merge Xs_list and Xq.
|
132 |
+
|
133 |
+
:param Xs_list: N-shot samples X
|
134 |
+
:param Xq: query X
|
135 |
+
:param idx: id of paired_samples
|
136 |
+
:return: Xall
|
137 |
+
"""
|
138 |
+
Xall = dict()
|
139 |
+
Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
|
140 |
+
Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
|
141 |
+
Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
|
142 |
+
xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
|
143 |
+
|
144 |
+
Xall['img_q'] = Xq['img']
|
145 |
+
Xall['target_q'] = Xq['target']
|
146 |
+
Xall['target_weight_q'] = Xq['target_weight']
|
147 |
+
xq_img_metas = Xq['img_metas'].data
|
148 |
+
|
149 |
+
img_metas = dict()
|
150 |
+
for key in xq_img_metas.keys():
|
151 |
+
img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
|
152 |
+
img_metas['query_' + key] = xq_img_metas[key]
|
153 |
+
img_metas['bbox_id'] = idx
|
154 |
+
|
155 |
+
Xall['img_metas'] = DC(img_metas, cpu_only=True)
|
156 |
+
|
157 |
+
return Xall
|
158 |
+
|
159 |
+
def __len__(self):
|
160 |
+
"""Get the size of the dataset."""
|
161 |
+
return len(self.paired_samples)
|
162 |
+
|
163 |
+
def __getitem__(self, idx):
|
164 |
+
"""Get the sample given index."""
|
165 |
+
|
166 |
+
pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
|
167 |
+
assert len(pair_ids) == self.num_shots + 1
|
168 |
+
sample_id_list = pair_ids[:self.num_shots]
|
169 |
+
query_id = pair_ids[-1]
|
170 |
+
|
171 |
+
sample_obj_list = []
|
172 |
+
for sample_id in sample_id_list:
|
173 |
+
sample_obj = copy.deepcopy(self.db[sample_id])
|
174 |
+
sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
175 |
+
sample_obj_list.append(sample_obj)
|
176 |
+
|
177 |
+
query_obj = copy.deepcopy(self.db[query_id])
|
178 |
+
query_obj['ann_info'] = copy.deepcopy(self.ann_info)
|
179 |
+
|
180 |
+
Xs_list = []
|
181 |
+
for sample_obj in sample_obj_list:
|
182 |
+
Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
|
183 |
+
Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
|
184 |
+
Xq = self.pipeline(query_obj)
|
185 |
+
|
186 |
+
Xall = self._merge_obj(Xs_list, Xq, idx)
|
187 |
+
Xall['skeleton'] = self.db[query_id]['skeleton']
|
188 |
+
Xall['rand_mask'] = self.rand_mask(Xall['target_weight_s'])
|
189 |
+
return Xall
|
190 |
+
|
191 |
+
def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
|
192 |
+
"""sort kpts and remove the repeated ones."""
|
193 |
+
kpts = sorted(kpts, key=lambda x: x[key])
|
194 |
+
num = len(kpts)
|
195 |
+
for i in range(num - 1, 0, -1):
|
196 |
+
if kpts[i][key] == kpts[i - 1][key]:
|
197 |
+
del kpts[i]
|
198 |
+
|
199 |
+
return kpts
|
200 |
+
|
201 |
+
def rand_mask(self, target_weight_s):
|
202 |
+
mask_s = target_weight_s[0]
|
203 |
+
for target_weight in target_weight_s:
|
204 |
+
mask_s = mask_s * target_weight
|
205 |
+
num_to_mask = int(np.sum(mask_s) * self.masking_ratio)
|
206 |
+
true_indices = np.where(mask_s == 1)[0]
|
207 |
+
rand_mask = np.random.permutation(true_indices)[:num_to_mask]
|
208 |
+
mask_s[rand_mask] = 0
|
209 |
+
return mask_s
|
EdgeCape/datasets/datasets/mp100/transformer_dataset.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmpose.datasets import DATASETS
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from collections import OrderedDict
|
6 |
+
from xtcocotools.coco import COCO
|
7 |
+
from .transformer_base_dataset import TransformerBaseDataset
|
8 |
+
|
9 |
+
@DATASETS.register_module()
|
10 |
+
class TransformerPoseDataset(TransformerBaseDataset):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
ann_file,
|
14 |
+
img_prefix,
|
15 |
+
data_cfg,
|
16 |
+
pipeline,
|
17 |
+
valid_class_ids,
|
18 |
+
max_kpt_num=None,
|
19 |
+
num_shots=1,
|
20 |
+
num_queries=100,
|
21 |
+
num_episodes=1,
|
22 |
+
test_mode=False):
|
23 |
+
super().__init__(
|
24 |
+
ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
|
25 |
+
|
26 |
+
self.ann_info['flip_pairs'] = []
|
27 |
+
|
28 |
+
self.ann_info['upper_body_ids'] = []
|
29 |
+
self.ann_info['lower_body_ids'] = []
|
30 |
+
|
31 |
+
self.ann_info['use_different_joint_weights'] = False
|
32 |
+
self.ann_info['joint_weights'] = np.array([1.,],
|
33 |
+
dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
|
34 |
+
|
35 |
+
self.coco = COCO(ann_file)
|
36 |
+
|
37 |
+
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
38 |
+
self.img_ids = self.coco.getImgIds()
|
39 |
+
self.classes = [
|
40 |
+
cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
|
41 |
+
]
|
42 |
+
|
43 |
+
self.num_classes = len(self.classes)
|
44 |
+
self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
|
45 |
+
self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
|
46 |
+
|
47 |
+
if valid_class_ids is not None: # None by default
|
48 |
+
self.valid_class_ids = valid_class_ids
|
49 |
+
else:
|
50 |
+
self.valid_class_ids = self.coco.getCatIds()
|
51 |
+
self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
|
52 |
+
|
53 |
+
self.cats = self.coco.cats
|
54 |
+
self.max_kpt_num = max_kpt_num
|
55 |
+
|
56 |
+
# Also update self.cat2obj
|
57 |
+
self.db = self._get_db()
|
58 |
+
|
59 |
+
self.num_shots = num_shots
|
60 |
+
|
61 |
+
if not test_mode:
|
62 |
+
# Update every training epoch
|
63 |
+
self.random_paired_samples()
|
64 |
+
else:
|
65 |
+
self.num_queries = num_queries
|
66 |
+
self.num_episodes = num_episodes
|
67 |
+
self.make_paired_samples()
|
68 |
+
|
69 |
+
|
70 |
+
def random_paired_samples(self):
|
71 |
+
num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
|
72 |
+
|
73 |
+
# balance the dataset
|
74 |
+
max_num_data = max(num_datas)
|
75 |
+
|
76 |
+
all_samples = []
|
77 |
+
for cls in self.valid_class_ids:
|
78 |
+
for i in range(max_num_data):
|
79 |
+
shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
|
80 |
+
all_samples.append(shot)
|
81 |
+
|
82 |
+
self.paired_samples = np.array(all_samples)
|
83 |
+
np.random.shuffle(self.paired_samples)
|
84 |
+
|
85 |
+
def make_paired_samples(self):
|
86 |
+
random.seed(1)
|
87 |
+
np.random.seed(0)
|
88 |
+
|
89 |
+
all_samples = []
|
90 |
+
for cls in self.valid_class_ids:
|
91 |
+
for _ in range(self.num_episodes):
|
92 |
+
shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
|
93 |
+
sample_ids = shots[:self.num_shots]
|
94 |
+
query_ids = shots[self.num_shots:]
|
95 |
+
for query_id in query_ids:
|
96 |
+
all_samples.append(sample_ids + [query_id])
|
97 |
+
|
98 |
+
self.paired_samples = np.array(all_samples)
|
99 |
+
|
100 |
+
def _select_kpt(self, obj, kpt_id):
|
101 |
+
obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
|
102 |
+
obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
|
103 |
+
obj['kpt_id'] = kpt_id
|
104 |
+
|
105 |
+
return obj
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def _get_mapping_id_name(imgs):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
imgs (dict): dict of image info.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
tuple: Image name & id mapping dicts.
|
115 |
+
|
116 |
+
- id2name (dict): Mapping image id to name.
|
117 |
+
- name2id (dict): Mapping image name to id.
|
118 |
+
"""
|
119 |
+
id2name = {}
|
120 |
+
name2id = {}
|
121 |
+
for image_id, image in imgs.items():
|
122 |
+
file_name = image['file_name']
|
123 |
+
id2name[image_id] = file_name
|
124 |
+
name2id[file_name] = image_id
|
125 |
+
|
126 |
+
return id2name, name2id
|
127 |
+
|
128 |
+
def _get_db(self):
|
129 |
+
"""Ground truth bbox and keypoints."""
|
130 |
+
self.obj_id = 0
|
131 |
+
|
132 |
+
self.cat2obj = {}
|
133 |
+
for i in self.coco.getCatIds():
|
134 |
+
self.cat2obj.update({i: []})
|
135 |
+
|
136 |
+
gt_db = []
|
137 |
+
for img_id in self.img_ids:
|
138 |
+
gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
|
139 |
+
|
140 |
+
return gt_db
|
141 |
+
|
142 |
+
def _load_coco_keypoint_annotation_kernel(self, img_id):
|
143 |
+
"""load annotation from COCOAPI.
|
144 |
+
|
145 |
+
Note:
|
146 |
+
bbox:[x1, y1, w, h]
|
147 |
+
Args:
|
148 |
+
img_id: coco image id
|
149 |
+
Returns:
|
150 |
+
dict: db entry
|
151 |
+
"""
|
152 |
+
img_ann = self.coco.loadImgs(img_id)[0]
|
153 |
+
width = img_ann['width']
|
154 |
+
height = img_ann['height']
|
155 |
+
|
156 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
157 |
+
objs = self.coco.loadAnns(ann_ids)
|
158 |
+
|
159 |
+
# sanitize bboxes
|
160 |
+
valid_objs = []
|
161 |
+
for obj in objs:
|
162 |
+
if 'bbox' not in obj:
|
163 |
+
continue
|
164 |
+
x, y, w, h = obj['bbox']
|
165 |
+
x1 = max(0, x)
|
166 |
+
y1 = max(0, y)
|
167 |
+
x2 = min(width - 1, x1 + max(0, w - 1))
|
168 |
+
y2 = min(height - 1, y1 + max(0, h - 1))
|
169 |
+
if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
|
170 |
+
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
171 |
+
valid_objs.append(obj)
|
172 |
+
objs = valid_objs
|
173 |
+
|
174 |
+
bbox_id = 0
|
175 |
+
rec = []
|
176 |
+
for obj in objs:
|
177 |
+
if 'keypoints' not in obj:
|
178 |
+
continue
|
179 |
+
if max(obj['keypoints']) == 0:
|
180 |
+
continue
|
181 |
+
if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
|
182 |
+
continue
|
183 |
+
|
184 |
+
category_id = obj['category_id']
|
185 |
+
# the number of keypoint for this specific category
|
186 |
+
cat_kpt_num = int(len(obj['keypoints']) / 3)
|
187 |
+
if self.max_kpt_num is None:
|
188 |
+
kpt_num = cat_kpt_num
|
189 |
+
else:
|
190 |
+
kpt_num = self.max_kpt_num
|
191 |
+
|
192 |
+
joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
|
193 |
+
joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
|
194 |
+
|
195 |
+
keypoints = np.array(obj['keypoints']).reshape(-1, 3)
|
196 |
+
joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
|
197 |
+
joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
|
198 |
+
|
199 |
+
center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
|
200 |
+
|
201 |
+
image_file = os.path.join(self.img_prefix, self.id2name[img_id])
|
202 |
+
if os.path.exists(image_file):
|
203 |
+
self.cat2obj[category_id].append(self.obj_id)
|
204 |
+
|
205 |
+
rec.append({
|
206 |
+
'image_file': image_file,
|
207 |
+
'center': center,
|
208 |
+
'scale': scale,
|
209 |
+
'rotation': 0,
|
210 |
+
'bbox': obj['clean_bbox'][:4],
|
211 |
+
'bbox_score': 1,
|
212 |
+
'joints_3d': joints_3d,
|
213 |
+
'joints_3d_visible': joints_3d_visible,
|
214 |
+
'category_id': category_id,
|
215 |
+
'cat_kpt_num': cat_kpt_num,
|
216 |
+
'bbox_id': self.obj_id,
|
217 |
+
'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
|
218 |
+
})
|
219 |
+
bbox_id = bbox_id + 1
|
220 |
+
self.obj_id += 1
|
221 |
+
|
222 |
+
return rec
|
223 |
+
|
224 |
+
def _xywh2cs(self, x, y, w, h):
|
225 |
+
"""This encodes bbox(x,y,w,w) into (center, scale)
|
226 |
+
|
227 |
+
Args:
|
228 |
+
x, y, w, h
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
tuple: A tuple containing center and scale.
|
232 |
+
|
233 |
+
- center (np.ndarray[float32](2,)): center of the bbox (x, y).
|
234 |
+
- scale (np.ndarray[float32](2,)): scale of the bbox w & h.
|
235 |
+
"""
|
236 |
+
aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
|
237 |
+
center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
|
238 |
+
#
|
239 |
+
# if (not self.test_mode) and np.random.rand() < 0.3:
|
240 |
+
# center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
|
241 |
+
|
242 |
+
if w > aspect_ratio * h:
|
243 |
+
h = w * 1.0 / aspect_ratio
|
244 |
+
elif w < aspect_ratio * h:
|
245 |
+
w = h * aspect_ratio
|
246 |
+
|
247 |
+
# pixel std is 200.0
|
248 |
+
scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
|
249 |
+
# padding to include proper amount of context
|
250 |
+
scale = scale * 1.25
|
251 |
+
|
252 |
+
return center, scale
|
253 |
+
|
254 |
+
def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
|
255 |
+
"""Evaluate interhand2d keypoint results. The pose prediction results
|
256 |
+
will be saved in `${res_folder}/result_keypoints.json`.
|
257 |
+
|
258 |
+
Note:
|
259 |
+
batch_size: N
|
260 |
+
num_keypoints: K
|
261 |
+
heatmap height: H
|
262 |
+
heatmap width: W
|
263 |
+
|
264 |
+
Args:
|
265 |
+
outputs (list(preds, boxes, image_path, output_heatmap))
|
266 |
+
:preds (np.ndarray[N,K,3]): The first two dimensions are
|
267 |
+
coordinates, score is the third dimension of the array.
|
268 |
+
:boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
|
269 |
+
, scale[1],area, score]
|
270 |
+
:image_paths (list[str]): For example, ['C', 'a', 'p', 't',
|
271 |
+
'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
|
272 |
+
'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
|
273 |
+
'/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
|
274 |
+
'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
|
275 |
+
'j', 'p', 'g']
|
276 |
+
:output_heatmap (np.ndarray[N, K, H, W]): model outpus.
|
277 |
+
|
278 |
+
res_folder (str): Path of directory to save the results.
|
279 |
+
metric (str | list[str]): Metric to be performed.
|
280 |
+
Options: 'PCK', 'AUC', 'EPE'.
|
281 |
+
|
282 |
+
Returns:
|
283 |
+
dict: Evaluation results for evaluation metric.
|
284 |
+
"""
|
285 |
+
metrics = metric if isinstance(metric, list) else [metric]
|
286 |
+
allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
|
287 |
+
for metric in metrics:
|
288 |
+
if metric not in allowed_metrics:
|
289 |
+
raise KeyError(f'metric {metric} is not supported')
|
290 |
+
|
291 |
+
res_file = os.path.join(res_folder, 'result_keypoints.json')
|
292 |
+
|
293 |
+
kpts = []
|
294 |
+
for output in outputs:
|
295 |
+
preds = output['preds']
|
296 |
+
boxes = output['boxes']
|
297 |
+
image_paths = output['image_paths']
|
298 |
+
bbox_ids = output['bbox_ids']
|
299 |
+
|
300 |
+
batch_size = len(image_paths)
|
301 |
+
for i in range(batch_size):
|
302 |
+
image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
|
303 |
+
|
304 |
+
kpts.append({
|
305 |
+
'keypoints': preds[i].tolist(),
|
306 |
+
'center': boxes[i][0:2].tolist(),
|
307 |
+
'scale': boxes[i][2:4].tolist(),
|
308 |
+
'area': float(boxes[i][4]),
|
309 |
+
'score': float(boxes[i][5]),
|
310 |
+
'image_id': image_id,
|
311 |
+
'bbox_id': bbox_ids[i]
|
312 |
+
})
|
313 |
+
kpts = self._sort_and_unique_bboxes(kpts)
|
314 |
+
|
315 |
+
self._write_keypoint_results(kpts, res_file)
|
316 |
+
info_str = self._report_metric(res_file, metrics)
|
317 |
+
name_value = OrderedDict(info_str)
|
318 |
+
|
319 |
+
return name_value
|
EdgeCape/datasets/pipelines/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .top_down_transform import (TopDownAffineFewShot,
|
2 |
+
TopDownGenerateTargetFewShot,
|
3 |
+
LoadDepthFromFile,
|
4 |
+
DepthTopDownAffineFewShot)
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'TopDownGenerateTargetFewShot', 'TopDownAffineFewShot', 'LoadDepthFromFile', 'DepthTopDownAffineFewShot',
|
8 |
+
]
|
EdgeCape/datasets/pipelines/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (373 Bytes). View file
|
|
EdgeCape/datasets/pipelines/__pycache__/post_transforms.cpython-39.pyc
ADDED
Binary file (3.39 kB). View file
|
|
EdgeCape/datasets/pipelines/__pycache__/top_down_transform.cpython-39.pyc
ADDED
Binary file (18 kB). View file
|
|
EdgeCape/datasets/pipelines/post_transforms.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
|
3 |
+
# Original licence: Copyright (c) Microsoft, under the MIT License.
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def get_affine_transform(center,
|
11 |
+
scale,
|
12 |
+
rot,
|
13 |
+
output_size,
|
14 |
+
shift=(0., 0.),
|
15 |
+
inv=False):
|
16 |
+
"""Get the affine transform matrix, given the center/scale/rot/output_size.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
20 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
21 |
+
wrt [width, height].
|
22 |
+
rot (float): Rotation angle (degree).
|
23 |
+
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
|
24 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
25 |
+
Default (0., 0.).
|
26 |
+
inv (bool): Option to inverse the affine transform direction.
|
27 |
+
(inv=False: src->dst or inv=True: dst->src)
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
np.ndarray: The transform matrix.
|
31 |
+
"""
|
32 |
+
assert len(center) == 2
|
33 |
+
assert len(scale) == 2
|
34 |
+
assert len(output_size) == 2
|
35 |
+
assert len(shift) == 2
|
36 |
+
|
37 |
+
# pixel_std is 200.
|
38 |
+
scale_tmp = scale * 200.0
|
39 |
+
|
40 |
+
shift = np.array(shift)
|
41 |
+
src_w = scale_tmp[0]
|
42 |
+
dst_w = output_size[0]
|
43 |
+
dst_h = output_size[1]
|
44 |
+
|
45 |
+
rot_rad = np.pi * rot / 180
|
46 |
+
src_dir = rotate_point([0., src_w * -0.5], rot_rad)
|
47 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
48 |
+
|
49 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
50 |
+
src[0, :] = center + scale_tmp * shift
|
51 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
52 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
53 |
+
|
54 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
55 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
56 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
57 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
58 |
+
|
59 |
+
if inv:
|
60 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
61 |
+
else:
|
62 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
63 |
+
|
64 |
+
return trans
|
65 |
+
|
66 |
+
|
67 |
+
def affine_transform(pt, trans_mat):
|
68 |
+
"""Apply an affine transformation to the points.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
pt (np.ndarray): a 2 dimensional point to be transformed
|
72 |
+
trans_mat (np.ndarray): 2x3 matrix of an affine transform
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
np.ndarray: Transformed points.
|
76 |
+
"""
|
77 |
+
assert len(pt) == 2
|
78 |
+
new_pt = np.array(trans_mat) @ np.array([pt[0], pt[1], 1.])
|
79 |
+
|
80 |
+
return new_pt
|
81 |
+
|
82 |
+
|
83 |
+
def _get_3rd_point(a, b):
|
84 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
85 |
+
function is used to get the 3rd point, given 2D points a & b.
|
86 |
+
|
87 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
88 |
+
anticlockwise, using b as the rotation center.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
a (np.ndarray): point(x,y)
|
92 |
+
b (np.ndarray): point(x,y)
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
np.ndarray: The 3rd point.
|
96 |
+
"""
|
97 |
+
assert len(a) == 2
|
98 |
+
assert len(b) == 2
|
99 |
+
direction = a - b
|
100 |
+
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
|
101 |
+
|
102 |
+
return third_pt
|
103 |
+
|
104 |
+
|
105 |
+
def rotate_point(pt, angle_rad):
|
106 |
+
"""Rotate a point by an angle.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
pt (list[float]): 2 dimensional point to be rotated
|
110 |
+
angle_rad (float): rotation angle by radian
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
list[float]: Rotated point.
|
114 |
+
"""
|
115 |
+
assert len(pt) == 2
|
116 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
117 |
+
new_x = pt[0] * cs - pt[1] * sn
|
118 |
+
new_y = pt[0] * sn + pt[1] * cs
|
119 |
+
rotated_pt = [new_x, new_y]
|
120 |
+
|
121 |
+
return rotated_pt
|
EdgeCape/datasets/pipelines/top_down_transform.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import mmcv
|
8 |
+
import numpy as np
|
9 |
+
from mmcv import fileio
|
10 |
+
|
11 |
+
from mmpose.datasets.builder import PIPELINES
|
12 |
+
from .post_transforms import (affine_transform,
|
13 |
+
get_affine_transform)
|
14 |
+
from mmpose.core.post_processing import (affine_transform, fliplr_joints,
|
15 |
+
get_affine_transform, get_warp_matrix,
|
16 |
+
warp_affine_joints)
|
17 |
+
|
18 |
+
@PIPELINES.register_module()
|
19 |
+
class TopDownAffineFewShot:
|
20 |
+
"""Affine transform the image to make input.
|
21 |
+
|
22 |
+
Required keys:'img', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
|
23 |
+
'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
|
24 |
+
'joints_3d_visible'.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
use_udp (bool): To use unbiased data processing.
|
28 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
29 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, use_udp=False):
|
33 |
+
self.use_udp = use_udp
|
34 |
+
|
35 |
+
def __call__(self, results):
|
36 |
+
image_size = results['ann_info']['image_size']
|
37 |
+
|
38 |
+
img = results['img']
|
39 |
+
joints_3d = results['joints_3d']
|
40 |
+
joints_3d_visible = results['joints_3d_visible']
|
41 |
+
c = results['center']
|
42 |
+
s = results['scale']
|
43 |
+
r = results['rotation']
|
44 |
+
|
45 |
+
if self.use_udp:
|
46 |
+
trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
|
47 |
+
img = cv2.warpAffine(
|
48 |
+
img,
|
49 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
50 |
+
flags=cv2.INTER_LINEAR)
|
51 |
+
joints_3d[:, 0:2] = \
|
52 |
+
warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
|
53 |
+
else:
|
54 |
+
trans = get_affine_transform(c, s, r, image_size)
|
55 |
+
img = cv2.warpAffine(
|
56 |
+
img,
|
57 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
58 |
+
flags=cv2.INTER_LINEAR)
|
59 |
+
for i in range(len(joints_3d)):
|
60 |
+
if joints_3d_visible[i, 0] > 0.0:
|
61 |
+
joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
|
62 |
+
|
63 |
+
results['img'] = img
|
64 |
+
results['joints_3d'] = joints_3d
|
65 |
+
results['joints_3d_visible'] = joints_3d_visible
|
66 |
+
|
67 |
+
return results
|
68 |
+
|
69 |
+
|
70 |
+
@PIPELINES.register_module()
|
71 |
+
class TopDownGenerateTargetFewShot:
|
72 |
+
"""Generate the target heatmap.
|
73 |
+
|
74 |
+
Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.
|
75 |
+
Modified keys: 'target', and 'target_weight'.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
sigma: Sigma of heatmap gaussian for 'MSRA' approach.
|
79 |
+
kernel: Kernel of heatmap gaussian for 'Megvii' approach.
|
80 |
+
encoding (str): Approach to generate target heatmaps.
|
81 |
+
Currently supported approaches: 'MSRA', 'Megvii', 'UDP'.
|
82 |
+
Default:'MSRA'
|
83 |
+
|
84 |
+
unbiased_encoding (bool): Option to use unbiased
|
85 |
+
encoding methods.
|
86 |
+
Paper ref: Zhang et al. Distribution-Aware Coordinate
|
87 |
+
Representation for Human Pose Estimation (CVPR 2020).
|
88 |
+
keypoint_pose_distance: Keypoint pose distance for UDP.
|
89 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
90 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
91 |
+
target_type (str): supported targets: 'GaussianHeatMap',
|
92 |
+
'CombinedTarget'. Default:'GaussianHeatMap'
|
93 |
+
CombinedTarget: The combination of classification target
|
94 |
+
(response map) and regression target (offset map).
|
95 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
96 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self,
|
100 |
+
sigma=2,
|
101 |
+
kernel=(11, 11),
|
102 |
+
valid_radius_factor=0.0546875,
|
103 |
+
target_type='GaussianHeatMap',
|
104 |
+
encoding='MSRA',
|
105 |
+
unbiased_encoding=False):
|
106 |
+
self.sigma = sigma
|
107 |
+
self.unbiased_encoding = unbiased_encoding
|
108 |
+
self.kernel = kernel
|
109 |
+
self.valid_radius_factor = valid_radius_factor
|
110 |
+
self.target_type = target_type
|
111 |
+
self.encoding = encoding
|
112 |
+
|
113 |
+
def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):
|
114 |
+
"""Generate the target heatmap via "MSRA" approach.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
cfg (dict): data config
|
118 |
+
joints_3d: np.ndarray ([num_joints, 3])
|
119 |
+
joints_3d_visible: np.ndarray ([num_joints, 3])
|
120 |
+
sigma: Sigma of heatmap gaussian
|
121 |
+
Returns:
|
122 |
+
tuple: A tuple containing targets.
|
123 |
+
|
124 |
+
- target: Target heatmaps.
|
125 |
+
- target_weight: (1: visible, 0: invisible)
|
126 |
+
"""
|
127 |
+
num_joints = len(joints_3d)
|
128 |
+
image_size = cfg['image_size']
|
129 |
+
W, H = cfg['heatmap_size']
|
130 |
+
joint_weights = cfg['joint_weights']
|
131 |
+
use_different_joint_weights = cfg['use_different_joint_weights']
|
132 |
+
assert not use_different_joint_weights
|
133 |
+
|
134 |
+
target_weight = np.zeros((num_joints, 1), dtype=np.float32)
|
135 |
+
target = np.zeros((num_joints, H, W), dtype=np.float32)
|
136 |
+
|
137 |
+
# 3-sigma rule
|
138 |
+
tmp_size = sigma * 3
|
139 |
+
|
140 |
+
if self.unbiased_encoding:
|
141 |
+
for joint_id in range(num_joints):
|
142 |
+
target_weight[joint_id] = joints_3d_visible[joint_id, 0]
|
143 |
+
|
144 |
+
feat_stride = image_size / [W, H]
|
145 |
+
mu_x = joints_3d[joint_id][0] / feat_stride[0]
|
146 |
+
mu_y = joints_3d[joint_id][1] / feat_stride[1]
|
147 |
+
# Check that any part of the gaussian is in-bounds
|
148 |
+
ul = [mu_x - tmp_size, mu_y - tmp_size]
|
149 |
+
br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]
|
150 |
+
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
|
151 |
+
target_weight[joint_id] = 0
|
152 |
+
|
153 |
+
if target_weight[joint_id] == 0:
|
154 |
+
continue
|
155 |
+
|
156 |
+
x = np.arange(0, W, 1, np.float32)
|
157 |
+
y = np.arange(0, H, 1, np.float32)
|
158 |
+
y = y[:, None]
|
159 |
+
|
160 |
+
if target_weight[joint_id] > 0.5:
|
161 |
+
target[joint_id] = np.exp(-((x - mu_x)**2 +
|
162 |
+
(y - mu_y)**2) /
|
163 |
+
(2 * sigma**2))
|
164 |
+
else:
|
165 |
+
for joint_id in range(num_joints):
|
166 |
+
target_weight[joint_id] = joints_3d_visible[joint_id, 0]
|
167 |
+
|
168 |
+
feat_stride = image_size / [W, H]
|
169 |
+
mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
|
170 |
+
mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
|
171 |
+
# Check that any part of the gaussian is in-bounds
|
172 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
173 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
174 |
+
if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
|
175 |
+
target_weight[joint_id] = 0
|
176 |
+
|
177 |
+
if target_weight[joint_id] > 0.5:
|
178 |
+
size = 2 * tmp_size + 1
|
179 |
+
x = np.arange(0, size, 1, np.float32)
|
180 |
+
y = x[:, None]
|
181 |
+
x0 = y0 = size // 2
|
182 |
+
# The gaussian is not normalized,
|
183 |
+
# we want the center value to equal 1
|
184 |
+
g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
|
185 |
+
|
186 |
+
# Usable gaussian range
|
187 |
+
g_x = max(0, -ul[0]), min(br[0], W) - ul[0]
|
188 |
+
g_y = max(0, -ul[1]), min(br[1], H) - ul[1]
|
189 |
+
# Image range
|
190 |
+
img_x = max(0, ul[0]), min(br[0], W)
|
191 |
+
img_y = max(0, ul[1]), min(br[1], H)
|
192 |
+
|
193 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
194 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
195 |
+
|
196 |
+
if use_different_joint_weights:
|
197 |
+
target_weight = np.multiply(target_weight, joint_weights)
|
198 |
+
|
199 |
+
return target, target_weight
|
200 |
+
|
201 |
+
def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,
|
202 |
+
target_type):
|
203 |
+
"""Generate the target heatmap via 'UDP' approach. Paper ref: Huang et
|
204 |
+
al. The Devil is in the Details: Delving into Unbiased Data Processing
|
205 |
+
for Human Pose Estimation (CVPR 2020).
|
206 |
+
|
207 |
+
Note:
|
208 |
+
num keypoints: K
|
209 |
+
heatmap height: H
|
210 |
+
heatmap width: W
|
211 |
+
num target channels: C
|
212 |
+
C = K if target_type=='GaussianHeatMap'
|
213 |
+
C = 3*K if target_type=='CombinedTarget'
|
214 |
+
|
215 |
+
Args:
|
216 |
+
cfg (dict): data config
|
217 |
+
joints_3d (np.ndarray[K, 3]): Annotated keypoints.
|
218 |
+
joints_3d_visible (np.ndarray[K, 3]): Visibility of keypoints.
|
219 |
+
factor (float): kernel factor for GaussianHeatMap target or
|
220 |
+
valid radius factor for CombinedTarget.
|
221 |
+
target_type (str): 'GaussianHeatMap' or 'CombinedTarget'.
|
222 |
+
GaussianHeatMap: Heatmap target with gaussian distribution.
|
223 |
+
CombinedTarget: The combination of classification target
|
224 |
+
(response map) and regression target (offset map).
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
tuple: A tuple containing targets.
|
228 |
+
|
229 |
+
- target (np.ndarray[C, H, W]): Target heatmaps.
|
230 |
+
- target_weight (np.ndarray[K, 1]): (1: visible, 0: invisible)
|
231 |
+
"""
|
232 |
+
num_joints = len(joints_3d)
|
233 |
+
image_size = cfg['image_size']
|
234 |
+
heatmap_size = cfg['heatmap_size']
|
235 |
+
joint_weights = cfg['joint_weights']
|
236 |
+
use_different_joint_weights = cfg['use_different_joint_weights']
|
237 |
+
assert not use_different_joint_weights
|
238 |
+
|
239 |
+
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
240 |
+
target_weight[:, 0] = joints_3d_visible[:, 0]
|
241 |
+
|
242 |
+
assert target_type in ['GaussianHeatMap', 'CombinedTarget']
|
243 |
+
|
244 |
+
if target_type == 'GaussianHeatMap':
|
245 |
+
target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
|
246 |
+
dtype=np.float32)
|
247 |
+
|
248 |
+
tmp_size = factor * 3
|
249 |
+
|
250 |
+
# prepare for gaussian
|
251 |
+
size = 2 * tmp_size + 1
|
252 |
+
x = np.arange(0, size, 1, np.float32)
|
253 |
+
y = x[:, None]
|
254 |
+
|
255 |
+
for joint_id in range(num_joints):
|
256 |
+
feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
|
257 |
+
mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
|
258 |
+
mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
|
259 |
+
# Check that any part of the gaussian is in-bounds
|
260 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
261 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
262 |
+
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
|
263 |
+
or br[0] < 0 or br[1] < 0:
|
264 |
+
# If not, just return the image as is
|
265 |
+
target_weight[joint_id] = 0
|
266 |
+
continue
|
267 |
+
|
268 |
+
# # Generate gaussian
|
269 |
+
mu_x_ac = joints_3d[joint_id][0] / feat_stride[0]
|
270 |
+
mu_y_ac = joints_3d[joint_id][1] / feat_stride[1]
|
271 |
+
x0 = y0 = size // 2
|
272 |
+
x0 += mu_x_ac - mu_x
|
273 |
+
y0 += mu_y_ac - mu_y
|
274 |
+
g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * factor**2))
|
275 |
+
|
276 |
+
# Usable gaussian range
|
277 |
+
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
|
278 |
+
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
|
279 |
+
# Image range
|
280 |
+
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
|
281 |
+
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
|
282 |
+
|
283 |
+
v = target_weight[joint_id]
|
284 |
+
if v > 0.5:
|
285 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
286 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
287 |
+
elif target_type == 'CombinedTarget':
|
288 |
+
target = np.zeros(
|
289 |
+
(num_joints, 3, heatmap_size[1] * heatmap_size[0]),
|
290 |
+
dtype=np.float32)
|
291 |
+
feat_width = heatmap_size[0]
|
292 |
+
feat_height = heatmap_size[1]
|
293 |
+
feat_x_int = np.arange(0, feat_width)
|
294 |
+
feat_y_int = np.arange(0, feat_height)
|
295 |
+
feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
|
296 |
+
feat_x_int = feat_x_int.flatten()
|
297 |
+
feat_y_int = feat_y_int.flatten()
|
298 |
+
# Calculate the radius of the positive area in classification
|
299 |
+
# heatmap.
|
300 |
+
valid_radius = factor * heatmap_size[1]
|
301 |
+
feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
|
302 |
+
for joint_id in range(num_joints):
|
303 |
+
mu_x = joints_3d[joint_id][0] / feat_stride[0]
|
304 |
+
mu_y = joints_3d[joint_id][1] / feat_stride[1]
|
305 |
+
x_offset = (mu_x - feat_x_int) / valid_radius
|
306 |
+
y_offset = (mu_y - feat_y_int) / valid_radius
|
307 |
+
dis = x_offset**2 + y_offset**2
|
308 |
+
keep_pos = np.where(dis <= 1)[0]
|
309 |
+
v = target_weight[joint_id]
|
310 |
+
if v > 0.5:
|
311 |
+
target[joint_id, 0, keep_pos] = 1
|
312 |
+
target[joint_id, 1, keep_pos] = x_offset[keep_pos]
|
313 |
+
target[joint_id, 2, keep_pos] = y_offset[keep_pos]
|
314 |
+
target = target.reshape(num_joints * 3, heatmap_size[1],
|
315 |
+
heatmap_size[0])
|
316 |
+
|
317 |
+
if use_different_joint_weights:
|
318 |
+
target_weight = np.multiply(target_weight, joint_weights)
|
319 |
+
|
320 |
+
return target, target_weight
|
321 |
+
|
322 |
+
def __call__(self, results):
|
323 |
+
"""Generate the target heatmap."""
|
324 |
+
joints_3d = results['joints_3d']
|
325 |
+
joints_3d_visible = results['joints_3d_visible']
|
326 |
+
|
327 |
+
assert self.encoding in ['MSRA', 'UDP']
|
328 |
+
|
329 |
+
if self.encoding == 'MSRA':
|
330 |
+
if isinstance(self.sigma, list):
|
331 |
+
num_sigmas = len(self.sigma)
|
332 |
+
cfg = results['ann_info']
|
333 |
+
num_joints = len(joints_3d)
|
334 |
+
heatmap_size = cfg['heatmap_size']
|
335 |
+
|
336 |
+
target = np.empty(
|
337 |
+
(0, num_joints, heatmap_size[1], heatmap_size[0]),
|
338 |
+
dtype=np.float32)
|
339 |
+
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
|
340 |
+
for i in range(num_sigmas):
|
341 |
+
target_i, target_weight_i = self._msra_generate_target(
|
342 |
+
cfg, joints_3d, joints_3d_visible, self.sigma[i])
|
343 |
+
target = np.concatenate([target, target_i[None]], axis=0)
|
344 |
+
target_weight = np.concatenate(
|
345 |
+
[target_weight, target_weight_i[None]], axis=0)
|
346 |
+
else:
|
347 |
+
target, target_weight = self._msra_generate_target(
|
348 |
+
results['ann_info'], joints_3d, joints_3d_visible,
|
349 |
+
self.sigma)
|
350 |
+
elif self.encoding == 'UDP':
|
351 |
+
if self.target_type == 'CombinedTarget':
|
352 |
+
factors = self.valid_radius_factor
|
353 |
+
channel_factor = 3
|
354 |
+
elif self.target_type == 'GaussianHeatMap':
|
355 |
+
factors = self.sigma
|
356 |
+
channel_factor = 1
|
357 |
+
if isinstance(factors, list):
|
358 |
+
num_factors = len(factors)
|
359 |
+
cfg = results['ann_info']
|
360 |
+
num_joints = len(joints_3d)
|
361 |
+
W, H = cfg['heatmap_size']
|
362 |
+
|
363 |
+
target = np.empty((0, channel_factor * num_joints, H, W),
|
364 |
+
dtype=np.float32)
|
365 |
+
target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
|
366 |
+
for i in range(num_factors):
|
367 |
+
target_i, target_weight_i = self._udp_generate_target(
|
368 |
+
cfg, joints_3d, joints_3d_visible, factors[i],
|
369 |
+
self.target_type)
|
370 |
+
target = np.concatenate([target, target_i[None]], axis=0)
|
371 |
+
target_weight = np.concatenate(
|
372 |
+
[target_weight, target_weight_i[None]], axis=0)
|
373 |
+
else:
|
374 |
+
target, target_weight = self._udp_generate_target(
|
375 |
+
results['ann_info'], joints_3d, joints_3d_visible, factors,
|
376 |
+
self.target_type)
|
377 |
+
else:
|
378 |
+
raise ValueError(
|
379 |
+
f'Encoding approach {self.encoding} is not supported!')
|
380 |
+
|
381 |
+
results['target'] = target
|
382 |
+
results['target_weight'] = target_weight
|
383 |
+
|
384 |
+
return results
|
385 |
+
|
386 |
+
@PIPELINES.register_module()
|
387 |
+
class LoadDepthFromFile:
|
388 |
+
"""Load depthmap from file.
|
389 |
+
|
390 |
+
Required Keys:
|
391 |
+
|
392 |
+
- depth_path
|
393 |
+
|
394 |
+
Modified Keys:
|
395 |
+
|
396 |
+
- depth
|
397 |
+
|
398 |
+
Args:
|
399 |
+
to_float32 (bool): Whether to convert the loaded depth to a float32
|
400 |
+
numpy array. If set to False, the loaded depth is an uint8 array.
|
401 |
+
Defaults to False.
|
402 |
+
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
403 |
+
Defaults to 'color'.
|
404 |
+
imdecode_backend (str): The depth decoding backend type. The backend
|
405 |
+
argument for :func:`mmcv.imfrombytes`.
|
406 |
+
See :func:`mmcv.imfrombytes` for details.
|
407 |
+
Defaults to 'cv2'.
|
408 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
409 |
+
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
410 |
+
Defaults to None. It will be deprecated in future. Please use
|
411 |
+
``backend_args`` instead.
|
412 |
+
Deprecated in version 2.0.0rc4.
|
413 |
+
ignore_empty (bool): Whether to allow loading empty depth or file path
|
414 |
+
not existent. Defaults to False.
|
415 |
+
backend_args (dict, optional): Instantiates the corresponding file
|
416 |
+
backend. It may contain `backend` key to specify the file
|
417 |
+
backend. If it contains, the file backend corresponding to this
|
418 |
+
value will be used and initialized with the remaining values,
|
419 |
+
otherwise the corresponding file backend will be selected
|
420 |
+
based on the prefix of the file path. Defaults to None.
|
421 |
+
New in version 2.0.0rc4.
|
422 |
+
"""
|
423 |
+
|
424 |
+
def __init__(self,
|
425 |
+
to_float32=False,
|
426 |
+
color_type='color',
|
427 |
+
channel_order='rgb',
|
428 |
+
file_client_args=dict(backend='disk')):
|
429 |
+
self.to_float32 = to_float32
|
430 |
+
self.color_type = color_type
|
431 |
+
self.channel_order = channel_order
|
432 |
+
self.file_client_args = file_client_args.copy()
|
433 |
+
self.file_client = None
|
434 |
+
|
435 |
+
def _read_depth(self, path):
|
436 |
+
img = np.load(path)['depth']
|
437 |
+
if img is None:
|
438 |
+
raise ValueError(f'Fail to read {path}')
|
439 |
+
if self.to_float32:
|
440 |
+
img = img.astype(np.float32)
|
441 |
+
return img
|
442 |
+
|
443 |
+
def __call__(self, results: dict) -> Optional[dict]:
|
444 |
+
"""Functions to load depth.
|
445 |
+
|
446 |
+
Args:
|
447 |
+
results (dict): Result dict from
|
448 |
+
:class:`mmengine.dataset.BaseDataset`.
|
449 |
+
|
450 |
+
Returns:
|
451 |
+
dict: The dict contains loaded depth and meta information.
|
452 |
+
"""
|
453 |
+
|
454 |
+
"""Loading depth(s) from file."""
|
455 |
+
if self.file_client is None:
|
456 |
+
self.file_client = mmcv.FileClient(**self.file_client_args)
|
457 |
+
|
458 |
+
depth_file = results.get('depth_file', None)
|
459 |
+
# Replace file extension with npy
|
460 |
+
pre, ext = os.path.splitext(depth_file)
|
461 |
+
depth_file = pre + '.npz'
|
462 |
+
if isinstance(depth_file, (list, tuple)):
|
463 |
+
# Load depths from a list of paths
|
464 |
+
results['depth'] = [self._read_depth(path) for path in depth_file]
|
465 |
+
elif depth_file is not None:
|
466 |
+
# Load single depth from path
|
467 |
+
results['depth'] = self._read_depth(depth_file)
|
468 |
+
else:
|
469 |
+
if 'depth' not in results:
|
470 |
+
# If `depth_file`` is not in results, check the `img` exists
|
471 |
+
# and format the depth. This for compatibility when the depth
|
472 |
+
# is manually set outside the pipeline.
|
473 |
+
raise KeyError('Either `depth_file` or `img` should exist in '
|
474 |
+
'results.')
|
475 |
+
if isinstance(results['depth'], (list, tuple)):
|
476 |
+
assert isinstance(results['depth'][0], np.ndarray)
|
477 |
+
else:
|
478 |
+
assert isinstance(results['depth'], np.ndarray)
|
479 |
+
results['depth_file'] = None
|
480 |
+
|
481 |
+
return results
|
482 |
+
|
483 |
+
def __repr__(self):
|
484 |
+
repr_str = (f'{self.__class__.__name__}('
|
485 |
+
f'to_float32={self.to_float32}, '
|
486 |
+
f"color_type='{self.color_type}', "
|
487 |
+
f'file_client_args={self.file_client_args})')
|
488 |
+
return repr_str
|
489 |
+
|
490 |
+
|
491 |
+
@PIPELINES.register_module()
|
492 |
+
class DepthTopDownAffineFewShot:
|
493 |
+
"""Affine transform the image to make input.
|
494 |
+
|
495 |
+
Required keys:'img', 'depth', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
|
496 |
+
'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
|
497 |
+
'joints_3d_visible'.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
use_udp (bool): To use unbiased data processing.
|
501 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
502 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
503 |
+
"""
|
504 |
+
|
505 |
+
def __init__(self, use_udp=False):
|
506 |
+
self.use_udp = use_udp
|
507 |
+
|
508 |
+
def __call__(self, results):
|
509 |
+
image_size = results['ann_info']['image_size']
|
510 |
+
|
511 |
+
img = results['img']
|
512 |
+
depth = results['depth']
|
513 |
+
joints_3d = results['joints_3d']
|
514 |
+
joints_3d_visible = results['joints_3d_visible']
|
515 |
+
c = results['center']
|
516 |
+
s = results['scale']
|
517 |
+
r = results['rotation']
|
518 |
+
|
519 |
+
if self.use_udp:
|
520 |
+
trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
|
521 |
+
img = cv2.warpAffine(
|
522 |
+
img,
|
523 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
524 |
+
flags=cv2.INTER_LINEAR)
|
525 |
+
depth = cv2.warpAffine(
|
526 |
+
depth,
|
527 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
528 |
+
flags=cv2.INTER_LINEAR)
|
529 |
+
joints_3d[:, 0:2] = warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
|
530 |
+
else:
|
531 |
+
trans = get_affine_transform(c, s, r, image_size)
|
532 |
+
img = cv2.warpAffine(
|
533 |
+
img,
|
534 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
535 |
+
flags=cv2.INTER_LINEAR)
|
536 |
+
depth = cv2.warpAffine(
|
537 |
+
depth,
|
538 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
539 |
+
flags=cv2.INTER_LINEAR)
|
540 |
+
for i in range(len(joints_3d)):
|
541 |
+
if joints_3d_visible[i, 0] > 0.0:
|
542 |
+
joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
|
543 |
+
|
544 |
+
results['img'] = img
|
545 |
+
results['depth'] = depth
|
546 |
+
results['joints_3d'] = joints_3d
|
547 |
+
results['joints_3d_visible'] = joints_3d_visible
|
548 |
+
|
549 |
+
return results
|
550 |
+
|
551 |
+
|
552 |
+
|
553 |
+
|
554 |
+
@PIPELINES.register_module()
|
555 |
+
class LoadFeatFromFile:
|
556 |
+
"""Load depthmap from file.
|
557 |
+
|
558 |
+
Required Keys:
|
559 |
+
|
560 |
+
- depth_path
|
561 |
+
|
562 |
+
Modified Keys:
|
563 |
+
|
564 |
+
- depth
|
565 |
+
|
566 |
+
Args:
|
567 |
+
to_float32 (bool): Whether to convert the loaded depth to a float32
|
568 |
+
numpy array. If set to False, the loaded depth is an uint8 array.
|
569 |
+
Defaults to False.
|
570 |
+
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
571 |
+
Defaults to 'color'.
|
572 |
+
imdecode_backend (str): The depth decoding backend type. The backend
|
573 |
+
argument for :func:`mmcv.imfrombytes`.
|
574 |
+
See :func:`mmcv.imfrombytes` for details.
|
575 |
+
Defaults to 'cv2'.
|
576 |
+
file_client_args (dict, optional): Arguments to instantiate a
|
577 |
+
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
578 |
+
Defaults to None. It will be deprecated in future. Please use
|
579 |
+
``backend_args`` instead.
|
580 |
+
Deprecated in version 2.0.0rc4.
|
581 |
+
ignore_empty (bool): Whether to allow loading empty depth or file path
|
582 |
+
not existent. Defaults to False.
|
583 |
+
backend_args (dict, optional): Instantiates the corresponding file
|
584 |
+
backend. It may contain `backend` key to specify the file
|
585 |
+
backend. If it contains, the file backend corresponding to this
|
586 |
+
value will be used and initialized with the remaining values,
|
587 |
+
otherwise the corresponding file backend will be selected
|
588 |
+
based on the prefix of the file path. Defaults to None.
|
589 |
+
New in version 2.0.0rc4.
|
590 |
+
"""
|
591 |
+
|
592 |
+
def __init__(self,
|
593 |
+
to_float32=False,
|
594 |
+
color_type='color',
|
595 |
+
channel_order='rgb',
|
596 |
+
file_client_args=dict(backend='disk')):
|
597 |
+
self.to_float32 = to_float32
|
598 |
+
self.color_type = color_type
|
599 |
+
self.channel_order = channel_order
|
600 |
+
self.file_client_args = file_client_args.copy()
|
601 |
+
self.file_client = None
|
602 |
+
|
603 |
+
def _read_depth(self, path):
|
604 |
+
img = np.load(path)['feat']
|
605 |
+
if img is None:
|
606 |
+
raise ValueError(f'Fail to read {path}')
|
607 |
+
if self.to_float32:
|
608 |
+
img = img.astype(np.float32)
|
609 |
+
return img
|
610 |
+
|
611 |
+
def __call__(self, results: dict) -> Optional[dict]:
|
612 |
+
"""Functions to load depth.
|
613 |
+
|
614 |
+
Args:
|
615 |
+
results (dict): Result dict from
|
616 |
+
:class:`mmengine.dataset.BaseDataset`.
|
617 |
+
|
618 |
+
Returns:
|
619 |
+
dict: The dict contains loaded depth and meta information.
|
620 |
+
"""
|
621 |
+
|
622 |
+
"""Loading depth(s) from file."""
|
623 |
+
if self.file_client is None:
|
624 |
+
self.file_client = mmcv.FileClient(**self.file_client_args)
|
625 |
+
|
626 |
+
feat_file = results.get('feat_file', None)
|
627 |
+
# Replace file extension with npy
|
628 |
+
pre, ext = os.path.splitext(feat_file)
|
629 |
+
feat_file = pre + '.npz'
|
630 |
+
if isinstance(feat_file, (list, tuple)):
|
631 |
+
# Load depths from a list of paths
|
632 |
+
results['feat'] = [self._read_depth(path) for path in feat_file]
|
633 |
+
elif feat_file is not None:
|
634 |
+
# Load single depth from path
|
635 |
+
results['feat'] = self._read_depth(feat_file)
|
636 |
+
else:
|
637 |
+
if 'feat_file' not in results:
|
638 |
+
# If `depth_file`` is not in results, check the `img` exists
|
639 |
+
# and format the depth. This for compatibility when the depth
|
640 |
+
# is manually set outside the pipeline.
|
641 |
+
raise KeyError('Either `feat_file` or `img` should exist in results.')
|
642 |
+
if isinstance(results['feat'], (list, tuple)):
|
643 |
+
assert isinstance(results['feat'][0], np.ndarray)
|
644 |
+
else:
|
645 |
+
assert isinstance(results['feat'], np.ndarray)
|
646 |
+
results['feat_file'] = None
|
647 |
+
|
648 |
+
return results
|
649 |
+
|
650 |
+
def __repr__(self):
|
651 |
+
repr_str = (f'{self.__class__.__name__}('
|
652 |
+
f'to_float32={self.to_float32}, '
|
653 |
+
f"color_type='{self.color_type}', "
|
654 |
+
f'file_client_args={self.file_client_args})')
|
655 |
+
return repr_str
|
656 |
+
|
657 |
+
|
658 |
+
@PIPELINES.register_module()
|
659 |
+
class FeatTopDownAffineFewShot:
|
660 |
+
"""Affine transform the image to make input.
|
661 |
+
|
662 |
+
Required keys:'img', 'depth', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
|
663 |
+
'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
|
664 |
+
'joints_3d_visible'.
|
665 |
+
|
666 |
+
Args:
|
667 |
+
use_udp (bool): To use unbiased data processing.
|
668 |
+
Paper ref: Huang et al. The Devil is in the Details: Delving into
|
669 |
+
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
|
670 |
+
"""
|
671 |
+
|
672 |
+
def __init__(self, use_udp=False):
|
673 |
+
self.use_udp = use_udp
|
674 |
+
|
675 |
+
def __call__(self, results):
|
676 |
+
image_size = results['ann_info']['image_size']
|
677 |
+
|
678 |
+
img = results['img']
|
679 |
+
feat = results['feat']
|
680 |
+
joints_3d = results['joints_3d']
|
681 |
+
joints_3d_visible = results['joints_3d_visible']
|
682 |
+
c = results['center']
|
683 |
+
s = results['scale']
|
684 |
+
r = results['rotation']
|
685 |
+
|
686 |
+
if self.use_udp:
|
687 |
+
trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
|
688 |
+
img = cv2.warpAffine(
|
689 |
+
img,
|
690 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
691 |
+
flags=cv2.INTER_LINEAR)
|
692 |
+
feat = cv2.warpAffine(
|
693 |
+
feat,
|
694 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
695 |
+
flags=cv2.INTER_LINEAR)
|
696 |
+
joints_3d[:, 0:2] = warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
|
697 |
+
else:
|
698 |
+
trans = get_affine_transform(c, s, r, image_size)
|
699 |
+
img = cv2.warpAffine(
|
700 |
+
img,
|
701 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
702 |
+
flags=cv2.INTER_LINEAR)
|
703 |
+
feat = cv2.warpAffine(
|
704 |
+
feat,
|
705 |
+
trans, (int(image_size[0]), int(image_size[1])),
|
706 |
+
flags=cv2.INTER_LINEAR)
|
707 |
+
for i in range(len(joints_3d)):
|
708 |
+
if joints_3d_visible[i, 0] > 0.0:
|
709 |
+
joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
|
710 |
+
|
711 |
+
results['img'] = img
|
712 |
+
results['depth'] = feat
|
713 |
+
results['joints_3d'] = joints_3d
|
714 |
+
results['joints_3d_visible'] = joints_3d_visible
|
715 |
+
|
716 |
+
return results
|
EdgeCape/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .detectors import * # noqa
|
2 |
+
from .keypoint_heads import * # noqa
|
3 |
+
from .backbones import * # noqa
|
EdgeCape/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (227 Bytes). View file
|
|
EdgeCape/models/backbones/__pycache__/adapter.cpython-39.pyc
ADDED
Binary file (27.8 kB). View file
|
|
EdgeCape/models/backbones/__pycache__/dino.cpython-39.pyc
ADDED
Binary file (5.48 kB). View file
|
|
EdgeCape/models/backbones/adapter.py
ADDED
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import fvcore.nn.weight_init as weight_init
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn.functional import interpolate
|
8 |
+
|
9 |
+
"""
|
10 |
+
Code is based on: https://github.com/mbanani/probe3d
|
11 |
+
"""
|
12 |
+
|
13 |
+
|
14 |
+
class SurfaceNormalHead(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
feat_dim,
|
18 |
+
head_type="multiscale",
|
19 |
+
uncertainty_aware=False,
|
20 |
+
hidden_dim=512,
|
21 |
+
kernel_size=1,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.uncertainty_aware = uncertainty_aware
|
26 |
+
output_dim = 4 if uncertainty_aware else 3
|
27 |
+
|
28 |
+
self.kernel_size = kernel_size
|
29 |
+
|
30 |
+
assert head_type in ["linear", "multiscale", "dpt"]
|
31 |
+
name = f"snorm_{head_type}_k{kernel_size}"
|
32 |
+
self.name = f"{name}_UA" if uncertainty_aware else name
|
33 |
+
|
34 |
+
if head_type == "linear":
|
35 |
+
self.head = Linear(feat_dim, output_dim, kernel_size)
|
36 |
+
elif head_type == "multiscale":
|
37 |
+
self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
|
38 |
+
elif head_type == "dpt":
|
39 |
+
self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Unknown head type: {self.head_type}")
|
42 |
+
|
43 |
+
def forward(self, feats):
|
44 |
+
return self.head(feats)
|
45 |
+
|
46 |
+
|
47 |
+
class DepthHead(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
feat_dim,
|
51 |
+
head_type="multiscale",
|
52 |
+
min_depth=0.001,
|
53 |
+
max_depth=10,
|
54 |
+
prediction_type="bindepth",
|
55 |
+
hidden_dim=512,
|
56 |
+
kernel_size=1,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.kernel_size = kernel_size
|
61 |
+
self.name = f"{prediction_type}_{head_type}_k{kernel_size}"
|
62 |
+
|
63 |
+
if prediction_type == "bindepth":
|
64 |
+
output_dim = 256
|
65 |
+
self.predict = DepthBinPrediction(min_depth, max_depth, n_bins=output_dim)
|
66 |
+
elif prediction_type == "sigdepth":
|
67 |
+
output_dim = 1
|
68 |
+
self.predict = DepthSigmoidPrediction(min_depth, max_depth)
|
69 |
+
else:
|
70 |
+
raise ValueError()
|
71 |
+
|
72 |
+
if head_type == "linear":
|
73 |
+
self.head = Linear(feat_dim, output_dim, kernel_size)
|
74 |
+
elif head_type == "multiscale":
|
75 |
+
self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
|
76 |
+
elif head_type == "dpt":
|
77 |
+
self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unknown head type: {self.head_type}")
|
80 |
+
|
81 |
+
def forward(self, feats):
|
82 |
+
"""Prediction each pixel."""
|
83 |
+
feats = self.head(feats)
|
84 |
+
depth = self.predict(feats)
|
85 |
+
return depth
|
86 |
+
|
87 |
+
|
88 |
+
class DepthBinPrediction(nn.Module):
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
min_depth=0.001,
|
92 |
+
max_depth=10,
|
93 |
+
n_bins=256,
|
94 |
+
bins_strategy="UD",
|
95 |
+
norm_strategy="linear",
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
self.n_bins = n_bins
|
99 |
+
self.min_depth = min_depth
|
100 |
+
self.max_depth = max_depth
|
101 |
+
self.norm_strategy = norm_strategy
|
102 |
+
self.bins_strategy = bins_strategy
|
103 |
+
|
104 |
+
def forward(self, prob):
|
105 |
+
if self.bins_strategy == "UD":
|
106 |
+
bins = torch.linspace(
|
107 |
+
self.min_depth, self.max_depth, self.n_bins, device=prob.device
|
108 |
+
)
|
109 |
+
elif self.bins_strategy == "SID":
|
110 |
+
bins = torch.logspace(
|
111 |
+
self.min_depth, self.max_depth, self.n_bins, device=prob.device
|
112 |
+
)
|
113 |
+
|
114 |
+
# following Adabins, default linear
|
115 |
+
if self.norm_strategy == "linear":
|
116 |
+
prob = torch.relu(prob)
|
117 |
+
eps = 0.1
|
118 |
+
prob = prob + eps
|
119 |
+
prob = prob / prob.sum(dim=1, keepdim=True)
|
120 |
+
elif self.norm_strategy == "softmax":
|
121 |
+
prob = torch.softmax(prob, dim=1)
|
122 |
+
elif self.norm_strategy == "sigmoid":
|
123 |
+
prob = torch.sigmoid(prob)
|
124 |
+
prob = prob / prob.sum(dim=1, keepdim=True)
|
125 |
+
|
126 |
+
depth = torch.einsum("ikhw,k->ihw", [prob, bins])
|
127 |
+
depth = depth.unsqueeze(dim=1)
|
128 |
+
return depth
|
129 |
+
|
130 |
+
|
131 |
+
class DepthSigmoidPrediction(nn.Module):
|
132 |
+
def __init__(self, min_depth=0.001, max_depth=10):
|
133 |
+
super().__init__()
|
134 |
+
self.min_depth = min_depth
|
135 |
+
self.max_depth = max_depth
|
136 |
+
|
137 |
+
def forward(self, pred):
|
138 |
+
depth = pred.sigmoid()
|
139 |
+
depth = self.min_depth + depth * (self.max_depth - self.min_depth)
|
140 |
+
return depth
|
141 |
+
|
142 |
+
|
143 |
+
class FeatureFusionBlock(nn.Module):
|
144 |
+
def __init__(self, features, kernel_size, with_skip=True):
|
145 |
+
super().__init__()
|
146 |
+
self.with_skip = with_skip
|
147 |
+
if self.with_skip:
|
148 |
+
self.resConfUnit1 = ResidualConvUnit(features, kernel_size)
|
149 |
+
|
150 |
+
self.resConfUnit2 = ResidualConvUnit(features, kernel_size)
|
151 |
+
|
152 |
+
def forward(self, x, skip_x=None):
|
153 |
+
if skip_x is not None:
|
154 |
+
assert self.with_skip and skip_x.shape == x.shape
|
155 |
+
x = self.resConfUnit1(x) + skip_x
|
156 |
+
|
157 |
+
x = self.resConfUnit2(x)
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class ResidualConvUnit(nn.Module):
|
162 |
+
def __init__(self, features, kernel_size):
|
163 |
+
super().__init__()
|
164 |
+
assert kernel_size % 1 == 0, "Kernel size needs to be odd"
|
165 |
+
padding = kernel_size // 2
|
166 |
+
self.conv = nn.Sequential(
|
167 |
+
nn.Conv2d(features, features, kernel_size, padding=padding),
|
168 |
+
nn.ReLU(True),
|
169 |
+
nn.Conv2d(features, features, kernel_size, padding=padding),
|
170 |
+
nn.ReLU(True),
|
171 |
+
)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
return self.conv(x) + x
|
175 |
+
|
176 |
+
|
177 |
+
class DPT(nn.Module):
|
178 |
+
def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=3, hr=False, swin=False):
|
179 |
+
super().__init__()
|
180 |
+
assert len(input_dims) == 4
|
181 |
+
self.hr = hr
|
182 |
+
self.conv_0 = nn.Conv2d(input_dims[0], hidden_dim, 1, padding=0)
|
183 |
+
self.conv_1 = nn.Conv2d(input_dims[1], hidden_dim, 1, padding=0)
|
184 |
+
self.conv_2 = nn.Conv2d(input_dims[2], hidden_dim, 1, padding=0)
|
185 |
+
self.conv_3 = nn.Conv2d(input_dims[3], hidden_dim, 1, padding=0)
|
186 |
+
|
187 |
+
self.ref_0 = FeatureFusionBlock(hidden_dim, kernel_size)
|
188 |
+
self.ref_1 = FeatureFusionBlock(hidden_dim, kernel_size)
|
189 |
+
self.ref_2 = FeatureFusionBlock(hidden_dim, kernel_size)
|
190 |
+
self.ref_3 = FeatureFusionBlock(hidden_dim, kernel_size, with_skip=False)
|
191 |
+
|
192 |
+
self.out_conv = nn.Sequential(
|
193 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
194 |
+
nn.ReLU(True),
|
195 |
+
nn.Conv2d(hidden_dim, output_dim, 3, padding=1),
|
196 |
+
)
|
197 |
+
|
198 |
+
if swin:
|
199 |
+
self.scale_factor = [1, 2, 4, 4]
|
200 |
+
else:
|
201 |
+
self.scale_factor = [2, 2, 2, 2]
|
202 |
+
|
203 |
+
def forward(self, features):
|
204 |
+
"""Prediction each pixel."""
|
205 |
+
assert len(features) == 4
|
206 |
+
feats = features.copy()
|
207 |
+
feats[0] = self.conv_0(feats[0])
|
208 |
+
feats[1] = self.conv_1(feats[1])
|
209 |
+
feats[2] = self.conv_2(feats[2])
|
210 |
+
feats[3] = self.conv_3(feats[3])
|
211 |
+
|
212 |
+
feats = [interpolate(x, scale_factor=scale_factor) for x, scale_factor in zip(feats, self.scale_factor)]
|
213 |
+
|
214 |
+
out = self.ref_3(feats[3], None)
|
215 |
+
out = self.ref_2(feats[2], out)
|
216 |
+
out = self.ref_1(feats[1], out)
|
217 |
+
out = self.ref_0(feats[0], out)
|
218 |
+
if not self.hr:
|
219 |
+
return self.out_conv(out)
|
220 |
+
out = interpolate(out, scale_factor=4)
|
221 |
+
out = self.out_conv(out)
|
222 |
+
# out = interpolate(out, scale_factor=2)
|
223 |
+
return out
|
224 |
+
|
225 |
+
|
226 |
+
def make_conv(input_dim, hidden_dim, output_dim, num_layers, kernel_size=1):
|
227 |
+
return conv
|
228 |
+
|
229 |
+
|
230 |
+
class Linear(nn.Module):
|
231 |
+
def __init__(self, input_dim, output_dim, kernel_size=1):
|
232 |
+
super().__init__()
|
233 |
+
if type(input_dim) is not int:
|
234 |
+
input_dim = sum(input_dim)
|
235 |
+
|
236 |
+
assert type(input_dim) is int
|
237 |
+
padding = kernel_size // 2
|
238 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding)
|
239 |
+
|
240 |
+
def forward(self, feats):
|
241 |
+
if type(feats) is list:
|
242 |
+
feats = torch.cat(feats, dim=1)
|
243 |
+
|
244 |
+
feats = interpolate(feats, scale_factor=4, mode="bilinear")
|
245 |
+
return self.conv(feats)
|
246 |
+
|
247 |
+
|
248 |
+
class MultiscaleHead(nn.Module):
|
249 |
+
def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=1):
|
250 |
+
super().__init__()
|
251 |
+
|
252 |
+
self.convs = nn.ModuleList(
|
253 |
+
[make_conv(in_d, None, hidden_dim, 1, kernel_size) for in_d in input_dims]
|
254 |
+
)
|
255 |
+
interm_dim = len(input_dims) * hidden_dim
|
256 |
+
self.conv_mid = make_conv(interm_dim, hidden_dim, hidden_dim, 3, kernel_size)
|
257 |
+
self.conv_out = make_conv(hidden_dim, hidden_dim, output_dim, 2, kernel_size)
|
258 |
+
|
259 |
+
def forward(self, feats):
|
260 |
+
num_feats = len(feats)
|
261 |
+
feats = [self.convs[i](feats[i]) for i in range(num_feats)]
|
262 |
+
|
263 |
+
h, w = feats[-1].shape[-2:]
|
264 |
+
feats = [interpolate(feat, (h, w), mode="bilinear") for feat in feats]
|
265 |
+
feats = torch.cat(feats, dim=1).relu()
|
266 |
+
|
267 |
+
# upsample
|
268 |
+
feats = interpolate(feats, scale_factor=2, mode="bilinear")
|
269 |
+
feats = self.conv_mid(feats).relu()
|
270 |
+
feats = interpolate(feats, scale_factor=4, mode="bilinear")
|
271 |
+
return self.conv_out(feats)
|
272 |
+
|
273 |
+
def get_norm(norm, out_channels, num_norm_groups=32):
|
274 |
+
"""
|
275 |
+
Args:
|
276 |
+
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
277 |
+
or a callable that takes a channel number and returns
|
278 |
+
the normalization layer as a nn.Module.
|
279 |
+
Returns:
|
280 |
+
nn.Module or None: the normalization layer
|
281 |
+
"""
|
282 |
+
if norm is None:
|
283 |
+
return None
|
284 |
+
if isinstance(norm, str):
|
285 |
+
if len(norm) == 0:
|
286 |
+
return None
|
287 |
+
norm = {
|
288 |
+
"GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
|
289 |
+
}[norm]
|
290 |
+
return norm(out_channels)
|
291 |
+
|
292 |
+
|
293 |
+
def get_activation(activation):
|
294 |
+
"""
|
295 |
+
Args:
|
296 |
+
activation (str or callable): either one of relu, lrelu, prelu, leaky_relu,
|
297 |
+
sigmoid, tanh, elu, selu, swish, mish; or a callable that takes a
|
298 |
+
tensor and returns a tensor.
|
299 |
+
Returns:
|
300 |
+
nn.Module or None: the activation layer
|
301 |
+
"""
|
302 |
+
if activation is None:
|
303 |
+
return None
|
304 |
+
if isinstance(activation, str):
|
305 |
+
if len(activation) == 0:
|
306 |
+
return None
|
307 |
+
activation = {
|
308 |
+
"relu": nn.ReLU,
|
309 |
+
"lrelu": nn.LeakyReLU,
|
310 |
+
"prelu": nn.PReLU,
|
311 |
+
"leaky_relu": nn.LeakyReLU,
|
312 |
+
"sigmoid": nn.Sigmoid,
|
313 |
+
"tanh": nn.Tanh,
|
314 |
+
"elu": nn.ELU,
|
315 |
+
"selu": nn.SELU,
|
316 |
+
}[activation]
|
317 |
+
return activation()
|
318 |
+
|
319 |
+
|
320 |
+
# SCE crisscross + diags
|
321 |
+
class EfficientSpatialContextNet(nn.Module):
|
322 |
+
def __init__(self, kernel_size=7, in_channels=768, out_channels=768, use_cuda=True):
|
323 |
+
super(EfficientSpatialContextNet, self).__init__()
|
324 |
+
self.kernel_size = kernel_size
|
325 |
+
self.pad = kernel_size // 2
|
326 |
+
self.conv = torch.nn.Conv2d(
|
327 |
+
in_channels + 4 * self.kernel_size,
|
328 |
+
out_channels,
|
329 |
+
1,
|
330 |
+
bias=True,
|
331 |
+
padding_mode="zeros",
|
332 |
+
)
|
333 |
+
|
334 |
+
if use_cuda:
|
335 |
+
self.conv = self.conv.cuda()
|
336 |
+
|
337 |
+
def forward(self, feature):
|
338 |
+
b, c, h, w = feature.size()
|
339 |
+
feature_normalized = F.normalize(feature, p=2, dim=1)
|
340 |
+
feature_pad = F.pad(
|
341 |
+
feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0
|
342 |
+
)
|
343 |
+
output = torch.zeros(
|
344 |
+
[4 * self.kernel_size, b, h, w],
|
345 |
+
dtype=feature.dtype,
|
346 |
+
requires_grad=feature.requires_grad,
|
347 |
+
)
|
348 |
+
if feature.is_cuda:
|
349 |
+
output = output.cuda(feature.get_device())
|
350 |
+
|
351 |
+
# left-top to right-bottom
|
352 |
+
for i in range(self.kernel_size):
|
353 |
+
c = i
|
354 |
+
r = i
|
355 |
+
output[i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
|
356 |
+
|
357 |
+
# col
|
358 |
+
for i in range(self.kernel_size):
|
359 |
+
c = self.kernel_size // 2
|
360 |
+
r = i
|
361 |
+
output[1 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
|
362 |
+
|
363 |
+
# right-top to left-bottom
|
364 |
+
for i in range(self.kernel_size):
|
365 |
+
c = (self.kernel_size - 1) - i
|
366 |
+
r = i
|
367 |
+
output[2 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
|
368 |
+
|
369 |
+
# row
|
370 |
+
for i in range(self.kernel_size):
|
371 |
+
c = i
|
372 |
+
r = self.kernel_size // 2
|
373 |
+
output[3 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
|
374 |
+
|
375 |
+
output = output.transpose(0, 1).contiguous()
|
376 |
+
output = torch.cat((feature, output), 1)
|
377 |
+
output = self.conv(output)
|
378 |
+
# output = F.relu(output)
|
379 |
+
|
380 |
+
return output
|
381 |
+
|
382 |
+
|
383 |
+
class Conv2d(nn.Conv2d):
|
384 |
+
"""
|
385 |
+
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
|
386 |
+
"""
|
387 |
+
|
388 |
+
def __init__(self, *args, **kwargs):
|
389 |
+
"""
|
390 |
+
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
|
391 |
+
Args:
|
392 |
+
norm (nn.Module, optional): a normalization layer
|
393 |
+
activation (callable(Tensor) -> Tensor): a callable activation function
|
394 |
+
It assumes that norm layer is used before activation.
|
395 |
+
"""
|
396 |
+
norm = kwargs.pop("norm", None)
|
397 |
+
activation = kwargs.pop("activation", None)
|
398 |
+
super().__init__(*args, **kwargs)
|
399 |
+
|
400 |
+
self.norm = norm
|
401 |
+
self.activation = activation
|
402 |
+
|
403 |
+
def forward(self, x):
|
404 |
+
x = F.conv2d(
|
405 |
+
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
406 |
+
)
|
407 |
+
if self.norm is not None:
|
408 |
+
x = self.norm(x)
|
409 |
+
if self.activation is not None:
|
410 |
+
x = self.activation(x)
|
411 |
+
return x
|
412 |
+
|
413 |
+
|
414 |
+
class CNNBlockBase(nn.Module):
|
415 |
+
"""
|
416 |
+
A CNN block is assumed to have input channels, output channels and a stride.
|
417 |
+
The input and output of `forward()` method must be NCHW tensors.
|
418 |
+
The method can perform arbitrary computation but must match the given
|
419 |
+
channels and stride specification.
|
420 |
+
Attribute:
|
421 |
+
in_channels (int):
|
422 |
+
out_channels (int):
|
423 |
+
stride (int):
|
424 |
+
"""
|
425 |
+
|
426 |
+
def __init__(self, in_channels, out_channels, stride):
|
427 |
+
"""
|
428 |
+
The `__init__` method of any subclass should also contain these arguments.
|
429 |
+
Args:
|
430 |
+
in_channels (int):
|
431 |
+
out_channels (int):
|
432 |
+
stride (int):
|
433 |
+
"""
|
434 |
+
super().__init__()
|
435 |
+
self.in_channels = in_channels
|
436 |
+
self.out_channels = out_channels
|
437 |
+
self.stride = stride
|
438 |
+
|
439 |
+
|
440 |
+
class BottleneckBlock(CNNBlockBase):
|
441 |
+
"""
|
442 |
+
The standard bottleneck residual block used by ResNet-50, 101 and 152
|
443 |
+
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
|
444 |
+
1x1, 3x3, 1x1, and a projection shortcut if needed.
|
445 |
+
"""
|
446 |
+
|
447 |
+
def __init__(
|
448 |
+
self,
|
449 |
+
in_channels,
|
450 |
+
out_channels,
|
451 |
+
*,
|
452 |
+
bottleneck_channels,
|
453 |
+
stride=1,
|
454 |
+
num_groups=1,
|
455 |
+
norm="GN",
|
456 |
+
stride_in_1x1=False,
|
457 |
+
dilation=1,
|
458 |
+
num_norm_groups=32,
|
459 |
+
kernel_size=(1, 3, 1)
|
460 |
+
):
|
461 |
+
"""
|
462 |
+
Args:
|
463 |
+
bottleneck_channels (int): number of output channels for the 3x3
|
464 |
+
"bottleneck" conv layers.
|
465 |
+
num_groups (int): number of groups for the 3x3 conv layer.
|
466 |
+
norm (str or callable): normalization for all conv layers.
|
467 |
+
See :func:`layers.get_norm` for supported format.
|
468 |
+
stride_in_1x1 (bool): when stride>1, whether to put stride in the
|
469 |
+
first 1x1 convolution or the bottleneck 3x3 convolution.
|
470 |
+
dilation (int): the dilation rate of the 3x3 conv layer.
|
471 |
+
"""
|
472 |
+
super().__init__(in_channels, out_channels, stride)
|
473 |
+
|
474 |
+
if in_channels != out_channels:
|
475 |
+
self.shortcut = Conv2d(
|
476 |
+
in_channels,
|
477 |
+
out_channels,
|
478 |
+
kernel_size=1,
|
479 |
+
stride=stride,
|
480 |
+
bias=False,
|
481 |
+
norm=get_norm(norm, out_channels, num_norm_groups),
|
482 |
+
)
|
483 |
+
else:
|
484 |
+
self.shortcut = None
|
485 |
+
|
486 |
+
# The original MSRA ResNet models have stride in the first 1x1 conv
|
487 |
+
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
|
488 |
+
# stride in the 3x3 conv
|
489 |
+
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
|
490 |
+
|
491 |
+
self.conv1 = Conv2d(
|
492 |
+
in_channels,
|
493 |
+
bottleneck_channels,
|
494 |
+
kernel_size=kernel_size[0],
|
495 |
+
stride=stride_1x1,
|
496 |
+
padding=(kernel_size[0] - 1) // 2,
|
497 |
+
bias=False,
|
498 |
+
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
499 |
+
)
|
500 |
+
|
501 |
+
self.conv2 = Conv2d(
|
502 |
+
bottleneck_channels,
|
503 |
+
bottleneck_channels,
|
504 |
+
kernel_size=kernel_size[1],
|
505 |
+
stride=stride_3x3,
|
506 |
+
padding=dilation * (kernel_size[1] - 1) // 2,
|
507 |
+
bias=False,
|
508 |
+
groups=num_groups,
|
509 |
+
dilation=dilation,
|
510 |
+
norm=get_norm(norm, bottleneck_channels, num_norm_groups),
|
511 |
+
)
|
512 |
+
|
513 |
+
self.conv3 = Conv2d(
|
514 |
+
bottleneck_channels,
|
515 |
+
out_channels,
|
516 |
+
kernel_size=kernel_size[2],
|
517 |
+
bias=False,
|
518 |
+
norm=get_norm(norm, out_channels, num_norm_groups),
|
519 |
+
)
|
520 |
+
|
521 |
+
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
|
522 |
+
if layer is not None: # shortcut can be None
|
523 |
+
weight_init.c2_msra_fill(layer)
|
524 |
+
|
525 |
+
# Zero-initialize the last normalization in each residual branch,
|
526 |
+
# so that at the beginning, the residual branch starts with zeros,
|
527 |
+
# and each residual block behaves like an identity.
|
528 |
+
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
529 |
+
# "For BN layers, the learnable scaling coefficient γ is initialized
|
530 |
+
# to be 1, except for each residual block's last BN
|
531 |
+
# where γ is initialized to be 0."
|
532 |
+
|
533 |
+
# nn.init.constant_(self.conv3.norm.weight, 0)
|
534 |
+
# TODO this somehow hurts performance when training GN models from scratch.
|
535 |
+
# Add it as an option when we need to use this code to train a backbone.
|
536 |
+
|
537 |
+
def forward(self, x):
|
538 |
+
out = self.conv1(x)
|
539 |
+
out = F.relu_(out)
|
540 |
+
|
541 |
+
out = self.conv2(out)
|
542 |
+
out = F.relu_(out)
|
543 |
+
|
544 |
+
out = self.conv3(out)
|
545 |
+
|
546 |
+
if self.shortcut is not None:
|
547 |
+
shortcut = self.shortcut(x)
|
548 |
+
else:
|
549 |
+
shortcut = x
|
550 |
+
|
551 |
+
out += shortcut
|
552 |
+
out = F.relu_(out)
|
553 |
+
return out
|
554 |
+
|
555 |
+
|
556 |
+
class ResNet(nn.Module):
|
557 |
+
"""
|
558 |
+
Implement :paper:`ResNet`.
|
559 |
+
"""
|
560 |
+
|
561 |
+
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
|
562 |
+
"""
|
563 |
+
Args:
|
564 |
+
stem (nn.Module): a stem module
|
565 |
+
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
|
566 |
+
each contains multiple :class:`CNNBlockBase`.
|
567 |
+
num_classes (None or int): if None, will not perform classification.
|
568 |
+
Otherwise, will create a linear layer.
|
569 |
+
out_features (list[str]): name of the layers whose outputs should
|
570 |
+
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
|
571 |
+
If None, will return the output of the last layer.
|
572 |
+
freeze_at (int): The number of stages at the beginning to freeze.
|
573 |
+
see :meth:`freeze` for detailed explanation.
|
574 |
+
"""
|
575 |
+
super().__init__()
|
576 |
+
self.stem = stem
|
577 |
+
self.num_classes = num_classes
|
578 |
+
|
579 |
+
current_stride = self.stem.stride
|
580 |
+
self._out_feature_strides = {"stem": current_stride}
|
581 |
+
self._out_feature_channels = {"stem": self.stem.out_channels}
|
582 |
+
|
583 |
+
self.stage_names, self.stages = [], []
|
584 |
+
|
585 |
+
if out_features is not None:
|
586 |
+
# Avoid keeping unused layers in this module. They consume extra memory
|
587 |
+
# and may cause allreduce to fail
|
588 |
+
num_stages = max(
|
589 |
+
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
|
590 |
+
)
|
591 |
+
stages = stages[:num_stages]
|
592 |
+
for i, blocks in enumerate(stages):
|
593 |
+
assert len(blocks) > 0, len(blocks)
|
594 |
+
for block in blocks:
|
595 |
+
assert isinstance(block, CNNBlockBase), block
|
596 |
+
|
597 |
+
name = "res" + str(i + 2)
|
598 |
+
stage = nn.Sequential(*blocks)
|
599 |
+
|
600 |
+
self.add_module(name, stage)
|
601 |
+
self.stage_names.append(name)
|
602 |
+
self.stages.append(stage)
|
603 |
+
|
604 |
+
self._out_feature_strides[name] = current_stride = int(
|
605 |
+
current_stride * np.prod([k.stride for k in blocks])
|
606 |
+
)
|
607 |
+
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
|
608 |
+
self.stage_names = tuple(self.stage_names) # Make it static for scripting
|
609 |
+
|
610 |
+
if num_classes is not None:
|
611 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
612 |
+
self.linear = nn.Linear(curr_channels, num_classes)
|
613 |
+
|
614 |
+
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
|
615 |
+
# "The 1000-way fully-connected layer is initialized by
|
616 |
+
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
|
617 |
+
nn.init.normal_(self.linear.weight, std=0.01)
|
618 |
+
name = "linear"
|
619 |
+
|
620 |
+
if out_features is None:
|
621 |
+
out_features = [name]
|
622 |
+
self._out_features = out_features
|
623 |
+
assert len(self._out_features)
|
624 |
+
children = [x[0] for x in self.named_children()]
|
625 |
+
for out_feature in self._out_features:
|
626 |
+
assert out_feature in children, "Available children: {}".format(", ".join(children))
|
627 |
+
self.freeze(freeze_at)
|
628 |
+
|
629 |
+
def forward(self, x):
|
630 |
+
"""
|
631 |
+
Args:
|
632 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
633 |
+
Returns:
|
634 |
+
dict[str->Tensor]: names and the corresponding features
|
635 |
+
"""
|
636 |
+
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
|
637 |
+
outputs = {}
|
638 |
+
x = self.stem(x)
|
639 |
+
if "stem" in self._out_features:
|
640 |
+
outputs["stem"] = x
|
641 |
+
for name, stage in zip(self.stage_names, self.stages):
|
642 |
+
x = stage(x)
|
643 |
+
if name in self._out_features:
|
644 |
+
outputs[name] = x
|
645 |
+
if self.num_classes is not None:
|
646 |
+
x = self.avgpool(x)
|
647 |
+
x = torch.flatten(x, 1)
|
648 |
+
x = self.linear(x)
|
649 |
+
if "linear" in self._out_features:
|
650 |
+
outputs["linear"] = x
|
651 |
+
return outputs
|
652 |
+
|
653 |
+
def freeze(self, freeze_at=0):
|
654 |
+
"""
|
655 |
+
Freeze the first several stages of the ResNet. Commonly used in
|
656 |
+
fine-tuning.
|
657 |
+
Layers that produce the same feature map spatial size are defined as one
|
658 |
+
"stage" by :paper:`FPN`.
|
659 |
+
Args:
|
660 |
+
freeze_at (int): number of stages to freeze.
|
661 |
+
`1` means freezing the stem. `2` means freezing the stem and
|
662 |
+
one residual stage, etc.
|
663 |
+
Returns:
|
664 |
+
nn.Module: this ResNet itself
|
665 |
+
"""
|
666 |
+
if freeze_at >= 1:
|
667 |
+
self.stem.freeze()
|
668 |
+
for idx, stage in enumerate(self.stages, start=2):
|
669 |
+
if freeze_at >= idx:
|
670 |
+
for block in stage.children():
|
671 |
+
block.freeze()
|
672 |
+
return self
|
673 |
+
|
674 |
+
@staticmethod
|
675 |
+
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
|
676 |
+
"""
|
677 |
+
Create a list of blocks of the same type that forms one ResNet stage.
|
678 |
+
Args:
|
679 |
+
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
|
680 |
+
stage. A module of this type must not change spatial resolution of inputs unless its
|
681 |
+
stride != 1.
|
682 |
+
num_blocks (int): number of blocks in this stage
|
683 |
+
in_channels (int): input channels of the entire stage.
|
684 |
+
out_channels (int): output channels of **every block** in the stage.
|
685 |
+
kwargs: other arguments passed to the constructor of
|
686 |
+
`block_class`. If the argument name is "xx_per_block", the
|
687 |
+
argument is a list of values to be passed to each block in the
|
688 |
+
stage. Otherwise, the same argument is passed to every block
|
689 |
+
in the stage.
|
690 |
+
Returns:
|
691 |
+
list[CNNBlockBase]: a list of block module.
|
692 |
+
Examples:
|
693 |
+
::
|
694 |
+
stage = ResNet.make_stage(
|
695 |
+
BottleneckBlock, 3, in_channels=16, out_channels=64,
|
696 |
+
bottleneck_channels=16, num_groups=1,
|
697 |
+
stride_per_block=[2, 1, 1],
|
698 |
+
dilations_per_block=[1, 1, 2]
|
699 |
+
)
|
700 |
+
Usually, layers that produce the same feature map spatial size are defined as one
|
701 |
+
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
|
702 |
+
all be 1.
|
703 |
+
"""
|
704 |
+
blocks = []
|
705 |
+
for i in range(num_blocks):
|
706 |
+
curr_kwargs = {}
|
707 |
+
for k, v in kwargs.items():
|
708 |
+
if k.endswith("_per_block"):
|
709 |
+
assert len(v) == num_blocks, (
|
710 |
+
f"Argument '{k}' of make_stage should have the "
|
711 |
+
f"same length as num_blocks={num_blocks}."
|
712 |
+
)
|
713 |
+
newk = k[: -len("_per_block")]
|
714 |
+
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
|
715 |
+
curr_kwargs[newk] = v[i]
|
716 |
+
else:
|
717 |
+
curr_kwargs[k] = v
|
718 |
+
|
719 |
+
blocks.append(
|
720 |
+
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
|
721 |
+
)
|
722 |
+
in_channels = out_channels
|
723 |
+
return blocks
|
724 |
+
|
725 |
+
@staticmethod
|
726 |
+
def make_default_stages(depth, block_class=None, **kwargs):
|
727 |
+
"""
|
728 |
+
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
|
729 |
+
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
|
730 |
+
instead for fine-grained customization.
|
731 |
+
Args:
|
732 |
+
depth (int): depth of ResNet
|
733 |
+
block_class (type): the CNN block class. Has to accept
|
734 |
+
`bottleneck_channels` argument for depth > 50.
|
735 |
+
By default it is BasicBlock or BottleneckBlock, based on the
|
736 |
+
depth.
|
737 |
+
kwargs:
|
738 |
+
other arguments to pass to `make_stage`. Should not contain
|
739 |
+
stride and channels, as they are predefined for each depth.
|
740 |
+
Returns:
|
741 |
+
list[list[CNNBlockBase]]: modules in all stages; see arguments of
|
742 |
+
:class:`ResNet.__init__`.
|
743 |
+
"""
|
744 |
+
num_blocks_per_stage = {
|
745 |
+
18: [2, 2, 2, 2],
|
746 |
+
34: [3, 4, 6, 3],
|
747 |
+
50: [3, 4, 6, 3],
|
748 |
+
101: [3, 4, 23, 3],
|
749 |
+
152: [3, 8, 36, 3],
|
750 |
+
}[depth]
|
751 |
+
if block_class is None:
|
752 |
+
block_class = BasicBlock if depth < 50 else BottleneckBlock
|
753 |
+
if depth < 50:
|
754 |
+
in_channels = [64, 64, 128, 256]
|
755 |
+
out_channels = [64, 128, 256, 512]
|
756 |
+
else:
|
757 |
+
in_channels = [64, 256, 512, 1024]
|
758 |
+
out_channels = [256, 512, 1024, 2048]
|
759 |
+
ret = []
|
760 |
+
for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
|
761 |
+
if depth >= 50:
|
762 |
+
kwargs["bottleneck_channels"] = o // 4
|
763 |
+
ret.append(
|
764 |
+
ResNet.make_stage(
|
765 |
+
block_class=block_class,
|
766 |
+
num_blocks=n,
|
767 |
+
stride_per_block=[s] + [1] * (n - 1),
|
768 |
+
in_channels=i,
|
769 |
+
out_channels=o,
|
770 |
+
**kwargs,
|
771 |
+
)
|
772 |
+
)
|
773 |
+
return ret
|
774 |
+
|
775 |
+
class DummyAggregationNetwork(nn.Module): # for testing, return the input
|
776 |
+
def __init__(self):
|
777 |
+
super(DummyAggregationNetwork, self).__init__()
|
778 |
+
# dummy paprameter
|
779 |
+
self.dummy = nn.Parameter(torch.ones([]))
|
780 |
+
|
781 |
+
def forward(self, batch, pose=None):
|
782 |
+
return batch * self.dummy
|
783 |
+
|
784 |
+
|
785 |
+
class AggregationNetwork(nn.Module):
|
786 |
+
"""
|
787 |
+
Module for aggregating feature maps across time and space.
|
788 |
+
Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
|
789 |
+
https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
|
790 |
+
"""
|
791 |
+
|
792 |
+
def __init__(
|
793 |
+
self,
|
794 |
+
device,
|
795 |
+
feature_dims=[640, 1280, 1280, 768],
|
796 |
+
projection_dim=384,
|
797 |
+
num_norm_groups=32,
|
798 |
+
save_timestep=[1],
|
799 |
+
kernel_size=[1, 3, 1],
|
800 |
+
contrastive_temp=10,
|
801 |
+
feat_map_dropout=0.0,
|
802 |
+
):
|
803 |
+
super().__init__()
|
804 |
+
self.skip_connection = True
|
805 |
+
self.feat_map_dropout = feat_map_dropout
|
806 |
+
self.azimuth_embedding = None
|
807 |
+
self.pos_embedding = None
|
808 |
+
self.bottleneck_layers = nn.ModuleList()
|
809 |
+
self.feature_dims = feature_dims
|
810 |
+
# For CLIP symmetric cross entropy loss during training
|
811 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
812 |
+
self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
|
813 |
+
self.device = device
|
814 |
+
self.save_timestep = save_timestep
|
815 |
+
|
816 |
+
self.mixing_weights_names = []
|
817 |
+
for l, feature_dim in enumerate(self.feature_dims):
|
818 |
+
bottleneck_layer = nn.Sequential(
|
819 |
+
*ResNet.make_stage(
|
820 |
+
BottleneckBlock,
|
821 |
+
num_blocks=1,
|
822 |
+
in_channels=feature_dim,
|
823 |
+
bottleneck_channels=projection_dim // 4,
|
824 |
+
out_channels=projection_dim,
|
825 |
+
norm="GN",
|
826 |
+
num_norm_groups=num_norm_groups,
|
827 |
+
kernel_size=kernel_size
|
828 |
+
)
|
829 |
+
)
|
830 |
+
self.bottleneck_layers.append(bottleneck_layer)
|
831 |
+
for t in save_timestep:
|
832 |
+
# 1-index the layer name following prior work
|
833 |
+
self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l + 1}")
|
834 |
+
self.last_layer = None
|
835 |
+
self.bottleneck_layers = self.bottleneck_layers.to(device)
|
836 |
+
mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
|
837 |
+
self.mixing_weights = nn.Parameter(mixing_weights.to(device))
|
838 |
+
# count number of parameters
|
839 |
+
num_params = 0
|
840 |
+
for param in self.parameters():
|
841 |
+
num_params += param.numel()
|
842 |
+
print(f"AggregationNetwork has {num_params} parameters.")
|
843 |
+
|
844 |
+
def load_pretrained_weights(self, pretrained_dict):
|
845 |
+
custom_dict = self.state_dict()
|
846 |
+
|
847 |
+
# Handle size mismatch
|
848 |
+
if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict[
|
849 |
+
'mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
|
850 |
+
# Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
|
851 |
+
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
|
852 |
+
custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
|
853 |
+
else:
|
854 |
+
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
|
855 |
+
|
856 |
+
# Load the weights that do match
|
857 |
+
matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
|
858 |
+
custom_dict.update(matching_keys)
|
859 |
+
|
860 |
+
# Now load the updated state_dict
|
861 |
+
self.load_state_dict(custom_dict, strict=False)
|
862 |
+
|
863 |
+
def forward(self, batch, pose=None):
|
864 |
+
"""
|
865 |
+
Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
|
866 |
+
"""
|
867 |
+
if self.feat_map_dropout > 0 and self.training:
|
868 |
+
batch = F.dropout(batch, p=self.feat_map_dropout)
|
869 |
+
|
870 |
+
output_feature = None
|
871 |
+
start = 0
|
872 |
+
mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
|
873 |
+
if self.pos_embedding is not None: # position embedding
|
874 |
+
batch = torch.cat((batch, self.pos_embedding), dim=1)
|
875 |
+
for i in range(len(mixing_weights)):
|
876 |
+
# Share bottleneck layers across timesteps
|
877 |
+
bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
|
878 |
+
# Chunk the batch according the layer
|
879 |
+
# Account for looping if there are multiple timesteps
|
880 |
+
end = start + self.feature_dims[i % len(self.feature_dims)]
|
881 |
+
feats = batch[:, start:end, :, :]
|
882 |
+
start = end
|
883 |
+
# Downsample the number of channels and weight the layer
|
884 |
+
bottlenecked_feature = bottleneck_layer(feats)
|
885 |
+
bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
|
886 |
+
if output_feature is None:
|
887 |
+
output_feature = bottlenecked_feature
|
888 |
+
else:
|
889 |
+
output_feature += bottlenecked_feature
|
890 |
+
|
891 |
+
if self.last_layer is not None:
|
892 |
+
|
893 |
+
output_feature_after = self.last_layer(output_feature)
|
894 |
+
if self.skip_connection:
|
895 |
+
# skip connection
|
896 |
+
output_feature = output_feature + output_feature_after
|
897 |
+
return output_feature
|
898 |
+
|
899 |
+
|
900 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
901 |
+
"""1x1 convolution without padding"""
|
902 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
903 |
+
|
904 |
+
|
905 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
906 |
+
"""3x3 convolution with padding"""
|
907 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
908 |
+
|
909 |
+
|
910 |
+
class BasicBlock(nn.Module):
|
911 |
+
def __init__(self, in_planes, planes, stride=1):
|
912 |
+
super().__init__()
|
913 |
+
self.conv1 = conv3x3(in_planes, planes, stride)
|
914 |
+
self.conv2 = conv3x3(planes, planes)
|
915 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
916 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
917 |
+
self.relu = nn.ReLU(inplace=True)
|
918 |
+
|
919 |
+
if stride == 1:
|
920 |
+
self.downsample = None
|
921 |
+
else:
|
922 |
+
self.downsample = nn.Sequential(
|
923 |
+
conv1x1(in_planes, planes, stride=stride),
|
924 |
+
nn.BatchNorm2d(planes)
|
925 |
+
)
|
926 |
+
|
927 |
+
def forward(self, x):
|
928 |
+
y = x
|
929 |
+
y = self.relu(self.bn1(self.conv1(y)))
|
930 |
+
y = self.bn2(self.conv2(y))
|
931 |
+
|
932 |
+
if self.downsample is not None:
|
933 |
+
x = self.downsample(x)
|
934 |
+
|
935 |
+
return self.relu(x + y)
|
EdgeCape/models/backbones/dino.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops as E
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers.models.vit_mae.modeling_vit_mae import (
|
6 |
+
get_2d_sincos_pos_embed_from_grid,
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
def resize_pos_embed(
|
11 |
+
pos_embed: torch.Tensor, hw: tuple[int, int], has_cls_token: bool = True
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Resize positional embedding for arbitrary image resolution. Resizing is done
|
15 |
+
via bicubic interpolation.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
pos_embed: Positional embedding tensor of shape ``(n_patches, embed_dim)``.
|
19 |
+
hw: Target height and width of the tensor after interpolation.
|
20 |
+
has_cls_token: Whether ``pos_embed[0]`` is for the ``[cls]`` token.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Tensor of shape ``(new_n_patches, embed_dim)`` of resized embedding.
|
24 |
+
``new_n_patches`` is ``new_height * new_width`` if ``has_cls`` is False,
|
25 |
+
else ``1 + new_height * new_width``.
|
26 |
+
"""
|
27 |
+
|
28 |
+
n_grid = pos_embed.shape[0] - 1 if has_cls_token else pos_embed.shape[0]
|
29 |
+
|
30 |
+
# Do not resize if already in same shape.
|
31 |
+
if n_grid == hw[0] * hw[1]:
|
32 |
+
return pos_embed
|
33 |
+
|
34 |
+
# Get original position embedding and extract ``[cls]`` token.
|
35 |
+
if has_cls_token:
|
36 |
+
cls_embed, pos_embed = pos_embed[[0]], pos_embed[1:]
|
37 |
+
|
38 |
+
orig_dim = int(pos_embed.shape[0] ** 0.5)
|
39 |
+
|
40 |
+
pos_embed = E.rearrange(pos_embed, "(h w) c -> 1 c h w", h=orig_dim)
|
41 |
+
pos_embed = F.interpolate(
|
42 |
+
pos_embed, hw, mode="bicubic", align_corners=False, antialias=True
|
43 |
+
)
|
44 |
+
pos_embed = E.rearrange(pos_embed, "1 c h w -> (h w) c")
|
45 |
+
|
46 |
+
# Add embedding of ``[cls]`` token back after resizing.
|
47 |
+
if has_cls_token:
|
48 |
+
pos_embed = torch.cat([cls_embed, pos_embed], dim=0)
|
49 |
+
|
50 |
+
return pos_embed
|
51 |
+
|
52 |
+
|
53 |
+
def center_padding(images, patch_size):
|
54 |
+
_, _, h, w = images.shape
|
55 |
+
diff_h = h % patch_size
|
56 |
+
diff_w = w % patch_size
|
57 |
+
|
58 |
+
if diff_h == 0 and diff_w == 0:
|
59 |
+
return images
|
60 |
+
|
61 |
+
pad_h = patch_size - diff_h
|
62 |
+
pad_w = patch_size - diff_w
|
63 |
+
|
64 |
+
pad_t = pad_h // 2
|
65 |
+
pad_l = pad_w // 2
|
66 |
+
pad_r = pad_w - pad_l
|
67 |
+
pad_b = pad_h - pad_t
|
68 |
+
|
69 |
+
images = F.pad(images, (pad_l, pad_r, pad_t, pad_b))
|
70 |
+
return images
|
71 |
+
|
72 |
+
|
73 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
74 |
+
"""
|
75 |
+
COPIED FROM TRANSFORMERS PACKAGE AND EDITED TO ALLOW FOR DIFFERENT WIDTH-HEIGHT
|
76 |
+
Create 2D sin/cos positional embeddings.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
embed_dim (`int`):
|
80 |
+
Embedding dimension.
|
81 |
+
grid_size (`int`):
|
82 |
+
The grid height and width.
|
83 |
+
add_cls_token (`bool`, *optional*, defaults to `False`):
|
84 |
+
Whether or not to add a classification (CLS) token.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or
|
88 |
+
(1+grid_size*grid_size, embed_dim): the
|
89 |
+
position embeddings (with or without classification token)
|
90 |
+
"""
|
91 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32)
|
92 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32)
|
93 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
94 |
+
grid = np.stack(grid, axis=0)
|
95 |
+
|
96 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
|
97 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
98 |
+
if add_cls_token:
|
99 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
100 |
+
return pos_embed
|
101 |
+
|
102 |
+
|
103 |
+
def tokens_to_output(output_type, dense_tokens, cls_token, feat_hw):
|
104 |
+
if output_type == "cls":
|
105 |
+
assert cls_token is not None
|
106 |
+
output = cls_token
|
107 |
+
elif output_type == "gap":
|
108 |
+
output = dense_tokens.mean(dim=1)
|
109 |
+
elif output_type == "dense":
|
110 |
+
h, w = feat_hw
|
111 |
+
dense_tokens = E.rearrange(dense_tokens, "b (h w) c -> b c h w", h=h, w=w)
|
112 |
+
output = dense_tokens.contiguous()
|
113 |
+
elif output_type == "dense-cls":
|
114 |
+
assert cls_token is not None
|
115 |
+
h, w = feat_hw
|
116 |
+
dense_tokens = E.rearrange(dense_tokens, "b (h w) c -> b c h w", h=h, w=w)
|
117 |
+
cls_token = cls_token[:, :, None, None].repeat(1, 1, h, w)
|
118 |
+
output = torch.cat((dense_tokens, cls_token), dim=1).contiguous()
|
119 |
+
else:
|
120 |
+
raise ValueError()
|
121 |
+
|
122 |
+
return output
|
123 |
+
|
124 |
+
class DINO(torch.nn.Module):
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dino_name="dinov2",
|
128 |
+
model_name="vits14",
|
129 |
+
output="dense-cls",
|
130 |
+
layer=-1,
|
131 |
+
return_multilayer=True,
|
132 |
+
):
|
133 |
+
super().__init__()
|
134 |
+
feat_dims = {
|
135 |
+
"vits14": 384,
|
136 |
+
"vitb8": 768,
|
137 |
+
"vitb16": 768,
|
138 |
+
"vitb14": 768,
|
139 |
+
"vitb14_reg": 768,
|
140 |
+
"vitl14": 1024,
|
141 |
+
"vitg14": 1536,
|
142 |
+
}
|
143 |
+
|
144 |
+
# get model
|
145 |
+
self.model_name = dino_name
|
146 |
+
self.checkpoint_name = f"{dino_name}_{model_name}"
|
147 |
+
dino_vit = torch.hub.load(f"facebookresearch/{dino_name}", self.checkpoint_name)
|
148 |
+
self.vit = dino_vit.eval().to(torch.float32)
|
149 |
+
self.has_registers = "_reg" in model_name
|
150 |
+
|
151 |
+
assert output in ["cls", "gap", "dense", "dense-cls"]
|
152 |
+
self.output = output
|
153 |
+
self.patch_size = self.vit.patch_embed.proj.kernel_size[0]
|
154 |
+
|
155 |
+
feat_dim = feat_dims[model_name]
|
156 |
+
feat_dim = feat_dim * 2 if output == "dense-cls" else feat_dim
|
157 |
+
|
158 |
+
num_layers = len(self.vit.blocks)
|
159 |
+
multilayers = [
|
160 |
+
num_layers // 4 - 1,
|
161 |
+
num_layers // 2 - 1,
|
162 |
+
num_layers // 4 * 3 - 1,
|
163 |
+
num_layers - 1,
|
164 |
+
]
|
165 |
+
|
166 |
+
if return_multilayer:
|
167 |
+
self.feat_dim = [feat_dim, feat_dim, feat_dim, feat_dim]
|
168 |
+
self.multilayers = multilayers
|
169 |
+
else:
|
170 |
+
self.feat_dim = feat_dim
|
171 |
+
layer = multilayers[-1] if layer == -1 else layer
|
172 |
+
self.multilayers = [layer]
|
173 |
+
|
174 |
+
# define layer name (for logging)
|
175 |
+
self.layer = "-".join(str(_x) for _x in self.multilayers)
|
176 |
+
|
177 |
+
def forward(self, images):
|
178 |
+
|
179 |
+
# pad images (if needed) to ensure it matches patch_size
|
180 |
+
images = center_padding(images, self.patch_size)
|
181 |
+
h, w = images.shape[-2:]
|
182 |
+
h, w = h // self.patch_size, w // self.patch_size
|
183 |
+
|
184 |
+
if self.model_name == "dinov2":
|
185 |
+
x = self.vit.prepare_tokens_with_masks(images, None)
|
186 |
+
else:
|
187 |
+
x = self.vit.prepare_tokens(images)
|
188 |
+
|
189 |
+
embeds = []
|
190 |
+
for i, blk in enumerate(self.vit.blocks):
|
191 |
+
x = blk(x)
|
192 |
+
if i in self.multilayers:
|
193 |
+
embeds.append(x)
|
194 |
+
if len(embeds) == len(self.multilayers):
|
195 |
+
break
|
196 |
+
|
197 |
+
num_spatial = h * w
|
198 |
+
outputs = []
|
199 |
+
for i, x_i in enumerate(embeds):
|
200 |
+
cls_tok = x_i[:, 0]
|
201 |
+
# ignoring register tokens
|
202 |
+
spatial = x_i[:, -1 * num_spatial :]
|
203 |
+
x_i = tokens_to_output(self.output, spatial, cls_tok, (h, w))
|
204 |
+
outputs.append(x_i)
|
205 |
+
|
206 |
+
return outputs[0] if len(outputs) == 1 else outputs
|
EdgeCape/models/detectors/EdgeCape.py
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import cv2
|
3 |
+
import mmcv
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from mmcv.image import imwrite
|
9 |
+
from mmcv.visualization.image import imshow
|
10 |
+
from mmpose.models import builder
|
11 |
+
from mmpose.models.builder import POSENETS
|
12 |
+
from mmpose.models.detectors.base import BasePose
|
13 |
+
from EdgeCape.models.backbones.adapter import DPT
|
14 |
+
from EdgeCape.models.backbones.dino import DINO
|
15 |
+
|
16 |
+
|
17 |
+
@POSENETS.register_module()
|
18 |
+
class EdgeCape(BasePose):
|
19 |
+
"""
|
20 |
+
EdgeCape: Edge-aware Context-Aware Pose Estimation.
|
21 |
+
Args:
|
22 |
+
keypoint_head (dict): Config for keypoint head.
|
23 |
+
encoder_config (dict): Config for encoder.
|
24 |
+
train_cfg (dict): Config for training. Default: None.
|
25 |
+
test_cfg (dict): Config for testing. Default: None.
|
26 |
+
freeze_backbone (bool): If True, freeze backbone. Default: False.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
keypoint_head,
|
31 |
+
encoder_config,
|
32 |
+
train_cfg=None,
|
33 |
+
test_cfg=None,
|
34 |
+
freeze_backbone=False):
|
35 |
+
super().__init__()
|
36 |
+
feature_output_setting = encoder_config.get('output', 'dense-cls')
|
37 |
+
model_name = encoder_config.get('model_name', 'vits14')
|
38 |
+
self.encoder_sample = self.encoder_query = DINO(output=feature_output_setting, model_name=model_name)
|
39 |
+
self.probe = DPT(input_dims=self.encoder_query.feat_dim, output_dim=768)
|
40 |
+
self.backbone = 'dino_extractor'
|
41 |
+
self.freeze_backbone = freeze_backbone
|
42 |
+
if keypoint_head.get('freeze', None) is not None:
|
43 |
+
self.freeze_backbone = True
|
44 |
+
|
45 |
+
self.keypoint_head_module = builder.build_head(keypoint_head)
|
46 |
+
self.keypoint_head_module.init_weights()
|
47 |
+
|
48 |
+
self.train_cfg = train_cfg
|
49 |
+
self.test_cfg = test_cfg
|
50 |
+
self.target_type = test_cfg.get('target_type',
|
51 |
+
'GaussianHeatMap') # GaussianHeatMap
|
52 |
+
|
53 |
+
@property
|
54 |
+
def with_keypoint(self):
|
55 |
+
"""Check if has keypoint_head."""
|
56 |
+
return hasattr(self, 'keypoint_head_module')
|
57 |
+
|
58 |
+
def init_weights(self, pretrained=None):
|
59 |
+
"""Weight initialization for model."""
|
60 |
+
self.encoder_sample.init_weights(pretrained)
|
61 |
+
self.encoder_query.init_weights(pretrained)
|
62 |
+
self.keypoint_head_module.init_weights()
|
63 |
+
|
64 |
+
def forward(self,
|
65 |
+
img_s,
|
66 |
+
img_q,
|
67 |
+
target_s=None,
|
68 |
+
target_weight_s=None,
|
69 |
+
target_q=None,
|
70 |
+
target_weight_q=None,
|
71 |
+
img_metas=None,
|
72 |
+
return_loss=True,
|
73 |
+
**kwargs):
|
74 |
+
"""Calls either forward_train or forward_test depending on whether
|
75 |
+
return_loss=True. Note this setting will change the expected inputs.
|
76 |
+
When `return_loss=True`, img and img_meta are single-nested (i.e.
|
77 |
+
Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
|
78 |
+
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
79 |
+
the outer list indicating test time augmentations.
|
80 |
+
"""
|
81 |
+
if return_loss:
|
82 |
+
return self.forward_train(img_s, target_s, target_weight_s, img_q,
|
83 |
+
target_q, target_weight_q, img_metas,
|
84 |
+
**kwargs)
|
85 |
+
else:
|
86 |
+
return self.forward_test(img_s, target_s, target_weight_s, img_q,
|
87 |
+
target_q, target_weight_q, img_metas,
|
88 |
+
**kwargs)
|
89 |
+
|
90 |
+
def forward_train(self,
|
91 |
+
img_s,
|
92 |
+
target_s,
|
93 |
+
target_weight_s,
|
94 |
+
img_q,
|
95 |
+
target_q,
|
96 |
+
target_weight_q,
|
97 |
+
img_metas,
|
98 |
+
**kwargs):
|
99 |
+
"""Defines the computation performed at every call when training."""
|
100 |
+
bs, _, h, w = img_q.shape
|
101 |
+
random_mask = kwargs.get('rand_mask', None)
|
102 |
+
output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints = self.predict(img_s,
|
103 |
+
target_s,
|
104 |
+
target_weight_s,
|
105 |
+
img_q,
|
106 |
+
img_metas,
|
107 |
+
random_mask)
|
108 |
+
|
109 |
+
# parse the img meta to get the target keypoints
|
110 |
+
device = output.device
|
111 |
+
target_keypoints = self.parse_keypoints_from_img_meta(img_metas,
|
112 |
+
device,
|
113 |
+
keyword='query')
|
114 |
+
|
115 |
+
target_sizes = torch.tensor(
|
116 |
+
[img_q.shape[-2], img_q.shape[-1]]).unsqueeze(0).repeat(
|
117 |
+
img_q.shape[0], 1, 1)
|
118 |
+
|
119 |
+
losses = dict()
|
120 |
+
if self.with_keypoint:
|
121 |
+
keypoint_losses = self.keypoint_head_module.get_loss(output,
|
122 |
+
initial_proposals,
|
123 |
+
similarity_map,
|
124 |
+
target_keypoints,
|
125 |
+
target_q,
|
126 |
+
target_weight_q * mask_s,
|
127 |
+
target_sizes,
|
128 |
+
reconstructed_keypoints,
|
129 |
+
)
|
130 |
+
losses.update(keypoint_losses)
|
131 |
+
keypoint_accuracy = self.keypoint_head_module.get_accuracy(output[-1],
|
132 |
+
target_keypoints,
|
133 |
+
target_weight_q * mask_s,
|
134 |
+
target_sizes,
|
135 |
+
height=h)
|
136 |
+
losses.update(keypoint_accuracy)
|
137 |
+
return losses
|
138 |
+
|
139 |
+
def forward_test(self,
|
140 |
+
img_s,
|
141 |
+
target_s,
|
142 |
+
target_weight_s,
|
143 |
+
img_q,
|
144 |
+
target_q,
|
145 |
+
target_weight_q,
|
146 |
+
img_metas=None,
|
147 |
+
vis_offset=True,
|
148 |
+
**kwargs):
|
149 |
+
|
150 |
+
"""Defines the computation performed at every call when testing."""
|
151 |
+
batch_size, _, img_height, img_width = img_q.shape
|
152 |
+
output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints = self.predict(img_s,
|
153 |
+
target_s,
|
154 |
+
target_weight_s,
|
155 |
+
img_q,
|
156 |
+
img_metas
|
157 |
+
)
|
158 |
+
predicted_pose = output[-1].detach().cpu().numpy()
|
159 |
+
result = {}
|
160 |
+
|
161 |
+
if self.with_keypoint:
|
162 |
+
keypoint_result = self.keypoint_head_module.decode(img_metas, predicted_pose, img_size=[img_width, img_height])
|
163 |
+
result.update(keypoint_result)
|
164 |
+
|
165 |
+
if vis_offset:
|
166 |
+
result.update({"points": torch.cat((initial_proposals[None], output)).cpu().numpy()})
|
167 |
+
|
168 |
+
result.update({"sample_image_file": [img_metas[i]['sample_image_file'] for i in range(len(img_metas))]})
|
169 |
+
|
170 |
+
return result
|
171 |
+
|
172 |
+
def predict(self,
|
173 |
+
img_s,
|
174 |
+
target_s,
|
175 |
+
target_weight_s,
|
176 |
+
img_q,
|
177 |
+
img_metas=None,
|
178 |
+
random_mask=None):
|
179 |
+
|
180 |
+
batch_size, _, img_height, img_width = img_q.shape
|
181 |
+
assert [i['sample_skeleton'][0] != i['query_skeleton'] for i in img_metas]
|
182 |
+
mask_s = target_weight_s[0]
|
183 |
+
for target_weight in target_weight_s:
|
184 |
+
mask_s = mask_s * target_weight
|
185 |
+
feature_q, feature_s = self.extract_features(img_s, img_q)
|
186 |
+
skeleton_lst = [i['sample_skeleton'][0] for i in img_metas]
|
187 |
+
|
188 |
+
(output, initial_proposals, similarity_map, reconstructed_keypoints) = self.keypoint_head_module(
|
189 |
+
feature_q, feature_s, target_s, mask_s, skeleton_lst, random_mask=random_mask)
|
190 |
+
|
191 |
+
return output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints
|
192 |
+
|
193 |
+
def extract_features(self, img_s, img_q):
|
194 |
+
with torch.no_grad():
|
195 |
+
dino_feature_s = [self.encoder_sample(img) for img in img_s]
|
196 |
+
dino_feature_q = self.encoder_query(img_q) # [bs, 3, h, w]
|
197 |
+
if self.freeze_backbone:
|
198 |
+
with torch.no_grad():
|
199 |
+
feature_s = [self.probe(f) for f in dino_feature_s]
|
200 |
+
feature_q = self.probe(dino_feature_q)
|
201 |
+
else:
|
202 |
+
feature_s = [self.probe(f) for f in dino_feature_s]
|
203 |
+
feature_q = self.probe(dino_feature_q)
|
204 |
+
|
205 |
+
return feature_q, feature_s
|
206 |
+
|
207 |
+
def parse_keypoints_from_img_meta(self, img_meta, device, keyword='query'):
|
208 |
+
"""Parse keypoints from the img_meta.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
img_meta (dict): Image meta info.
|
212 |
+
device (torch.device): Device of the output keypoints.
|
213 |
+
keyword (str): 'query' or 'sample'. Default: 'query'.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
Tensor: Keypoints coordinates of query images.
|
217 |
+
"""
|
218 |
+
|
219 |
+
if keyword == 'query':
|
220 |
+
query_kpt = torch.stack([
|
221 |
+
torch.tensor(info[f'{keyword}_joints_3d']).to(device) for info in img_meta], dim=0)[:, :, :2]
|
222 |
+
else:
|
223 |
+
query_kpt = []
|
224 |
+
for info in img_meta:
|
225 |
+
if isinstance(info[f'{keyword}_joints_3d'][0], torch.Tensor):
|
226 |
+
samples = torch.stack(info[f'{keyword}_joints_3d'])
|
227 |
+
else:
|
228 |
+
samples = np.array(info[f'{keyword}_joints_3d'])
|
229 |
+
query_kpt.append(torch.tensor(samples).to(device)[:, :, :2])
|
230 |
+
query_kpt = torch.stack(query_kpt, dim=0) # [bs, , num_samples, num_query, 2]
|
231 |
+
return query_kpt
|
232 |
+
|
233 |
+
def get_full_similarity_map(self, feature_q, feature_s, h, w):
|
234 |
+
resized_feature_q = F.interpolate(feature_q, size=(h, w),
|
235 |
+
mode='bilinear')
|
236 |
+
resized_feature_s = [F.interpolate(s, size=(h, w), mode='bilinear') for
|
237 |
+
s in feature_s]
|
238 |
+
return [self.chunk_cosine_sim(f_s, resized_feature_q) for f_s in
|
239 |
+
resized_feature_s]
|
240 |
+
|
241 |
+
# UNMODIFIED
|
242 |
+
def show_result(self,
|
243 |
+
img,
|
244 |
+
result,
|
245 |
+
skeleton=None,
|
246 |
+
kpt_score_thr=0.3,
|
247 |
+
bbox_color='green',
|
248 |
+
pose_kpt_color=None,
|
249 |
+
pose_limb_color=None,
|
250 |
+
radius=4,
|
251 |
+
text_color=(255, 0, 0),
|
252 |
+
thickness=1,
|
253 |
+
font_scale=0.5,
|
254 |
+
win_name='',
|
255 |
+
show=False,
|
256 |
+
wait_time=0,
|
257 |
+
out_file=None):
|
258 |
+
"""Draw `result` over `img`.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
img (str or Tensor): The image to be displayed.
|
262 |
+
result (list[dict]): The results to draw over `img`
|
263 |
+
(bbox_result, pose_result).
|
264 |
+
kpt_score_thr (float, optional): Minimum score of keypoints
|
265 |
+
to be shown. Default: 0.3.
|
266 |
+
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
|
267 |
+
pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
|
268 |
+
If None, do not draw keypoints.
|
269 |
+
pose_limb_color (np.array[Mx3]): Color of M limbs.
|
270 |
+
If None, do not draw limbs.
|
271 |
+
text_color (str or tuple or :obj:`Color`): Color of texts.
|
272 |
+
thickness (int): Thickness of lines.
|
273 |
+
font_scale (float): Font scales of texts.
|
274 |
+
win_name (str): The window name.
|
275 |
+
wait_time (int): Value of waitKey param.
|
276 |
+
Default: 0.
|
277 |
+
out_file (str or None): The filename to write the image.
|
278 |
+
Default: None.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
Tensor: Visualized img, only if not `show` or `out_file`.
|
282 |
+
"""
|
283 |
+
|
284 |
+
img = mmcv.imread(img)
|
285 |
+
img = img.copy()
|
286 |
+
img_h, img_w, _ = img.shape
|
287 |
+
|
288 |
+
bbox_result = []
|
289 |
+
pose_result = []
|
290 |
+
for res in result:
|
291 |
+
bbox_result.append(res['bbox'])
|
292 |
+
pose_result.append(res['keypoints'])
|
293 |
+
|
294 |
+
if len(bbox_result) > 0:
|
295 |
+
bboxes = np.vstack(bbox_result)
|
296 |
+
# draw bounding boxes
|
297 |
+
mmcv.imshow_bboxes(
|
298 |
+
img,
|
299 |
+
bboxes,
|
300 |
+
colors=bbox_color,
|
301 |
+
top_k=-1,
|
302 |
+
thickness=thickness,
|
303 |
+
show=False,
|
304 |
+
win_name=win_name,
|
305 |
+
wait_time=wait_time,
|
306 |
+
out_file=None)
|
307 |
+
|
308 |
+
for person_id, kpts in enumerate(pose_result):
|
309 |
+
# draw each point on image
|
310 |
+
if pose_kpt_color is not None:
|
311 |
+
assert len(pose_kpt_color) == len(kpts), (
|
312 |
+
len(pose_kpt_color), len(kpts))
|
313 |
+
for kid, kpt in enumerate(kpts):
|
314 |
+
x_coord, y_coord, kpt_score = int(kpt[0]), int(
|
315 |
+
kpt[1]), kpt[2]
|
316 |
+
if kpt_score > kpt_score_thr:
|
317 |
+
img_copy = img.copy()
|
318 |
+
r, g, b = pose_kpt_color[kid]
|
319 |
+
cv2.circle(img_copy, (int(x_coord), int(y_coord)),
|
320 |
+
radius, (int(r), int(g), int(b)), -1)
|
321 |
+
transparency = max(0, min(1, kpt_score))
|
322 |
+
cv2.addWeighted(
|
323 |
+
img_copy,
|
324 |
+
transparency,
|
325 |
+
img,
|
326 |
+
1 - transparency,
|
327 |
+
0,
|
328 |
+
dst=img)
|
329 |
+
|
330 |
+
# draw limbs
|
331 |
+
if skeleton is not None and pose_limb_color is not None:
|
332 |
+
assert len(pose_limb_color) == len(skeleton)
|
333 |
+
for sk_id, sk in enumerate(skeleton):
|
334 |
+
pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1,
|
335 |
+
1]))
|
336 |
+
pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1,
|
337 |
+
1]))
|
338 |
+
if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
|
339 |
+
and pos1[1] < img_h and pos2[0] > 0
|
340 |
+
and pos2[0] < img_w and pos2[1] > 0
|
341 |
+
and pos2[1] < img_h
|
342 |
+
and kpts[sk[0] - 1, 2] > kpt_score_thr
|
343 |
+
and kpts[sk[1] - 1, 2] > kpt_score_thr):
|
344 |
+
img_copy = img.copy()
|
345 |
+
X = (pos1[0], pos2[0])
|
346 |
+
Y = (pos1[1], pos2[1])
|
347 |
+
mX = np.mean(X)
|
348 |
+
mY = np.mean(Y)
|
349 |
+
length = ((Y[0] - Y[1]) ** 2 + (
|
350 |
+
X[0] - X[1]) ** 2) ** 0.5
|
351 |
+
angle = math.degrees(
|
352 |
+
math.atan2(Y[0] - Y[1], X[0] - X[1]))
|
353 |
+
stickwidth = 2
|
354 |
+
polygon = cv2.ellipse2Poly(
|
355 |
+
(int(mX), int(mY)),
|
356 |
+
(int(length / 2), int(stickwidth)), int(angle),
|
357 |
+
0, 360, 1)
|
358 |
+
|
359 |
+
r, g, b = pose_limb_color[sk_id]
|
360 |
+
cv2.fillConvexPoly(img_copy, polygon,
|
361 |
+
(int(r), int(g), int(b)))
|
362 |
+
transparency = max(
|
363 |
+
0,
|
364 |
+
min(
|
365 |
+
1, 0.5 *
|
366 |
+
(kpts[sk[0] - 1, 2] + kpts[
|
367 |
+
sk[1] - 1, 2])))
|
368 |
+
cv2.addWeighted(
|
369 |
+
img_copy,
|
370 |
+
transparency,
|
371 |
+
img,
|
372 |
+
1 - transparency,
|
373 |
+
0,
|
374 |
+
dst=img)
|
375 |
+
|
376 |
+
show, wait_time = 1, 1
|
377 |
+
if show:
|
378 |
+
height, width = img.shape[:2]
|
379 |
+
max_ = max(height, width)
|
380 |
+
|
381 |
+
factor = min(1, 800 / max_)
|
382 |
+
enlarge = cv2.resize(
|
383 |
+
img, (0, 0),
|
384 |
+
fx=factor,
|
385 |
+
fy=factor,
|
386 |
+
interpolation=cv2.INTER_CUBIC)
|
387 |
+
imshow(enlarge, win_name, wait_time)
|
388 |
+
|
389 |
+
if out_file is not None:
|
390 |
+
imwrite(img, out_file)
|
391 |
+
|
392 |
+
return img
|
EdgeCape/models/detectors/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .EdgeCape import EdgeCape
|
2 |
+
|
3 |
+
__all__ = ['EdgeCape']
|
EdgeCape/models/detectors/__pycache__/EdgeCape.cpython-39.pyc
ADDED
Binary file (11.6 kB). View file
|
|