orhir commited on
Commit
184241a
·
verified ·
1 Parent(s): dbe4dc3

Upload 114 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. EdgeCape/VERSION +1 -0
  2. EdgeCape/__init__.py +3 -0
  3. EdgeCape/__pycache__/__init__.cpython-39.pyc +0 -0
  4. EdgeCape/apis/__init__.py +5 -0
  5. EdgeCape/apis/__pycache__/__init__.cpython-39.pyc +0 -0
  6. EdgeCape/apis/__pycache__/test.cpython-39.pyc +0 -0
  7. EdgeCape/apis/__pycache__/train.cpython-39.pyc +0 -0
  8. EdgeCape/apis/test.py +198 -0
  9. EdgeCape/apis/train.py +124 -0
  10. EdgeCape/core/__init__.py +1 -0
  11. EdgeCape/core/__pycache__/__init__.cpython-39.pyc +0 -0
  12. EdgeCape/core/custom_hooks/__pycache__/shuffle_hooks.cpython-39.pyc +0 -0
  13. EdgeCape/core/custom_hooks/shuffle_hooks.py +28 -0
  14. EdgeCape/datasets/__init__.py +3 -0
  15. EdgeCape/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  16. EdgeCape/datasets/__pycache__/builder.cpython-39.pyc +0 -0
  17. EdgeCape/datasets/builder.py +55 -0
  18. EdgeCape/datasets/datasets/__init__.py +6 -0
  19. EdgeCape/datasets/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  20. EdgeCape/datasets/datasets/mp100/__init__.py +13 -0
  21. EdgeCape/datasets/datasets/mp100/__pycache__/__init__.cpython-39.pyc +0 -0
  22. EdgeCape/datasets/datasets/mp100/__pycache__/custom_test_dataset.cpython-39.pyc +0 -0
  23. EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_base_dataset.cpython-39.pyc +0 -0
  24. EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_dataset.cpython-39.pyc +0 -0
  25. EdgeCape/datasets/datasets/mp100/__pycache__/test_base_dataset.cpython-39.pyc +0 -0
  26. EdgeCape/datasets/datasets/mp100/__pycache__/test_dataset.cpython-39.pyc +0 -0
  27. EdgeCape/datasets/datasets/mp100/__pycache__/transformer_base_dataset.cpython-39.pyc +0 -0
  28. EdgeCape/datasets/datasets/mp100/__pycache__/transformer_dataset.cpython-39.pyc +0 -0
  29. EdgeCape/datasets/datasets/mp100/custom_test_dataset.py +355 -0
  30. EdgeCape/datasets/datasets/mp100/fewshot_base_dataset.py +223 -0
  31. EdgeCape/datasets/datasets/mp100/fewshot_dataset.py +312 -0
  32. EdgeCape/datasets/datasets/mp100/test_base_dataset.py +226 -0
  33. EdgeCape/datasets/datasets/mp100/test_dataset.py +319 -0
  34. EdgeCape/datasets/datasets/mp100/transformer_base_dataset.py +209 -0
  35. EdgeCape/datasets/datasets/mp100/transformer_dataset.py +319 -0
  36. EdgeCape/datasets/pipelines/__init__.py +8 -0
  37. EdgeCape/datasets/pipelines/__pycache__/__init__.cpython-39.pyc +0 -0
  38. EdgeCape/datasets/pipelines/__pycache__/post_transforms.cpython-39.pyc +0 -0
  39. EdgeCape/datasets/pipelines/__pycache__/top_down_transform.cpython-39.pyc +0 -0
  40. EdgeCape/datasets/pipelines/post_transforms.py +121 -0
  41. EdgeCape/datasets/pipelines/top_down_transform.py +716 -0
  42. EdgeCape/models/__init__.py +3 -0
  43. EdgeCape/models/__pycache__/__init__.cpython-39.pyc +0 -0
  44. EdgeCape/models/backbones/__pycache__/adapter.cpython-39.pyc +0 -0
  45. EdgeCape/models/backbones/__pycache__/dino.cpython-39.pyc +0 -0
  46. EdgeCape/models/backbones/adapter.py +935 -0
  47. EdgeCape/models/backbones/dino.py +206 -0
  48. EdgeCape/models/detectors/EdgeCape.py +392 -0
  49. EdgeCape/models/detectors/__init__.py +3 -0
  50. EdgeCape/models/detectors/__pycache__/EdgeCape.cpython-39.pyc +0 -0
EdgeCape/VERSION ADDED
@@ -0,0 +1 @@
 
 
1
+ 0.2.0
EdgeCape/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .core import * # noqa
2
+ from .datasets import * # noqa
3
+ from .models import * # noqa
EdgeCape/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (206 Bytes). View file
 
EdgeCape/apis/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .train import train_model
2
+
3
+ __all__ = [
4
+ 'train_model'
5
+ ]
EdgeCape/apis/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (217 Bytes). View file
 
EdgeCape/apis/__pycache__/test.cpython-39.pyc ADDED
Binary file (5.14 kB). View file
 
EdgeCape/apis/__pycache__/train.cpython-39.pyc ADDED
Binary file (3.19 kB). View file
 
EdgeCape/apis/test.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os.path as osp
3
+ import pickle
4
+ import shutil
5
+ import tempfile
6
+
7
+ import mmcv
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ from mmcv.runner import get_dist_info
12
+
13
+
14
+ def single_gpu_test(model, data_loader):
15
+ """Test model with a single gpu.
16
+
17
+ This method tests model with a single gpu and displays test progress bar.
18
+
19
+ Args:
20
+ model (nn.Module): Model to be tested.
21
+ data_loader (nn.Dataloader): Pytorch data loader.
22
+
23
+
24
+ Returns:
25
+ list: The prediction results.
26
+ """
27
+ model.eval()
28
+ results = []
29
+ dataset = data_loader.dataset
30
+ prog_bar = mmcv.ProgressBar(len(dataset))
31
+ for data in data_loader:
32
+ with torch.no_grad():
33
+ result = model(return_loss=False, **data)
34
+ batch_size = len(next(iter(data.values()))[0])
35
+ # results.append(result)
36
+ if 'preds' in result:
37
+ for i in range(batch_size):
38
+ results.append({
39
+ 'preds': result['preds'][i][None],
40
+ 'boxes': result['boxes'][i][None],
41
+ 'bbox_ids': [result['bbox_ids'][i]],
42
+ 'image_paths': [result['image_paths'][i]],
43
+ })
44
+ # use the first key as main key to calculate the batch size
45
+ # for _ in range(batch_size):
46
+ prog_bar.update(batch_size)
47
+ return results
48
+
49
+
50
+ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
51
+ """Test model with multiple gpus.
52
+
53
+ This method tests model with multiple gpus and collects the results
54
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
55
+ it encodes results to gpu tensors and use gpu communication for results
56
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
57
+ and collects them by the rank 0 worker.
58
+
59
+ Args:
60
+ model (nn.Module): Model to be tested.
61
+ data_loader (nn.Dataloader): Pytorch data loader.
62
+ tmpdir (str): Path of directory to save the temporary results from
63
+ different gpus under cpu mode.
64
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
65
+
66
+ Returns:
67
+ list: The prediction results.
68
+ """
69
+ model.eval()
70
+ results = []
71
+ dataset = data_loader.dataset
72
+ rank, world_size = get_dist_info()
73
+ if rank == 0:
74
+ prog_bar = mmcv.ProgressBar(len(dataset))
75
+ for data in data_loader:
76
+ with torch.no_grad():
77
+ result = model(return_loss=False, **data)
78
+ results.append(result)
79
+
80
+ if rank == 0:
81
+ # use the first key as main key to calculate the batch size
82
+ batch_size = len(next(iter(data.values())))
83
+ for _ in range(batch_size * world_size):
84
+ prog_bar.update()
85
+
86
+ # collect results from all ranks
87
+ if gpu_collect:
88
+ results = collect_results_gpu(results, len(dataset))
89
+ else:
90
+ results = collect_results_cpu(results, len(dataset), tmpdir)
91
+ return results
92
+
93
+
94
+ def collect_results_cpu(result_part, size, tmpdir=None):
95
+ """Collect results in cpu mode.
96
+
97
+ It saves the results on different gpus to 'tmpdir' and collects
98
+ them by the rank 0 worker.
99
+
100
+ Args:
101
+ result_part (list): Results to be collected
102
+ size (int): Result size.
103
+ tmpdir (str): Path of directory to save the temporary results from
104
+ different gpus under cpu mode. Default: None
105
+
106
+ Returns:
107
+ list: Ordered results.
108
+ """
109
+ rank, world_size = get_dist_info()
110
+ # create a tmp dir if it is not specified
111
+ if tmpdir is None:
112
+ MAX_LEN = 512
113
+ # 32 is whitespace
114
+ dir_tensor = torch.full((MAX_LEN, ),
115
+ 32,
116
+ dtype=torch.uint8,
117
+ device='cuda')
118
+ if rank == 0:
119
+ mmcv.mkdir_or_exist('.dist_test')
120
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
121
+ tmpdir = torch.tensor(
122
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
123
+ dir_tensor[:len(tmpdir)] = tmpdir
124
+ dist.broadcast(dir_tensor, 0)
125
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
126
+ else:
127
+ mmcv.mkdir_or_exist(tmpdir)
128
+ # synchronizes all processes to make sure tmpdir exist
129
+ dist.barrier()
130
+ # dump the part result to the dir
131
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
132
+ # synchronizes all processes for loading pickle file
133
+ dist.barrier()
134
+ # collect all parts
135
+ if rank != 0:
136
+ return None
137
+
138
+ # load results of all parts from tmp dir
139
+ part_list = []
140
+ for i in range(world_size):
141
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
142
+ part_list.append(mmcv.load(part_file))
143
+ # sort the results
144
+ ordered_results = []
145
+ for res in zip(*part_list):
146
+ ordered_results.extend(list(res))
147
+ # the dataloader may pad some samples
148
+ ordered_results = ordered_results[:size]
149
+ # remove tmp dir
150
+ shutil.rmtree(tmpdir)
151
+ return ordered_results
152
+
153
+
154
+ def collect_results_gpu(result_part, size):
155
+ """Collect results in gpu mode.
156
+
157
+ It encodes results to gpu tensors and use gpu communication for results
158
+ collection.
159
+
160
+ Args:
161
+ result_part (list): Results to be collected
162
+ size (int): Result size.
163
+
164
+ Returns:
165
+ list: Ordered results.
166
+ """
167
+
168
+ rank, world_size = get_dist_info()
169
+ # dump result part to tensor with pickle
170
+ part_tensor = torch.tensor(
171
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
172
+ # gather all result part tensor shape
173
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
174
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
175
+ dist.all_gather(shape_list, shape_tensor)
176
+ # padding result part tensor to max length
177
+ shape_max = torch.tensor(shape_list).max()
178
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
179
+ part_send[:shape_tensor[0]] = part_tensor
180
+ part_recv_list = [
181
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
182
+ ]
183
+ # gather all result part
184
+ dist.all_gather(part_recv_list, part_send)
185
+
186
+ if rank == 0:
187
+ part_list = []
188
+ for recv, shape in zip(part_recv_list, shape_list):
189
+ part_list.append(
190
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
191
+ # sort the results
192
+ ordered_results = []
193
+ for res in zip(*part_list):
194
+ ordered_results.extend(list(res))
195
+ # the dataloader may pad some samples
196
+ ordered_results = ordered_results[:size]
197
+ return ordered_results
198
+ return None
EdgeCape/apis/train.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
4
+ from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
5
+ build_optimizer)
6
+
7
+ from mmpose.core import DistEvalHook, EvalHook, Fp16OptimizerHook
8
+ from mmpose.datasets import build_dataloader
9
+ from mmpose.utils import get_root_logger
10
+ from EdgeCape.core.custom_hooks.shuffle_hooks import ShufflePairedSamplesHook
11
+
12
+ def train_model(model,
13
+ dataset,
14
+ val_dataset,
15
+ cfg,
16
+ distributed=False,
17
+ validate=False,
18
+ timestamp=None,
19
+ meta=None):
20
+ """Train model entry function.
21
+
22
+ Args:
23
+ model (nn.Module): The model to be trained.
24
+ dataset (Dataset): Train dataset.
25
+ cfg (dict): The config dict for training.
26
+ distributed (bool): Whether to use distributed training.
27
+ Default: False.
28
+ validate (bool): Whether to do evaluation. Default: False.
29
+ timestamp (str | None): Local time for runner. Default: None.
30
+ meta (dict | None): Meta dict to record some important information.
31
+ Default: None
32
+ """
33
+ logger = get_root_logger(cfg.log_level)
34
+
35
+ # prepare data loaders
36
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
37
+ dataloader_setting = dict(
38
+ samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
39
+ workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
40
+ # cfg.gpus will be ignored if distributed
41
+ num_gpus=len(cfg.gpu_ids),
42
+ dist=distributed,
43
+ seed=cfg.seed,
44
+ pin_memory=False,
45
+ )
46
+ dataloader_setting = dict(dataloader_setting,
47
+ **cfg.data.get('train_dataloader', {}))
48
+
49
+ data_loaders = [
50
+ build_dataloader(ds, **dataloader_setting) for ds in dataset
51
+ ]
52
+
53
+ # put model on gpus
54
+ if distributed:
55
+ find_unused_parameters = cfg.get('find_unused_parameters', True) # NOTE: True has been modified to False for faster training.
56
+ # Sets the `find_unused_parameters` parameter in
57
+ # torch.nn.parallel.DistributedDataParallel
58
+ model = MMDistributedDataParallel(
59
+ model.cuda(),
60
+ device_ids=[torch.cuda.current_device()],
61
+ broadcast_buffers=False,
62
+ find_unused_parameters=find_unused_parameters)
63
+ else:
64
+ model = MMDataParallel(
65
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
66
+
67
+ # build runner
68
+ optimizer = build_optimizer(model, cfg.optimizer)
69
+ runner = EpochBasedRunner(
70
+ model,
71
+ optimizer=optimizer,
72
+ work_dir=cfg.work_dir,
73
+ logger=logger,
74
+ meta=meta)
75
+ # an ugly workaround to make .log and .log.json filenames the same
76
+ runner.timestamp = timestamp
77
+
78
+ # fp16 setting
79
+ fp16_cfg = cfg.get('fp16', None)
80
+ if fp16_cfg is not None:
81
+ optimizer_config = Fp16OptimizerHook(
82
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
83
+ elif distributed and 'type' not in cfg.optimizer_config:
84
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
85
+ else:
86
+ optimizer_config = cfg.optimizer_config
87
+
88
+ # register hooks
89
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
90
+ cfg.checkpoint_config, cfg.log_config,
91
+ cfg.get('momentum_config', None))
92
+ if distributed:
93
+ runner.register_hook(DistSamplerSeedHook())
94
+
95
+ shuffle_cfg = cfg.get('shuffle_cfg', None)
96
+ if shuffle_cfg is not None:
97
+ for data_loader in data_loaders:
98
+ runner.register_hook(ShufflePairedSamplesHook(data_loader, **shuffle_cfg))
99
+
100
+ # register eval hooks
101
+ if validate:
102
+ eval_cfg = cfg.get('evaluation', {})
103
+ eval_cfg['res_folder'] = os.path.join(cfg.work_dir, eval_cfg['res_folder'])
104
+ dataloader_setting = dict(
105
+ # samples_per_gpu=cfg.data.get('samples_per_gpu', {}),
106
+ samples_per_gpu=1,
107
+ workers_per_gpu=cfg.data.get('workers_per_gpu', {}),
108
+ # cfg.gpus will be ignored if distributed
109
+ num_gpus=len(cfg.gpu_ids),
110
+ dist=distributed,
111
+ shuffle=False,
112
+ pin_memory=False,
113
+ )
114
+ dataloader_setting = dict(dataloader_setting,
115
+ **cfg.data.get('val_dataloader', {}))
116
+ val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
117
+ eval_hook = DistEvalHook if distributed else EvalHook
118
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
119
+
120
+ if cfg.resume_from:
121
+ runner.resume(cfg.resume_from)
122
+ elif cfg.load_from:
123
+ runner.load_checkpoint(cfg.load_from)
124
+ runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
EdgeCape/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
EdgeCape/core/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
EdgeCape/core/custom_hooks/__pycache__/shuffle_hooks.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
EdgeCape/core/custom_hooks/shuffle_hooks.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.runner import Hook
2
+ from torch.utils.data import DataLoader
3
+ from mmpose.utils import get_root_logger
4
+
5
+ class ShufflePairedSamplesHook(Hook):
6
+ """Non-Distributed ShufflePairedSamples.
7
+ After each training epoch, run FewShotKeypointDataset.random_paired_samples()
8
+ """
9
+
10
+ def __init__(self,
11
+ dataloader,
12
+ interval=1):
13
+ if not isinstance(dataloader, DataLoader):
14
+ raise TypeError(f'dataloader must be a pytorch DataLoader, '
15
+ f'but got {type(dataloader)}')
16
+
17
+ self.dataloader = dataloader
18
+ self.interval = interval
19
+ self.logger = get_root_logger()
20
+
21
+ def after_train_epoch(self, runner):
22
+ """Called after every training epoch to evaluate the results."""
23
+ if not self.every_n_epochs(runner, self.interval):
24
+ return
25
+ # self.logger.info("Run random_paired_samples()")
26
+ # self.logger.info(f"Before: {self.dataloader.dataset.paired_samples[0]}")
27
+ self.dataloader.dataset.random_paired_samples()
28
+ # self.logger.info(f"After: {self.dataloader.dataset.paired_samples[0]}")
EdgeCape/datasets/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .builder import * # noqa
2
+ from .datasets import * # noqa
3
+ from .pipelines import * # noqa
EdgeCape/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (221 Bytes). View file
 
EdgeCape/datasets/__pycache__/builder.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
EdgeCape/datasets/builder.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmcv.utils import build_from_cfg
2
+ from torch.utils.data.dataset import ConcatDataset
3
+
4
+ from mmpose.datasets.dataset_wrappers import RepeatDataset
5
+ from mmpose.datasets.builder import DATASETS
6
+
7
+
8
+ def _concat_cfg(cfg):
9
+ replace = ['ann_file', 'img_prefix']
10
+ channels = ['num_joints', 'dataset_channel']
11
+ concat_cfg = []
12
+ for i in range(len(cfg['type'])):
13
+ cfg_tmp = cfg.deepcopy()
14
+ cfg_tmp['type'] = cfg['type'][i]
15
+ for item in replace:
16
+ assert item in cfg_tmp
17
+ assert len(cfg['type']) == len(cfg[item]), (cfg[item])
18
+ cfg_tmp[item] = cfg[item][i]
19
+ for item in channels:
20
+ assert item in cfg_tmp['data_cfg']
21
+ assert len(cfg['type']) == len(cfg['data_cfg'][item])
22
+ cfg_tmp['data_cfg'][item] = cfg['data_cfg'][item][i]
23
+ concat_cfg.append(cfg_tmp)
24
+ return concat_cfg
25
+
26
+
27
+ def _check_vaild(cfg):
28
+ replace = ['num_joints', 'dataset_channel']
29
+ if isinstance(cfg['data_cfg'][replace[0]], (list, tuple)):
30
+ for item in replace:
31
+ cfg['data_cfg'][item] = cfg['data_cfg'][item][0]
32
+ return cfg
33
+
34
+
35
+ def build_dataset(cfg, default_args=None):
36
+ """Build a dataset from config dict.
37
+
38
+ Args:
39
+ cfg (dict): Config dict. It should at least contain the key "type".
40
+ default_args (dict, optional): Default initialization arguments.
41
+ Default: None.
42
+
43
+ Returns:
44
+ Dataset: The constructed dataset.
45
+ """
46
+ if isinstance(cfg['type'], (list, tuple)): # In training, type=TransformerPoseDataset
47
+ dataset = ConcatDataset(
48
+ [build_dataset(c, default_args) for c in _concat_cfg(cfg)])
49
+ elif cfg['type'] == 'RepeatDataset':
50
+ dataset = RepeatDataset(
51
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
52
+ else:
53
+ cfg = _check_vaild(cfg)
54
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
55
+ return dataset
EdgeCape/datasets/datasets/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .mp100 import (FewShotKeypointDataset, FewShotBaseDataset,
2
+ TransformerBaseDataset, TransformerPoseDataset,)
3
+
4
+ __all__ = ['FewShotBaseDataset', 'FewShotKeypointDataset',
5
+ 'TransformerBaseDataset', 'TransformerPoseDataset',
6
+ ]
EdgeCape/datasets/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (353 Bytes). View file
 
EdgeCape/datasets/datasets/mp100/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .fewshot_dataset import FewShotKeypointDataset
2
+ from .fewshot_base_dataset import FewShotBaseDataset
3
+ from .transformer_dataset import TransformerPoseDataset
4
+ from .transformer_base_dataset import TransformerBaseDataset
5
+ from .test_base_dataset import TestBaseDataset
6
+ from .test_dataset import TestPoseDataset
7
+ from .custom_test_dataset import CustomTestPoseDataset
8
+
9
+ __all__ = [
10
+ 'FewShotKeypointDataset', 'FewShotBaseDataset',
11
+ 'TransformerPoseDataset', 'TransformerBaseDataset',
12
+ 'TestBaseDataset', 'TestPoseDataset', 'CustomTestPoseDataset'
13
+ ]
EdgeCape/datasets/datasets/mp100/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (663 Bytes). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/custom_test_dataset.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_base_dataset.cpython-39.pyc ADDED
Binary file (7.22 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/fewshot_dataset.cpython-39.pyc ADDED
Binary file (8.95 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/test_base_dataset.cpython-39.pyc ADDED
Binary file (7.4 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/test_dataset.cpython-39.pyc ADDED
Binary file (9.02 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/transformer_base_dataset.cpython-39.pyc ADDED
Binary file (7.23 kB). View file
 
EdgeCape/datasets/datasets/mp100/__pycache__/transformer_dataset.cpython-39.pyc ADDED
Binary file (9.06 kB). View file
 
EdgeCape/datasets/datasets/mp100/custom_test_dataset.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmpose.datasets import DATASETS
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from collections import OrderedDict
6
+ from xtcocotools.coco import COCO
7
+ from .test_base_dataset import TestBaseDataset
8
+
9
+ @DATASETS.register_module()
10
+ class CustomTestPoseDataset(TestBaseDataset):
11
+
12
+ def __init__(self,
13
+ ann_file,
14
+ img_prefix,
15
+ data_cfg,
16
+ pipeline,
17
+ valid_class_ids,
18
+ max_kpt_num=None,
19
+ num_shots=1,
20
+ num_queries=100,
21
+ num_episodes=1,
22
+ pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],
23
+ test_mode=True):
24
+ super().__init__(
25
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)
26
+
27
+ self.ann_info['flip_pairs'] = []
28
+
29
+ self.ann_info['upper_body_ids'] = []
30
+ self.ann_info['lower_body_ids'] = []
31
+
32
+ self.ann_info['use_different_joint_weights'] = False
33
+ self.ann_info['joint_weights'] = np.array([1.,],
34
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
35
+
36
+ self.coco = COCO(ann_file)
37
+
38
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
39
+ self.img_ids = self.coco.getImgIds()
40
+
41
+ cat = None
42
+ relevant_names = [
43
+ '000000052046',
44
+ '000000052152'
45
+
46
+ # '000000027059',
47
+ # '000000030361'
48
+ # '000000027936',
49
+ # 'Pileated_Woodpecker_0004_180307', 'American_Three_Toed_Woodpecker_0019_179870'
50
+ # '000000016379', '000000008869'
51
+ # 'commonwarthog_115',
52
+ # 'commonwarthog_78'
53
+ # '000000027059', '000000030361', '000000027936'
54
+ # 'klipspringer_66', '000000008333', '000000026814', '000000047543', '000000052080', 'Common_Tern_0050_148928'
55
+ ]
56
+ if len(relevant_names) > 0:
57
+ if cat is not None:
58
+ relevant_names = [os.path.join(cat, name) for name in relevant_names]
59
+ self.img_ids = [img_id for img_id in self.img_ids if self.id2name[img_id] in relevant_names]
60
+ else:
61
+ new_ids = []
62
+ for relevant_name in relevant_names:
63
+ new_ids += [img_id for img_id in self.img_ids if relevant_name in self.id2name[img_id]]
64
+ self.img_ids = new_ids
65
+ else:
66
+ self.img_ids = [img_id for img_id in self.img_ids if cat == self.id2name[img_id].split('/')[0]]
67
+
68
+ self.classes = [
69
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
70
+ ]
71
+
72
+ self.num_classes = len(self.classes)
73
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
74
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
75
+
76
+ if valid_class_ids is not None: # None by default
77
+ self.valid_class_ids = valid_class_ids
78
+ else:
79
+ self.valid_class_ids = self.coco.getCatIds()
80
+
81
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
82
+ self.cats = self.coco.cats
83
+ self.max_kpt_num = max_kpt_num
84
+
85
+ # Also update self.cat2obj
86
+ self.db = self._get_db()
87
+
88
+ self.num_shots = num_shots
89
+
90
+ if not test_mode:
91
+ # Update every training epoch
92
+ self.random_paired_samples()
93
+ else:
94
+ self.num_queries = num_queries
95
+ self.num_episodes = num_episodes
96
+ self.make_paired_samples()
97
+
98
+
99
+ def random_paired_samples(self):
100
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
101
+
102
+ # balance the dataset
103
+ max_num_data = max(num_datas)
104
+
105
+ all_samples = []
106
+ for cls in self.valid_class_ids:
107
+ for i in range(max_num_data):
108
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
109
+ all_samples.append(shot)
110
+
111
+ self.paired_samples = np.array(all_samples)
112
+ np.random.shuffle(self.paired_samples)
113
+
114
+ def make_paired_samples(self):
115
+ random.seed(1)
116
+ np.random.seed(0)
117
+ all_samples = []
118
+ self.num_episodes = 1000
119
+ for cls in self.valid_class_ids:
120
+ for _ in range(self.num_episodes):
121
+ if self.cat2obj[cls] == []:
122
+ continue
123
+ self.num_queries = 1
124
+ self.num_shots = 1
125
+ if len(self.cat2obj[cls]) < self.num_shots + self.num_queries:
126
+ shots = random.choices(self.cat2obj[cls], k=self.num_shots + self.num_queries)
127
+ else:
128
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
129
+ sample_ids = shots[:self.num_shots]
130
+ query_ids = shots[self.num_shots:]
131
+ for query_id in query_ids:
132
+ all_samples.append(sample_ids + [query_id])
133
+ all_samples.append([query_id] + [query_id])
134
+
135
+ self.paired_samples = np.array(list(set(tuple(x) for x in all_samples)))
136
+
137
+ def _select_kpt(self, obj, kpt_id):
138
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
139
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
140
+ obj['kpt_id'] = kpt_id
141
+
142
+ return obj
143
+
144
+ @staticmethod
145
+ def _get_mapping_id_name(imgs):
146
+ """
147
+ Args:
148
+ imgs (dict): dict of image info.
149
+
150
+ Returns:
151
+ tuple: Image name & id mapping dicts.
152
+
153
+ - id2name (dict): Mapping image id to name.
154
+ - name2id (dict): Mapping image name to id.
155
+ """
156
+ id2name = {}
157
+ name2id = {}
158
+ for image_id, image in imgs.items():
159
+ file_name = image['file_name']
160
+ id2name[image_id] = file_name
161
+ name2id[file_name] = image_id
162
+
163
+ return id2name, name2id
164
+
165
+ def _get_db(self):
166
+ """Ground truth bbox and keypoints."""
167
+ self.obj_id = 0
168
+
169
+ self.cat2obj = {}
170
+ for i in self.coco.getCatIds():
171
+ self.cat2obj.update({i: []})
172
+
173
+ gt_db = []
174
+ for img_id in self.img_ids:
175
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
176
+ return gt_db
177
+
178
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
179
+ """load annotation from COCOAPI.
180
+
181
+ Note:
182
+ bbox:[x1, y1, w, h]
183
+ Args:
184
+ img_id: coco image id
185
+ Returns:
186
+ dict: db entry
187
+ """
188
+ img_ann = self.coco.loadImgs(img_id)[0]
189
+ width = img_ann['width']
190
+ height = img_ann['height']
191
+
192
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
193
+ objs = self.coco.loadAnns(ann_ids)
194
+
195
+ # sanitize bboxes
196
+ valid_objs = []
197
+ for obj in objs:
198
+ if 'bbox' not in obj:
199
+ continue
200
+ x, y, w, h = obj['bbox']
201
+ x1 = max(0, x)
202
+ y1 = max(0, y)
203
+ x2 = min(width - 1, x1 + max(0, w - 1))
204
+ y2 = min(height - 1, y1 + max(0, h - 1))
205
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
206
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
207
+ valid_objs.append(obj)
208
+ objs = valid_objs
209
+
210
+ bbox_id = 0
211
+ rec = []
212
+ for obj in objs:
213
+ if 'keypoints' not in obj:
214
+ continue
215
+ if max(obj['keypoints']) == 0:
216
+ continue
217
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
218
+ continue
219
+
220
+ category_id = obj['category_id']
221
+ # the number of keypoint for this specific category
222
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
223
+ if self.max_kpt_num is None:
224
+ kpt_num = cat_kpt_num
225
+ else:
226
+ kpt_num = self.max_kpt_num
227
+
228
+ joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
229
+ joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
230
+
231
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
232
+ joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
233
+ joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
234
+
235
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
236
+
237
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
238
+
239
+ self.cat2obj[category_id].append(self.obj_id)
240
+
241
+ rec.append({
242
+ 'image_file': image_file,
243
+ 'center': center,
244
+ 'scale': scale,
245
+ 'rotation': 0,
246
+ 'bbox': obj['clean_bbox'][:4],
247
+ 'bbox_score': 1,
248
+ 'joints_3d': joints_3d,
249
+ 'joints_3d_visible': joints_3d_visible,
250
+ 'category_id': category_id,
251
+ 'cat_kpt_num': cat_kpt_num,
252
+ 'bbox_id': self.obj_id,
253
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
254
+ })
255
+ bbox_id = bbox_id + 1
256
+ self.obj_id += 1
257
+
258
+ return rec
259
+
260
+ def _xywh2cs(self, x, y, w, h):
261
+ """This encodes bbox(x,y,w,w) into (center, scale)
262
+
263
+ Args:
264
+ x, y, w, h
265
+
266
+ Returns:
267
+ tuple: A tuple containing center and scale.
268
+
269
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
270
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
271
+ """
272
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
273
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
274
+ #
275
+ # if (not self.test_mode) and np.random.rand() < 0.3:
276
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
277
+
278
+ if w > aspect_ratio * h:
279
+ h = w * 1.0 / aspect_ratio
280
+ elif w < aspect_ratio * h:
281
+ w = h * aspect_ratio
282
+
283
+ # pixel std is 200.0
284
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
285
+ # padding to include proper amount of context
286
+ scale = scale * 1.25
287
+
288
+ return center, scale
289
+
290
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
291
+ """Evaluate interhand2d keypoint results. The pose prediction results
292
+ will be saved in `${res_folder}/result_keypoints.json`.
293
+
294
+ Note:
295
+ batch_size: N
296
+ num_keypoints: K
297
+ heatmap height: H
298
+ heatmap width: W
299
+
300
+ Args:
301
+ outputs (list(preds, boxes, image_path, output_heatmap))
302
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
303
+ coordinates, score is the third dimension of the array.
304
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
305
+ , scale[1],area, score]
306
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
307
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
308
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
309
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
310
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
311
+ 'j', 'p', 'g']
312
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
313
+
314
+ res_folder (str): Path of directory to save the results.
315
+ metric (str | list[str]): Metric to be performed.
316
+ Options: 'PCK', 'AUC', 'EPE'.
317
+
318
+ Returns:
319
+ dict: Evaluation results for evaluation metric.
320
+ """
321
+ metrics = metric if isinstance(metric, list) else [metric]
322
+ allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
323
+ for metric in metrics:
324
+ if metric not in allowed_metrics:
325
+ raise KeyError(f'metric {metric} is not supported')
326
+
327
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
328
+
329
+ kpts = []
330
+ for output in outputs:
331
+ preds = output['preds']
332
+ boxes = output['boxes']
333
+ image_paths = output['image_paths']
334
+ bbox_ids = output['bbox_ids']
335
+
336
+ batch_size = len(image_paths)
337
+ for i in range(batch_size):
338
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
339
+
340
+ kpts.append({
341
+ 'keypoints': preds[i].tolist(),
342
+ 'center': boxes[i][0:2].tolist(),
343
+ 'scale': boxes[i][2:4].tolist(),
344
+ 'area': float(boxes[i][4]),
345
+ 'score': float(boxes[i][5]),
346
+ 'image_id': image_id,
347
+ 'bbox_id': bbox_ids[i]
348
+ })
349
+ kpts = self._sort_and_unique_bboxes(kpts)
350
+
351
+ self._write_keypoint_results(kpts, res_file)
352
+ info_str = self._report_metric(res_file, metrics)
353
+ name_value = OrderedDict(info_str)
354
+
355
+ return name_value
EdgeCape/datasets/datasets/mp100/fewshot_base_dataset.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABCMeta, abstractmethod
3
+ import json_tricks as json
4
+ import numpy as np
5
+
6
+ from mmcv.parallel import DataContainer as DC
7
+ from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe,
8
+ keypoint_pck_accuracy)
9
+ from torch.utils.data import Dataset
10
+ from mmpose.datasets import DATASETS
11
+ from mmpose.datasets.pipelines import Compose
12
+
13
+ @DATASETS.register_module()
14
+ class FewShotBaseDataset(Dataset, metaclass=ABCMeta):
15
+
16
+ def __init__(self,
17
+ ann_file,
18
+ img_prefix,
19
+ data_cfg,
20
+ pipeline,
21
+ test_mode=False):
22
+ self.image_info = {}
23
+ self.ann_info = {}
24
+
25
+ self.annotations_path = ann_file
26
+ if not img_prefix.endswith('/'):
27
+ img_prefix = img_prefix + '/'
28
+ self.img_prefix = img_prefix
29
+ self.pipeline = pipeline
30
+ self.test_mode = test_mode
31
+
32
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
33
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
34
+ self.ann_info['num_joints'] = data_cfg['num_joints']
35
+
36
+ self.ann_info['flip_pairs'] = None
37
+
38
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
39
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
40
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
41
+
42
+ self.db = []
43
+ self.num_shots = 1
44
+ self.paired_samples = []
45
+ self.pipeline = Compose(self.pipeline)
46
+
47
+ @abstractmethod
48
+ def _get_db(self):
49
+ """Load dataset."""
50
+ raise NotImplementedError
51
+
52
+ @abstractmethod
53
+ def _select_kpt(self, obj, kpt_id):
54
+ """Select kpt."""
55
+ raise NotImplementedError
56
+
57
+ @abstractmethod
58
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
59
+ """Evaluate keypoint results."""
60
+ raise NotImplementedError
61
+
62
+ @staticmethod
63
+ def _write_keypoint_results(keypoints, res_file):
64
+ """Write results into a json file."""
65
+
66
+ with open(res_file, 'w') as f:
67
+ json.dump(keypoints, f, sort_keys=True, indent=4)
68
+
69
+ def _report_metric(self,
70
+ res_file,
71
+ metrics,
72
+ pck_thr=0.2,
73
+ pckh_thr=0.7,
74
+ auc_nor=30):
75
+ """Keypoint evaluation.
76
+
77
+ Args:
78
+ res_file (str): Json file stored prediction results.
79
+ metrics (str | list[str]): Metric to be performed.
80
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
81
+ pck_thr (float): PCK threshold, default as 0.2.
82
+ pckh_thr (float): PCKh threshold, default as 0.7.
83
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
84
+
85
+ Returns:
86
+ List: Evaluation results for evaluation metric.
87
+ """
88
+ info_str = []
89
+
90
+ with open(res_file, 'r') as fin:
91
+ preds = json.load(fin)
92
+ assert len(preds) == len(self.paired_samples)
93
+
94
+ outputs = []
95
+ gts = []
96
+ masks = []
97
+ threshold_bbox = []
98
+ threshold_head_box = []
99
+
100
+ for pred, pair in zip(preds, self.paired_samples):
101
+ item = self.db[pair[-1]]
102
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
103
+ gts.append(np.array(item['joints_3d'])[:, :-1])
104
+
105
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
106
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
107
+ for id_s in pair[:-1]:
108
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
109
+ masks.append(np.bitwise_and(mask_query, mask_sample))
110
+
111
+ if 'PCK' in metrics:
112
+ bbox = np.array(item['bbox'])
113
+ bbox_thr = np.max(bbox[2:])
114
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
115
+ if 'PCKh' in metrics:
116
+ head_box_thr = item['head_size']
117
+ threshold_head_box.append(
118
+ np.array([head_box_thr, head_box_thr]))
119
+
120
+ if 'PCK' in metrics:
121
+ pck_avg = []
122
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
123
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
124
+ pck_avg.append(pck)
125
+ info_str.append(('PCK', np.mean(pck_avg)))
126
+
127
+ return info_str
128
+
129
+ def _merge_obj(self, Xs_list, Xq, idx):
130
+ """ merge Xs_list and Xq.
131
+
132
+ :param Xs_list: N-shot samples X
133
+ :param Xq: query X
134
+ :param idx: id of paired_samples
135
+ :return: Xall
136
+ """
137
+ Xall = dict()
138
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
139
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
140
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
141
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
142
+
143
+ Xall['img_q'] = Xq['img']
144
+ Xall['target_q'] = Xq['target']
145
+ Xall['target_weight_q'] = Xq['target_weight']
146
+ xq_img_metas = Xq['img_metas'].data
147
+
148
+ img_metas = dict()
149
+ for key in xq_img_metas.keys():
150
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
151
+ img_metas['query_' + key] = xq_img_metas[key]
152
+ img_metas['bbox_id'] = idx
153
+
154
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
155
+
156
+ return Xall
157
+
158
+ def __len__(self):
159
+ """Get the size of the dataset."""
160
+ return len(self.paired_samples)
161
+
162
+ def __getitem__(self, idx):
163
+ """Get the sample given index."""
164
+
165
+ pair_ids = self.paired_samples[idx]
166
+ assert len(pair_ids) == self.num_shots + 1
167
+ sample_id_list = pair_ids[:self.num_shots]
168
+ query_id = pair_ids[-1]
169
+
170
+ sample_obj_list = []
171
+ for sample_id in sample_id_list:
172
+ sample_obj = copy.deepcopy(self.db[sample_id])
173
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
174
+ sample_obj_list.append(sample_obj)
175
+
176
+ query_obj = copy.deepcopy(self.db[query_id])
177
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
178
+
179
+ if not self.test_mode:
180
+ # randomly select "one" keypoint
181
+ sample_valid = (sample_obj_list[0]['joints_3d_visible'][:, 0] > 0)
182
+ for sample_obj in sample_obj_list:
183
+ sample_valid = sample_valid & (sample_obj['joints_3d_visible'][:, 0] > 0)
184
+ query_valid = (query_obj['joints_3d_visible'][:, 0] > 0)
185
+
186
+ valid_s = np.where(sample_valid)[0]
187
+ valid_q = np.where(query_valid)[0]
188
+ valid_sq = np.where(sample_valid & query_valid)[0]
189
+ if len(valid_sq) > 0:
190
+ kpt_id = np.random.choice(valid_sq)
191
+ elif len(valid_s) > 0:
192
+ kpt_id = np.random.choice(valid_s)
193
+ elif len(valid_q) > 0:
194
+ kpt_id = np.random.choice(valid_q)
195
+ else:
196
+ kpt_id = np.random.choice(np.array(range(len(query_valid))))
197
+
198
+ for i in range(self.num_shots):
199
+ sample_obj_list[i] = self._select_kpt(sample_obj_list[i], kpt_id)
200
+ query_obj = self._select_kpt(query_obj, kpt_id)
201
+
202
+ # when test, all keypoints will be preserved.
203
+
204
+ Xs_list = []
205
+ for sample_obj in sample_obj_list:
206
+ Xs = self.pipeline(sample_obj)
207
+ Xs_list.append(Xs)
208
+ Xq = self.pipeline(query_obj)
209
+
210
+ Xall = self._merge_obj(Xs_list, Xq, idx)
211
+ Xall['skeleton'] = self.db[query_id]['skeleton']
212
+
213
+ return Xall
214
+
215
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
216
+ """sort kpts and remove the repeated ones."""
217
+ kpts = sorted(kpts, key=lambda x: x[key])
218
+ num = len(kpts)
219
+ for i in range(num - 1, 0, -1):
220
+ if kpts[i][key] == kpts[i - 1][key]:
221
+ del kpts[i]
222
+
223
+ return kpts
EdgeCape/datasets/datasets/mp100/fewshot_dataset.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmpose.datasets import DATASETS
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from collections import OrderedDict
6
+ from xtcocotools.coco import COCO
7
+ from .fewshot_base_dataset import FewShotBaseDataset
8
+
9
+ @DATASETS.register_module()
10
+ class FewShotKeypointDataset(FewShotBaseDataset):
11
+
12
+ def __init__(self,
13
+ ann_file,
14
+ img_prefix,
15
+ data_cfg,
16
+ pipeline,
17
+ valid_class_ids,
18
+ num_shots = 1,
19
+ num_queries = 100,
20
+ num_episodes = 1,
21
+ test_mode=False):
22
+ super().__init__(
23
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
24
+
25
+ self.ann_info['flip_pairs'] = []
26
+
27
+ self.ann_info['upper_body_ids'] = []
28
+ self.ann_info['lower_body_ids'] = []
29
+
30
+ self.ann_info['use_different_joint_weights'] = False
31
+ self.ann_info['joint_weights'] = np.array([1.,],
32
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
33
+
34
+ self.coco = COCO(ann_file)
35
+
36
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
37
+ self.img_ids = self.coco.getImgIds()
38
+ self.classes = [
39
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
40
+ ]
41
+
42
+ self.num_classes = len(self.classes)
43
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
44
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
45
+
46
+ if valid_class_ids is not None:
47
+ self.valid_class_ids = valid_class_ids
48
+ else:
49
+ self.valid_class_ids = self.coco.getCatIds()
50
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
51
+
52
+ self.cats = self.coco.cats
53
+
54
+ # Also update self.cat2obj
55
+ self.db = self._get_db()
56
+
57
+ self.num_shots = num_shots
58
+
59
+ if not test_mode:
60
+ # Update every training epoch
61
+ self.random_paired_samples()
62
+ else:
63
+ self.num_queries = num_queries
64
+ self.num_episodes = num_episodes
65
+ self.make_paired_samples()
66
+
67
+
68
+ def random_paired_samples(self):
69
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
70
+
71
+ # balance the dataset
72
+ max_num_data = max(num_datas)
73
+
74
+ all_samples = []
75
+ for cls in self.valid_class_ids:
76
+ for i in range(max_num_data):
77
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
78
+ all_samples.append(shot)
79
+
80
+ self.paired_samples = np.array(all_samples)
81
+ np.random.shuffle(self.paired_samples)
82
+
83
+ def make_paired_samples(self):
84
+ random.seed(1)
85
+ np.random.seed(0)
86
+
87
+ all_samples = []
88
+ for cls in self.valid_class_ids:
89
+ for _ in range(self.num_episodes):
90
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
91
+ sample_ids = shots[:self.num_shots]
92
+ query_ids = shots[self.num_shots:]
93
+ for query_id in query_ids:
94
+ all_samples.append(sample_ids + [query_id])
95
+
96
+ self.paired_samples = np.array(all_samples)
97
+
98
+ def _select_kpt(self, obj, kpt_id):
99
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
100
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
101
+ obj['kpt_id'] = kpt_id
102
+
103
+ return obj
104
+
105
+ @staticmethod
106
+ def _get_mapping_id_name(imgs):
107
+ """
108
+ Args:
109
+ imgs (dict): dict of image info.
110
+
111
+ Returns:
112
+ tuple: Image name & id mapping dicts.
113
+
114
+ - id2name (dict): Mapping image id to name.
115
+ - name2id (dict): Mapping image name to id.
116
+ """
117
+ id2name = {}
118
+ name2id = {}
119
+ for image_id, image in imgs.items():
120
+ file_name = image['file_name']
121
+ id2name[image_id] = file_name
122
+ name2id[file_name] = image_id
123
+
124
+ return id2name, name2id
125
+
126
+ def _get_db(self):
127
+ """Ground truth bbox and keypoints."""
128
+ self.obj_id = 0
129
+
130
+ self.cat2obj = {}
131
+ for i in self.coco.getCatIds():
132
+ self.cat2obj.update({i: []})
133
+
134
+ gt_db = []
135
+ for img_id in self.img_ids:
136
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
137
+ return gt_db
138
+
139
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
140
+ """load annotation from COCOAPI.
141
+
142
+ Note:
143
+ bbox:[x1, y1, w, h]
144
+ Args:
145
+ img_id: coco image id
146
+ Returns:
147
+ dict: db entry
148
+ """
149
+ img_ann = self.coco.loadImgs(img_id)[0]
150
+ width = img_ann['width']
151
+ height = img_ann['height']
152
+
153
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
154
+ objs = self.coco.loadAnns(ann_ids)
155
+
156
+ # sanitize bboxes
157
+ valid_objs = []
158
+ for obj in objs:
159
+ if 'bbox' not in obj:
160
+ continue
161
+ x, y, w, h = obj['bbox']
162
+ x1 = max(0, x)
163
+ y1 = max(0, y)
164
+ x2 = min(width - 1, x1 + max(0, w - 1))
165
+ y2 = min(height - 1, y1 + max(0, h - 1))
166
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
167
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
168
+ valid_objs.append(obj)
169
+ objs = valid_objs
170
+
171
+ bbox_id = 0
172
+ rec = []
173
+ for obj in objs:
174
+ if 'keypoints' not in obj:
175
+ continue
176
+ if max(obj['keypoints']) == 0:
177
+ continue
178
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
179
+ continue
180
+
181
+ category_id = obj['category_id']
182
+ # the number of keypoint for this specific category
183
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
184
+
185
+ joints_3d = np.zeros((cat_kpt_num, 3), dtype=np.float32)
186
+ joints_3d_visible = np.zeros((cat_kpt_num, 3), dtype=np.float32)
187
+
188
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
189
+ joints_3d[:, :2] = keypoints[:, :2]
190
+ joints_3d_visible[:, :2] = np.minimum(1, keypoints[:, 2:3])
191
+
192
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
193
+
194
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
195
+
196
+ self.cat2obj[category_id].append(self.obj_id)
197
+
198
+ rec.append({
199
+ 'image_file': image_file,
200
+ 'center': center,
201
+ 'scale': scale,
202
+ 'rotation': 0,
203
+ 'bbox': obj['clean_bbox'][:4],
204
+ 'bbox_score': 1,
205
+ 'joints_3d': joints_3d,
206
+ 'joints_3d_visible': joints_3d_visible,
207
+ 'category_id': category_id,
208
+ 'cat_kpt_num': cat_kpt_num,
209
+ 'bbox_id': self.obj_id,
210
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
211
+ })
212
+ bbox_id = bbox_id + 1
213
+ self.obj_id += 1
214
+
215
+ return rec
216
+
217
+ def _xywh2cs(self, x, y, w, h):
218
+ """This encodes bbox(x,y,w,w) into (center, scale)
219
+
220
+ Args:
221
+ x, y, w, h
222
+
223
+ Returns:
224
+ tuple: A tuple containing center and scale.
225
+
226
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
227
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
228
+ """
229
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
230
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
231
+ #
232
+ # if (not self.test_mode) and np.random.rand() < 0.3:
233
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
234
+
235
+ if w > aspect_ratio * h:
236
+ h = w * 1.0 / aspect_ratio
237
+ elif w < aspect_ratio * h:
238
+ w = h * aspect_ratio
239
+
240
+ # pixel std is 200.0
241
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
242
+ # padding to include proper amount of context
243
+ scale = scale * 1.25
244
+
245
+ return center, scale
246
+
247
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
248
+ """Evaluate interhand2d keypoint results. The pose prediction results
249
+ will be saved in `${res_folder}/result_keypoints.json`.
250
+
251
+ Note:
252
+ batch_size: N
253
+ num_keypoints: K
254
+ heatmap height: H
255
+ heatmap width: W
256
+
257
+ Args:
258
+ outputs (list(preds, boxes, image_path, output_heatmap))
259
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
260
+ coordinates, score is the third dimension of the array.
261
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
262
+ , scale[1],area, score]
263
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
264
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
265
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
266
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
267
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
268
+ 'j', 'p', 'g']
269
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
270
+
271
+ res_folder (str): Path of directory to save the results.
272
+ metric (str | list[str]): Metric to be performed.
273
+ Options: 'PCK', 'AUC', 'EPE'.
274
+
275
+ Returns:
276
+ dict: Evaluation results for evaluation metric.
277
+ """
278
+ metrics = metric if isinstance(metric, list) else [metric]
279
+ allowed_metrics = ['PCK', 'AUC', 'EPE']
280
+ for metric in metrics:
281
+ if metric not in allowed_metrics:
282
+ raise KeyError(f'metric {metric} is not supported')
283
+
284
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
285
+
286
+ kpts = []
287
+ for output in outputs:
288
+ preds = output['preds']
289
+ boxes = output['boxes']
290
+ image_paths = output['image_paths']
291
+ bbox_ids = output['bbox_ids']
292
+
293
+ batch_size = len(image_paths)
294
+ for i in range(batch_size):
295
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
296
+
297
+ kpts.append({
298
+ 'keypoints': preds[i].tolist(),
299
+ 'center': boxes[i][0:2].tolist(),
300
+ 'scale': boxes[i][2:4].tolist(),
301
+ 'area': float(boxes[i][4]),
302
+ 'score': float(boxes[i][5]),
303
+ 'image_id': image_id,
304
+ 'bbox_id': bbox_ids[i]
305
+ })
306
+ kpts = self._sort_and_unique_bboxes(kpts)
307
+
308
+ self._write_keypoint_results(kpts, res_file)
309
+ info_str = self._report_metric(res_file, metrics)
310
+ name_value = OrderedDict(info_str)
311
+
312
+ return name_value
EdgeCape/datasets/datasets/mp100/test_base_dataset.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABCMeta, abstractmethod
3
+ import json_tricks as json
4
+ import numpy as np
5
+
6
+ from mmcv.parallel import DataContainer as DC
7
+ from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe, keypoint_nme,
8
+ keypoint_pck_accuracy)
9
+ from torch.utils.data import Dataset
10
+ from mmpose.datasets import DATASETS
11
+ from mmpose.datasets.pipelines import Compose
12
+
13
+ @DATASETS.register_module()
14
+ class TestBaseDataset(Dataset, metaclass=ABCMeta):
15
+
16
+ def __init__(self,
17
+ ann_file,
18
+ img_prefix,
19
+ data_cfg,
20
+ pipeline,
21
+ test_mode=True,
22
+ PCK_threshold_list=[0.05, 0.1, 0.15, 0.2, 0.25]):
23
+ self.image_info = {}
24
+ self.ann_info = {}
25
+
26
+ self.annotations_path = ann_file
27
+ if not img_prefix.endswith('/'):
28
+ img_prefix = img_prefix + '/'
29
+ self.img_prefix = img_prefix
30
+ self.pipeline = pipeline
31
+ self.test_mode = test_mode
32
+ self.PCK_threshold_list = PCK_threshold_list
33
+
34
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
35
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
36
+ self.ann_info['num_joints'] = data_cfg['num_joints']
37
+
38
+ self.ann_info['flip_pairs'] = None
39
+
40
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
41
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
42
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
43
+
44
+ self.db = []
45
+ self.num_shots = 1
46
+ self.paired_samples = []
47
+ self.pipeline = Compose(self.pipeline)
48
+
49
+ @abstractmethod
50
+ def _get_db(self):
51
+ """Load dataset."""
52
+ raise NotImplementedError
53
+
54
+ @abstractmethod
55
+ def _select_kpt(self, obj, kpt_id):
56
+ """Select kpt."""
57
+ raise NotImplementedError
58
+
59
+ @abstractmethod
60
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
61
+ """Evaluate keypoint results."""
62
+ raise NotImplementedError
63
+
64
+ @staticmethod
65
+ def _write_keypoint_results(keypoints, res_file):
66
+ """Write results into a json file."""
67
+
68
+ with open(res_file, 'w') as f:
69
+ json.dump(keypoints, f, sort_keys=True, indent=4)
70
+
71
+ def _report_metric(self,
72
+ res_file,
73
+ metrics):
74
+ """Keypoint evaluation.
75
+
76
+ Args:
77
+ res_file (str): Json file stored prediction results.
78
+ metrics (str | list[str]): Metric to be performed.
79
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
80
+ pck_thr (float): PCK threshold, default as 0.2.
81
+ pckh_thr (float): PCKh threshold, default as 0.7.
82
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
83
+
84
+ Returns:
85
+ List: Evaluation results for evaluation metric.
86
+ """
87
+ info_str = []
88
+
89
+ with open(res_file, 'r') as fin:
90
+ preds = json.load(fin)
91
+ assert len(preds) == len(self.paired_samples)
92
+
93
+ outputs = []
94
+ gts = []
95
+ masks = []
96
+ threshold_bbox = []
97
+ threshold_head_box = []
98
+
99
+ for pred, pair in zip(preds, self.paired_samples):
100
+ item = self.db[pair[-1]]
101
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
102
+ gts.append(np.array(item['joints_3d'])[:, :-1])
103
+
104
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
105
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
106
+ for id_s in pair[:-1]:
107
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
108
+ masks.append(np.bitwise_and(mask_query, mask_sample))
109
+
110
+ if 'PCK' in metrics or 'NME' in metrics or 'AUC' in metrics:
111
+ bbox = np.array(item['bbox'])
112
+ bbox_thr = np.max(bbox[2:])
113
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
114
+ if 'PCKh' in metrics:
115
+ head_box_thr = item['head_size']
116
+ threshold_head_box.append(
117
+ np.array([head_box_thr, head_box_thr]))
118
+
119
+ if 'PCK' in metrics:
120
+ pck_results = dict()
121
+ for pck_thr in self.PCK_threshold_list:
122
+ pck_results[pck_thr] = []
123
+
124
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
125
+ for pck_thr in self.PCK_threshold_list:
126
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
127
+ pck_results[pck_thr].append(pck)
128
+
129
+ mPCK = 0
130
+ for pck_thr in self.PCK_threshold_list:
131
+ info_str.append(['PCK@' + str(pck_thr), np.mean(pck_results[pck_thr])])
132
+ mPCK += np.mean(pck_results[pck_thr])
133
+ info_str.append(['mPCK', mPCK / len(self.PCK_threshold_list)])
134
+
135
+ if 'NME' in metrics:
136
+ nme_results = []
137
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
138
+ nme = keypoint_nme(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), np.expand_dims(thr_bbox,0))
139
+ nme_results.append(nme)
140
+ info_str.append(['NME', np.mean(nme_results)])
141
+
142
+ if 'AUC' in metrics:
143
+ auc_results = []
144
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
145
+ auc = keypoint_auc(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), thr_bbox[0])
146
+ auc_results.append(auc)
147
+ info_str.append(['AUC', np.mean(auc_results)])
148
+
149
+ if 'EPE' in metrics:
150
+ epe_results = []
151
+ for (output, gt, mask) in zip(outputs, gts, masks):
152
+ epe = keypoint_epe(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0))
153
+ epe_results.append(epe)
154
+ info_str.append(['EPE', np.mean(epe_results)])
155
+ return info_str
156
+
157
+ def _merge_obj(self, Xs_list, Xq, idx):
158
+ """ merge Xs_list and Xq.
159
+
160
+ :param Xs_list: N-shot samples X
161
+ :param Xq: query X
162
+ :param idx: id of paired_samples
163
+ :return: Xall
164
+ """
165
+ Xall = dict()
166
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
167
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
168
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
169
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
170
+
171
+ Xall['img_q'] = Xq['img']
172
+ Xall['target_q'] = Xq['target']
173
+ Xall['target_weight_q'] = Xq['target_weight']
174
+ xq_img_metas = Xq['img_metas'].data
175
+
176
+ img_metas = dict()
177
+ for key in xq_img_metas.keys():
178
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
179
+ img_metas['query_' + key] = xq_img_metas[key]
180
+ img_metas['bbox_id'] = idx
181
+
182
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
183
+
184
+ return Xall
185
+
186
+ def __len__(self):
187
+ """Get the size of the dataset."""
188
+ return len(self.paired_samples)
189
+
190
+ def __getitem__(self, idx):
191
+ """Get the sample given index."""
192
+
193
+ pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
194
+ assert len(pair_ids) == self.num_shots + 1
195
+ sample_id_list = pair_ids[:self.num_shots]
196
+ query_id = pair_ids[-1]
197
+
198
+ sample_obj_list = []
199
+ for sample_id in sample_id_list:
200
+ sample_obj = copy.deepcopy(self.db[sample_id])
201
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
202
+ sample_obj_list.append(sample_obj)
203
+
204
+ query_obj = copy.deepcopy(self.db[query_id])
205
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
206
+
207
+ Xs_list = []
208
+ for sample_obj in sample_obj_list:
209
+ Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
210
+ Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
211
+ Xq = self.pipeline(query_obj)
212
+
213
+ Xall = self._merge_obj(Xs_list, Xq, idx)
214
+ Xall['skeleton'] = self.db[query_id]['skeleton']
215
+
216
+ return Xall
217
+
218
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
219
+ """sort kpts and remove the repeated ones."""
220
+ kpts = sorted(kpts, key=lambda x: x[key])
221
+ num = len(kpts)
222
+ for i in range(num - 1, 0, -1):
223
+ if kpts[i][key] == kpts[i - 1][key]:
224
+ del kpts[i]
225
+
226
+ return kpts
EdgeCape/datasets/datasets/mp100/test_dataset.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmpose.datasets import DATASETS
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from collections import OrderedDict
6
+ from xtcocotools.coco import COCO
7
+ from .test_base_dataset import TestBaseDataset
8
+
9
+ @DATASETS.register_module()
10
+ class TestPoseDataset(TestBaseDataset):
11
+
12
+ def __init__(self,
13
+ ann_file,
14
+ img_prefix,
15
+ data_cfg,
16
+ pipeline,
17
+ valid_class_ids,
18
+ max_kpt_num=None,
19
+ num_shots=1,
20
+ num_queries=100,
21
+ num_episodes=1,
22
+ pck_threshold_list=[0.05, 0.1, 0.15, 0.20, 0.25],
23
+ test_mode=True):
24
+ super().__init__(
25
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode, PCK_threshold_list=pck_threshold_list)
26
+
27
+ self.ann_info['flip_pairs'] = []
28
+
29
+ self.ann_info['upper_body_ids'] = []
30
+ self.ann_info['lower_body_ids'] = []
31
+
32
+ self.ann_info['use_different_joint_weights'] = False
33
+ self.ann_info['joint_weights'] = np.array([1.,],
34
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
35
+
36
+ self.coco = COCO(ann_file)
37
+
38
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
39
+ self.img_ids = self.coco.getImgIds()
40
+ self.classes = [
41
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
42
+ ]
43
+
44
+ self.num_classes = len(self.classes)
45
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
46
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
47
+
48
+ if valid_class_ids is not None: # None by default
49
+ self.valid_class_ids = valid_class_ids
50
+ else:
51
+ self.valid_class_ids = self.coco.getCatIds()
52
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
53
+
54
+ self.cats = self.coco.cats
55
+ self.max_kpt_num = max_kpt_num
56
+
57
+ # Also update self.cat2obj
58
+ self.db = self._get_db()
59
+
60
+ self.num_shots = num_shots
61
+
62
+ if not test_mode:
63
+ # Update every training epoch
64
+ self.random_paired_samples()
65
+ else:
66
+ self.num_queries = num_queries
67
+ self.num_episodes = num_episodes
68
+ self.make_paired_samples()
69
+
70
+
71
+ def random_paired_samples(self):
72
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
73
+
74
+ # balance the dataset
75
+ max_num_data = max(num_datas)
76
+
77
+ all_samples = []
78
+ for cls in self.valid_class_ids:
79
+ for i in range(max_num_data):
80
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
81
+ all_samples.append(shot)
82
+
83
+ self.paired_samples = np.array(all_samples)
84
+ np.random.shuffle(self.paired_samples)
85
+
86
+ def make_paired_samples(self):
87
+ random.seed(1)
88
+ np.random.seed(0)
89
+
90
+ all_samples = []
91
+ for cls in self.valid_class_ids:
92
+ for _ in range(self.num_episodes):
93
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
94
+ sample_ids = shots[:self.num_shots]
95
+ query_ids = shots[self.num_shots:]
96
+ for query_id in query_ids:
97
+ all_samples.append(sample_ids + [query_id])
98
+
99
+ self.paired_samples = np.array(all_samples)
100
+
101
+ def _select_kpt(self, obj, kpt_id):
102
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
103
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
104
+ obj['kpt_id'] = kpt_id
105
+
106
+ return obj
107
+
108
+ @staticmethod
109
+ def _get_mapping_id_name(imgs):
110
+ """
111
+ Args:
112
+ imgs (dict): dict of image info.
113
+
114
+ Returns:
115
+ tuple: Image name & id mapping dicts.
116
+
117
+ - id2name (dict): Mapping image id to name.
118
+ - name2id (dict): Mapping image name to id.
119
+ """
120
+ id2name = {}
121
+ name2id = {}
122
+ for image_id, image in imgs.items():
123
+ file_name = image['file_name']
124
+ id2name[image_id] = file_name
125
+ name2id[file_name] = image_id
126
+
127
+ return id2name, name2id
128
+
129
+ def _get_db(self):
130
+ """Ground truth bbox and keypoints."""
131
+ self.obj_id = 0
132
+
133
+ self.cat2obj = {}
134
+ for i in self.coco.getCatIds():
135
+ self.cat2obj.update({i: []})
136
+
137
+ gt_db = []
138
+ for img_id in self.img_ids:
139
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
140
+ return gt_db
141
+
142
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
143
+ """load annotation from COCOAPI.
144
+
145
+ Note:
146
+ bbox:[x1, y1, w, h]
147
+ Args:
148
+ img_id: coco image id
149
+ Returns:
150
+ dict: db entry
151
+ """
152
+ img_ann = self.coco.loadImgs(img_id)[0]
153
+ width = img_ann['width']
154
+ height = img_ann['height']
155
+
156
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
157
+ objs = self.coco.loadAnns(ann_ids)
158
+
159
+ # sanitize bboxes
160
+ valid_objs = []
161
+ for obj in objs:
162
+ if 'bbox' not in obj:
163
+ continue
164
+ x, y, w, h = obj['bbox']
165
+ x1 = max(0, x)
166
+ y1 = max(0, y)
167
+ x2 = min(width - 1, x1 + max(0, w - 1))
168
+ y2 = min(height - 1, y1 + max(0, h - 1))
169
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
170
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
171
+ valid_objs.append(obj)
172
+ objs = valid_objs
173
+
174
+ bbox_id = 0
175
+ rec = []
176
+ for obj in objs:
177
+ if 'keypoints' not in obj:
178
+ continue
179
+ if max(obj['keypoints']) == 0:
180
+ continue
181
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
182
+ continue
183
+
184
+ category_id = obj['category_id']
185
+ # the number of keypoint for this specific category
186
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
187
+ if self.max_kpt_num is None:
188
+ kpt_num = cat_kpt_num
189
+ else:
190
+ kpt_num = self.max_kpt_num
191
+
192
+ joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
193
+ joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
194
+
195
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
196
+ joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
197
+ joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
198
+
199
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
200
+
201
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
202
+
203
+ self.cat2obj[category_id].append(self.obj_id)
204
+
205
+ rec.append({
206
+ 'image_file': image_file,
207
+ 'center': center,
208
+ 'scale': scale,
209
+ 'rotation': 0,
210
+ 'bbox': obj['clean_bbox'][:4],
211
+ 'bbox_score': 1,
212
+ 'joints_3d': joints_3d,
213
+ 'joints_3d_visible': joints_3d_visible,
214
+ 'category_id': category_id,
215
+ 'cat_kpt_num': cat_kpt_num,
216
+ 'bbox_id': self.obj_id,
217
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
218
+ })
219
+ bbox_id = bbox_id + 1
220
+ self.obj_id += 1
221
+
222
+ return rec
223
+
224
+ def _xywh2cs(self, x, y, w, h):
225
+ """This encodes bbox(x,y,w,w) into (center, scale)
226
+
227
+ Args:
228
+ x, y, w, h
229
+
230
+ Returns:
231
+ tuple: A tuple containing center and scale.
232
+
233
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
234
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
235
+ """
236
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
237
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
238
+ #
239
+ # if (not self.test_mode) and np.random.rand() < 0.3:
240
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
241
+
242
+ if w > aspect_ratio * h:
243
+ h = w * 1.0 / aspect_ratio
244
+ elif w < aspect_ratio * h:
245
+ w = h * aspect_ratio
246
+
247
+ # pixel std is 200.0
248
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
249
+ # padding to include proper amount of context
250
+ scale = scale * 1.25
251
+
252
+ return center, scale
253
+
254
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
255
+ """Evaluate interhand2d keypoint results. The pose prediction results
256
+ will be saved in `${res_folder}/result_keypoints.json`.
257
+
258
+ Note:
259
+ batch_size: N
260
+ num_keypoints: K
261
+ heatmap height: H
262
+ heatmap width: W
263
+
264
+ Args:
265
+ outputs (list(preds, boxes, image_path, output_heatmap))
266
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
267
+ coordinates, score is the third dimension of the array.
268
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
269
+ , scale[1],area, score]
270
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
271
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
272
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
273
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
274
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
275
+ 'j', 'p', 'g']
276
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
277
+
278
+ res_folder (str): Path of directory to save the results.
279
+ metric (str | list[str]): Metric to be performed.
280
+ Options: 'PCK', 'AUC', 'EPE'.
281
+
282
+ Returns:
283
+ dict: Evaluation results for evaluation metric.
284
+ """
285
+ metrics = metric if isinstance(metric, list) else [metric]
286
+ allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
287
+ for metric in metrics:
288
+ if metric not in allowed_metrics:
289
+ raise KeyError(f'metric {metric} is not supported')
290
+
291
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
292
+
293
+ kpts = []
294
+ for output in outputs:
295
+ preds = output['preds']
296
+ boxes = output['boxes']
297
+ image_paths = output['image_paths']
298
+ bbox_ids = output['bbox_ids']
299
+
300
+ batch_size = len(image_paths)
301
+ for i in range(batch_size):
302
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
303
+
304
+ kpts.append({
305
+ 'keypoints': preds[i].tolist(),
306
+ 'center': boxes[i][0:2].tolist(),
307
+ 'scale': boxes[i][2:4].tolist(),
308
+ 'area': float(boxes[i][4]),
309
+ 'score': float(boxes[i][5]),
310
+ 'image_id': image_id,
311
+ 'bbox_id': bbox_ids[i]
312
+ })
313
+ kpts = self._sort_and_unique_bboxes(kpts)
314
+
315
+ self._write_keypoint_results(kpts, res_file)
316
+ info_str = self._report_metric(res_file, metrics)
317
+ name_value = OrderedDict(info_str)
318
+
319
+ return name_value
EdgeCape/datasets/datasets/mp100/transformer_base_dataset.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from abc import ABCMeta, abstractmethod
3
+ import json_tricks as json
4
+ import numpy as np
5
+
6
+ from mmcv.parallel import DataContainer as DC
7
+ from mmpose.core.evaluation.top_down_eval import (keypoint_auc, keypoint_epe,
8
+ keypoint_pck_accuracy)
9
+ from torch.utils.data import Dataset
10
+ from mmpose.datasets import DATASETS
11
+ from mmpose.datasets.pipelines import Compose
12
+
13
+ @DATASETS.register_module()
14
+ class TransformerBaseDataset(Dataset, metaclass=ABCMeta):
15
+
16
+ def __init__(self,
17
+ ann_file,
18
+ img_prefix,
19
+ data_cfg,
20
+ pipeline,
21
+ masking_ratio=0.3,
22
+ test_mode=False):
23
+ self.image_info = {}
24
+ self.ann_info = {}
25
+
26
+ self.annotations_path = ann_file
27
+ if not img_prefix.endswith('/'):
28
+ img_prefix = img_prefix + '/'
29
+ self.img_prefix = img_prefix
30
+ self.pipeline = pipeline
31
+ self.test_mode = test_mode
32
+ self.masking_ratio = masking_ratio
33
+ self.ann_info['image_size'] = np.array(data_cfg['image_size'])
34
+ self.ann_info['heatmap_size'] = np.array(data_cfg['heatmap_size'])
35
+ self.ann_info['num_joints'] = data_cfg['num_joints']
36
+
37
+ self.ann_info['flip_pairs'] = None
38
+
39
+ self.ann_info['inference_channel'] = data_cfg['inference_channel']
40
+ self.ann_info['num_output_channels'] = data_cfg['num_output_channels']
41
+ self.ann_info['dataset_channel'] = data_cfg['dataset_channel']
42
+
43
+ self.db = []
44
+ self.num_shots = 1
45
+ self.paired_samples = []
46
+ self.pipeline = Compose(self.pipeline)
47
+
48
+ @abstractmethod
49
+ def _get_db(self):
50
+ """Load dataset."""
51
+ raise NotImplementedError
52
+
53
+ @abstractmethod
54
+ def _select_kpt(self, obj, kpt_id):
55
+ """Select kpt."""
56
+ raise NotImplementedError
57
+
58
+ @abstractmethod
59
+ def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
60
+ """Evaluate keypoint results."""
61
+ raise NotImplementedError
62
+
63
+ @staticmethod
64
+ def _write_keypoint_results(keypoints, res_file):
65
+ """Write results into a json file."""
66
+
67
+ with open(res_file, 'w') as f:
68
+ json.dump(keypoints, f, sort_keys=True, indent=4)
69
+
70
+ def _report_metric(self,
71
+ res_file,
72
+ metrics,
73
+ pck_thr=0.2,
74
+ pckh_thr=0.7,
75
+ auc_nor=30):
76
+ """Keypoint evaluation.
77
+
78
+ Args:
79
+ res_file (str): Json file stored prediction results.
80
+ metrics (str | list[str]): Metric to be performed.
81
+ Options: 'PCK', 'PCKh', 'AUC', 'EPE'.
82
+ pck_thr (float): PCK threshold, default as 0.2.
83
+ pckh_thr (float): PCKh threshold, default as 0.7.
84
+ auc_nor (float): AUC normalization factor, default as 30 pixel.
85
+
86
+ Returns:
87
+ List: Evaluation results for evaluation metric.
88
+ """
89
+ info_str = []
90
+
91
+ with open(res_file, 'r') as fin:
92
+ preds = json.load(fin)
93
+ assert len(preds) == len(self.paired_samples)
94
+
95
+ outputs = []
96
+ gts = []
97
+ masks = []
98
+ threshold_bbox = []
99
+ threshold_head_box = []
100
+
101
+ for pred, pair in zip(preds, self.paired_samples):
102
+ item = self.db[pair[-1]]
103
+ outputs.append(np.array(pred['keypoints'])[:, :-1])
104
+ gts.append(np.array(item['joints_3d'])[:, :-1])
105
+
106
+ mask_query = ((np.array(item['joints_3d_visible'])[:, 0]) > 0)
107
+ mask_sample = ((np.array(self.db[pair[0]]['joints_3d_visible'])[:, 0]) > 0)
108
+ for id_s in pair[:-1]:
109
+ mask_sample = np.bitwise_and(mask_sample, ((np.array(self.db[id_s]['joints_3d_visible'])[:, 0]) > 0))
110
+ masks.append(np.bitwise_and(mask_query, mask_sample))
111
+
112
+ if 'PCK' in metrics:
113
+ bbox = np.array(item['bbox'])
114
+ bbox_thr = np.max(bbox[2:])
115
+ threshold_bbox.append(np.array([bbox_thr, bbox_thr]))
116
+ if 'PCKh' in metrics:
117
+ head_box_thr = item['head_size']
118
+ threshold_head_box.append(
119
+ np.array([head_box_thr, head_box_thr]))
120
+
121
+ if 'PCK' in metrics:
122
+ pck_avg = []
123
+ for (output, gt, mask, thr_bbox) in zip(outputs, gts, masks, threshold_bbox):
124
+ _, pck, _ = keypoint_pck_accuracy(np.expand_dims(output, 0), np.expand_dims(gt,0), np.expand_dims(mask,0), pck_thr, np.expand_dims(thr_bbox,0))
125
+ pck_avg.append(pck)
126
+ info_str.append(('PCK', np.mean(pck_avg)))
127
+
128
+ return info_str
129
+
130
+ def _merge_obj(self, Xs_list, Xq, idx):
131
+ """ merge Xs_list and Xq.
132
+
133
+ :param Xs_list: N-shot samples X
134
+ :param Xq: query X
135
+ :param idx: id of paired_samples
136
+ :return: Xall
137
+ """
138
+ Xall = dict()
139
+ Xall['img_s'] = [Xs['img'] for Xs in Xs_list]
140
+ Xall['target_s'] = [Xs['target'] for Xs in Xs_list]
141
+ Xall['target_weight_s'] = [Xs['target_weight'] for Xs in Xs_list]
142
+ xs_img_metas = [Xs['img_metas'].data for Xs in Xs_list]
143
+
144
+ Xall['img_q'] = Xq['img']
145
+ Xall['target_q'] = Xq['target']
146
+ Xall['target_weight_q'] = Xq['target_weight']
147
+ xq_img_metas = Xq['img_metas'].data
148
+
149
+ img_metas = dict()
150
+ for key in xq_img_metas.keys():
151
+ img_metas['sample_' + key] = [xs_img_meta[key] for xs_img_meta in xs_img_metas]
152
+ img_metas['query_' + key] = xq_img_metas[key]
153
+ img_metas['bbox_id'] = idx
154
+
155
+ Xall['img_metas'] = DC(img_metas, cpu_only=True)
156
+
157
+ return Xall
158
+
159
+ def __len__(self):
160
+ """Get the size of the dataset."""
161
+ return len(self.paired_samples)
162
+
163
+ def __getitem__(self, idx):
164
+ """Get the sample given index."""
165
+
166
+ pair_ids = self.paired_samples[idx] # [supported id * shots, query id]
167
+ assert len(pair_ids) == self.num_shots + 1
168
+ sample_id_list = pair_ids[:self.num_shots]
169
+ query_id = pair_ids[-1]
170
+
171
+ sample_obj_list = []
172
+ for sample_id in sample_id_list:
173
+ sample_obj = copy.deepcopy(self.db[sample_id])
174
+ sample_obj['ann_info'] = copy.deepcopy(self.ann_info)
175
+ sample_obj_list.append(sample_obj)
176
+
177
+ query_obj = copy.deepcopy(self.db[query_id])
178
+ query_obj['ann_info'] = copy.deepcopy(self.ann_info)
179
+
180
+ Xs_list = []
181
+ for sample_obj in sample_obj_list:
182
+ Xs = self.pipeline(sample_obj) # dict with ['img', 'target', 'target_weight', 'img_metas'],
183
+ Xs_list.append(Xs) # Xs['target'] is of shape [100, map_h, map_w]
184
+ Xq = self.pipeline(query_obj)
185
+
186
+ Xall = self._merge_obj(Xs_list, Xq, idx)
187
+ Xall['skeleton'] = self.db[query_id]['skeleton']
188
+ Xall['rand_mask'] = self.rand_mask(Xall['target_weight_s'])
189
+ return Xall
190
+
191
+ def _sort_and_unique_bboxes(self, kpts, key='bbox_id'):
192
+ """sort kpts and remove the repeated ones."""
193
+ kpts = sorted(kpts, key=lambda x: x[key])
194
+ num = len(kpts)
195
+ for i in range(num - 1, 0, -1):
196
+ if kpts[i][key] == kpts[i - 1][key]:
197
+ del kpts[i]
198
+
199
+ return kpts
200
+
201
+ def rand_mask(self, target_weight_s):
202
+ mask_s = target_weight_s[0]
203
+ for target_weight in target_weight_s:
204
+ mask_s = mask_s * target_weight
205
+ num_to_mask = int(np.sum(mask_s) * self.masking_ratio)
206
+ true_indices = np.where(mask_s == 1)[0]
207
+ rand_mask = np.random.permutation(true_indices)[:num_to_mask]
208
+ mask_s[rand_mask] = 0
209
+ return mask_s
EdgeCape/datasets/datasets/mp100/transformer_dataset.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmpose.datasets import DATASETS
2
+ import random
3
+ import numpy as np
4
+ import os
5
+ from collections import OrderedDict
6
+ from xtcocotools.coco import COCO
7
+ from .transformer_base_dataset import TransformerBaseDataset
8
+
9
+ @DATASETS.register_module()
10
+ class TransformerPoseDataset(TransformerBaseDataset):
11
+
12
+ def __init__(self,
13
+ ann_file,
14
+ img_prefix,
15
+ data_cfg,
16
+ pipeline,
17
+ valid_class_ids,
18
+ max_kpt_num=None,
19
+ num_shots=1,
20
+ num_queries=100,
21
+ num_episodes=1,
22
+ test_mode=False):
23
+ super().__init__(
24
+ ann_file, img_prefix, data_cfg, pipeline, test_mode=test_mode)
25
+
26
+ self.ann_info['flip_pairs'] = []
27
+
28
+ self.ann_info['upper_body_ids'] = []
29
+ self.ann_info['lower_body_ids'] = []
30
+
31
+ self.ann_info['use_different_joint_weights'] = False
32
+ self.ann_info['joint_weights'] = np.array([1.,],
33
+ dtype=np.float32).reshape((self.ann_info['num_joints'], 1))
34
+
35
+ self.coco = COCO(ann_file)
36
+
37
+ self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
38
+ self.img_ids = self.coco.getImgIds()
39
+ self.classes = [
40
+ cat['name'] for cat in self.coco.loadCats(self.coco.getCatIds())
41
+ ]
42
+
43
+ self.num_classes = len(self.classes)
44
+ self._class_to_ind = dict(zip(self.classes, self.coco.getCatIds()))
45
+ self._ind_to_class = dict(zip(self.coco.getCatIds(), self.classes))
46
+
47
+ if valid_class_ids is not None: # None by default
48
+ self.valid_class_ids = valid_class_ids
49
+ else:
50
+ self.valid_class_ids = self.coco.getCatIds()
51
+ self.valid_classes = [self._ind_to_class[ind] for ind in self.valid_class_ids]
52
+
53
+ self.cats = self.coco.cats
54
+ self.max_kpt_num = max_kpt_num
55
+
56
+ # Also update self.cat2obj
57
+ self.db = self._get_db()
58
+
59
+ self.num_shots = num_shots
60
+
61
+ if not test_mode:
62
+ # Update every training epoch
63
+ self.random_paired_samples()
64
+ else:
65
+ self.num_queries = num_queries
66
+ self.num_episodes = num_episodes
67
+ self.make_paired_samples()
68
+
69
+
70
+ def random_paired_samples(self):
71
+ num_datas = [len(self.cat2obj[self._class_to_ind[cls]]) for cls in self.valid_classes]
72
+
73
+ # balance the dataset
74
+ max_num_data = max(num_datas)
75
+
76
+ all_samples = []
77
+ for cls in self.valid_class_ids:
78
+ for i in range(max_num_data):
79
+ shot = random.sample(self.cat2obj[cls], self.num_shots + 1)
80
+ all_samples.append(shot)
81
+
82
+ self.paired_samples = np.array(all_samples)
83
+ np.random.shuffle(self.paired_samples)
84
+
85
+ def make_paired_samples(self):
86
+ random.seed(1)
87
+ np.random.seed(0)
88
+
89
+ all_samples = []
90
+ for cls in self.valid_class_ids:
91
+ for _ in range(self.num_episodes):
92
+ shots = random.sample(self.cat2obj[cls], self.num_shots + self.num_queries)
93
+ sample_ids = shots[:self.num_shots]
94
+ query_ids = shots[self.num_shots:]
95
+ for query_id in query_ids:
96
+ all_samples.append(sample_ids + [query_id])
97
+
98
+ self.paired_samples = np.array(all_samples)
99
+
100
+ def _select_kpt(self, obj, kpt_id):
101
+ obj['joints_3d'] = obj['joints_3d'][kpt_id:kpt_id+1]
102
+ obj['joints_3d_visible'] = obj['joints_3d_visible'][kpt_id:kpt_id+1]
103
+ obj['kpt_id'] = kpt_id
104
+
105
+ return obj
106
+
107
+ @staticmethod
108
+ def _get_mapping_id_name(imgs):
109
+ """
110
+ Args:
111
+ imgs (dict): dict of image info.
112
+
113
+ Returns:
114
+ tuple: Image name & id mapping dicts.
115
+
116
+ - id2name (dict): Mapping image id to name.
117
+ - name2id (dict): Mapping image name to id.
118
+ """
119
+ id2name = {}
120
+ name2id = {}
121
+ for image_id, image in imgs.items():
122
+ file_name = image['file_name']
123
+ id2name[image_id] = file_name
124
+ name2id[file_name] = image_id
125
+
126
+ return id2name, name2id
127
+
128
+ def _get_db(self):
129
+ """Ground truth bbox and keypoints."""
130
+ self.obj_id = 0
131
+
132
+ self.cat2obj = {}
133
+ for i in self.coco.getCatIds():
134
+ self.cat2obj.update({i: []})
135
+
136
+ gt_db = []
137
+ for img_id in self.img_ids:
138
+ gt_db.extend(self._load_coco_keypoint_annotation_kernel(img_id))
139
+
140
+ return gt_db
141
+
142
+ def _load_coco_keypoint_annotation_kernel(self, img_id):
143
+ """load annotation from COCOAPI.
144
+
145
+ Note:
146
+ bbox:[x1, y1, w, h]
147
+ Args:
148
+ img_id: coco image id
149
+ Returns:
150
+ dict: db entry
151
+ """
152
+ img_ann = self.coco.loadImgs(img_id)[0]
153
+ width = img_ann['width']
154
+ height = img_ann['height']
155
+
156
+ ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=False)
157
+ objs = self.coco.loadAnns(ann_ids)
158
+
159
+ # sanitize bboxes
160
+ valid_objs = []
161
+ for obj in objs:
162
+ if 'bbox' not in obj:
163
+ continue
164
+ x, y, w, h = obj['bbox']
165
+ x1 = max(0, x)
166
+ y1 = max(0, y)
167
+ x2 = min(width - 1, x1 + max(0, w - 1))
168
+ y2 = min(height - 1, y1 + max(0, h - 1))
169
+ if ('area' not in obj or obj['area'] > 0) and x2 > x1 and y2 > y1:
170
+ obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
171
+ valid_objs.append(obj)
172
+ objs = valid_objs
173
+
174
+ bbox_id = 0
175
+ rec = []
176
+ for obj in objs:
177
+ if 'keypoints' not in obj:
178
+ continue
179
+ if max(obj['keypoints']) == 0:
180
+ continue
181
+ if 'num_keypoints' in obj and obj['num_keypoints'] == 0:
182
+ continue
183
+
184
+ category_id = obj['category_id']
185
+ # the number of keypoint for this specific category
186
+ cat_kpt_num = int(len(obj['keypoints']) / 3)
187
+ if self.max_kpt_num is None:
188
+ kpt_num = cat_kpt_num
189
+ else:
190
+ kpt_num = self.max_kpt_num
191
+
192
+ joints_3d = np.zeros((kpt_num, 3), dtype=np.float32)
193
+ joints_3d_visible = np.zeros((kpt_num, 3), dtype=np.float32)
194
+
195
+ keypoints = np.array(obj['keypoints']).reshape(-1, 3)
196
+ joints_3d[:cat_kpt_num, :2] = keypoints[:, :2]
197
+ joints_3d_visible[:cat_kpt_num, :2] = np.minimum(1, keypoints[:, 2:3])
198
+
199
+ center, scale = self._xywh2cs(*obj['clean_bbox'][:4])
200
+
201
+ image_file = os.path.join(self.img_prefix, self.id2name[img_id])
202
+ if os.path.exists(image_file):
203
+ self.cat2obj[category_id].append(self.obj_id)
204
+
205
+ rec.append({
206
+ 'image_file': image_file,
207
+ 'center': center,
208
+ 'scale': scale,
209
+ 'rotation': 0,
210
+ 'bbox': obj['clean_bbox'][:4],
211
+ 'bbox_score': 1,
212
+ 'joints_3d': joints_3d,
213
+ 'joints_3d_visible': joints_3d_visible,
214
+ 'category_id': category_id,
215
+ 'cat_kpt_num': cat_kpt_num,
216
+ 'bbox_id': self.obj_id,
217
+ 'skeleton': self.coco.cats[obj['category_id']]['skeleton'],
218
+ })
219
+ bbox_id = bbox_id + 1
220
+ self.obj_id += 1
221
+
222
+ return rec
223
+
224
+ def _xywh2cs(self, x, y, w, h):
225
+ """This encodes bbox(x,y,w,w) into (center, scale)
226
+
227
+ Args:
228
+ x, y, w, h
229
+
230
+ Returns:
231
+ tuple: A tuple containing center and scale.
232
+
233
+ - center (np.ndarray[float32](2,)): center of the bbox (x, y).
234
+ - scale (np.ndarray[float32](2,)): scale of the bbox w & h.
235
+ """
236
+ aspect_ratio = self.ann_info['image_size'][0] / self.ann_info['image_size'][1]
237
+ center = np.array([x + w * 0.5, y + h * 0.5], dtype=np.float32)
238
+ #
239
+ # if (not self.test_mode) and np.random.rand() < 0.3:
240
+ # center += 0.4 * (np.random.rand(2) - 0.5) * [w, h]
241
+
242
+ if w > aspect_ratio * h:
243
+ h = w * 1.0 / aspect_ratio
244
+ elif w < aspect_ratio * h:
245
+ w = h * aspect_ratio
246
+
247
+ # pixel std is 200.0
248
+ scale = np.array([w / 200.0, h / 200.0], dtype=np.float32)
249
+ # padding to include proper amount of context
250
+ scale = scale * 1.25
251
+
252
+ return center, scale
253
+
254
+ def evaluate(self, outputs, res_folder, metric='PCK', **kwargs):
255
+ """Evaluate interhand2d keypoint results. The pose prediction results
256
+ will be saved in `${res_folder}/result_keypoints.json`.
257
+
258
+ Note:
259
+ batch_size: N
260
+ num_keypoints: K
261
+ heatmap height: H
262
+ heatmap width: W
263
+
264
+ Args:
265
+ outputs (list(preds, boxes, image_path, output_heatmap))
266
+ :preds (np.ndarray[N,K,3]): The first two dimensions are
267
+ coordinates, score is the third dimension of the array.
268
+ :boxes (np.ndarray[N,6]): [center[0], center[1], scale[0]
269
+ , scale[1],area, score]
270
+ :image_paths (list[str]): For example, ['C', 'a', 'p', 't',
271
+ 'u', 'r', 'e', '1', '2', '/', '0', '3', '9', '0', '_',
272
+ 'd', 'h', '_', 't', 'o', 'u', 'c', 'h', 'R', 'O', 'M',
273
+ '/', 'c', 'a', 'm', '4', '1', '0', '2', '0', '9', '/',
274
+ 'i', 'm', 'a', 'g', 'e', '6', '2', '4', '3', '4', '.',
275
+ 'j', 'p', 'g']
276
+ :output_heatmap (np.ndarray[N, K, H, W]): model outpus.
277
+
278
+ res_folder (str): Path of directory to save the results.
279
+ metric (str | list[str]): Metric to be performed.
280
+ Options: 'PCK', 'AUC', 'EPE'.
281
+
282
+ Returns:
283
+ dict: Evaluation results for evaluation metric.
284
+ """
285
+ metrics = metric if isinstance(metric, list) else [metric]
286
+ allowed_metrics = ['PCK', 'AUC', 'EPE', 'NME']
287
+ for metric in metrics:
288
+ if metric not in allowed_metrics:
289
+ raise KeyError(f'metric {metric} is not supported')
290
+
291
+ res_file = os.path.join(res_folder, 'result_keypoints.json')
292
+
293
+ kpts = []
294
+ for output in outputs:
295
+ preds = output['preds']
296
+ boxes = output['boxes']
297
+ image_paths = output['image_paths']
298
+ bbox_ids = output['bbox_ids']
299
+
300
+ batch_size = len(image_paths)
301
+ for i in range(batch_size):
302
+ image_id = self.name2id[image_paths[i][len(self.img_prefix):]]
303
+
304
+ kpts.append({
305
+ 'keypoints': preds[i].tolist(),
306
+ 'center': boxes[i][0:2].tolist(),
307
+ 'scale': boxes[i][2:4].tolist(),
308
+ 'area': float(boxes[i][4]),
309
+ 'score': float(boxes[i][5]),
310
+ 'image_id': image_id,
311
+ 'bbox_id': bbox_ids[i]
312
+ })
313
+ kpts = self._sort_and_unique_bboxes(kpts)
314
+
315
+ self._write_keypoint_results(kpts, res_file)
316
+ info_str = self._report_metric(res_file, metrics)
317
+ name_value = OrderedDict(info_str)
318
+
319
+ return name_value
EdgeCape/datasets/pipelines/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .top_down_transform import (TopDownAffineFewShot,
2
+ TopDownGenerateTargetFewShot,
3
+ LoadDepthFromFile,
4
+ DepthTopDownAffineFewShot)
5
+
6
+ __all__ = [
7
+ 'TopDownGenerateTargetFewShot', 'TopDownAffineFewShot', 'LoadDepthFromFile', 'DepthTopDownAffineFewShot',
8
+ ]
EdgeCape/datasets/pipelines/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (373 Bytes). View file
 
EdgeCape/datasets/pipelines/__pycache__/post_transforms.cpython-39.pyc ADDED
Binary file (3.39 kB). View file
 
EdgeCape/datasets/pipelines/__pycache__/top_down_transform.cpython-39.pyc ADDED
Binary file (18 kB). View file
 
EdgeCape/datasets/pipelines/post_transforms.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Adapted from https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
3
+ # Original licence: Copyright (c) Microsoft, under the MIT License.
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+
10
+ def get_affine_transform(center,
11
+ scale,
12
+ rot,
13
+ output_size,
14
+ shift=(0., 0.),
15
+ inv=False):
16
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
17
+
18
+ Args:
19
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
20
+ scale (np.ndarray[2, ]): Scale of the bounding box
21
+ wrt [width, height].
22
+ rot (float): Rotation angle (degree).
23
+ output_size (np.ndarray[2, ]): Size of the destination heatmaps.
24
+ shift (0-100%): Shift translation ratio wrt the width/height.
25
+ Default (0., 0.).
26
+ inv (bool): Option to inverse the affine transform direction.
27
+ (inv=False: src->dst or inv=True: dst->src)
28
+
29
+ Returns:
30
+ np.ndarray: The transform matrix.
31
+ """
32
+ assert len(center) == 2
33
+ assert len(scale) == 2
34
+ assert len(output_size) == 2
35
+ assert len(shift) == 2
36
+
37
+ # pixel_std is 200.
38
+ scale_tmp = scale * 200.0
39
+
40
+ shift = np.array(shift)
41
+ src_w = scale_tmp[0]
42
+ dst_w = output_size[0]
43
+ dst_h = output_size[1]
44
+
45
+ rot_rad = np.pi * rot / 180
46
+ src_dir = rotate_point([0., src_w * -0.5], rot_rad)
47
+ dst_dir = np.array([0., dst_w * -0.5])
48
+
49
+ src = np.zeros((3, 2), dtype=np.float32)
50
+ src[0, :] = center + scale_tmp * shift
51
+ src[1, :] = center + src_dir + scale_tmp * shift
52
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
53
+
54
+ dst = np.zeros((3, 2), dtype=np.float32)
55
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
56
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
57
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
58
+
59
+ if inv:
60
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
61
+ else:
62
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
63
+
64
+ return trans
65
+
66
+
67
+ def affine_transform(pt, trans_mat):
68
+ """Apply an affine transformation to the points.
69
+
70
+ Args:
71
+ pt (np.ndarray): a 2 dimensional point to be transformed
72
+ trans_mat (np.ndarray): 2x3 matrix of an affine transform
73
+
74
+ Returns:
75
+ np.ndarray: Transformed points.
76
+ """
77
+ assert len(pt) == 2
78
+ new_pt = np.array(trans_mat) @ np.array([pt[0], pt[1], 1.])
79
+
80
+ return new_pt
81
+
82
+
83
+ def _get_3rd_point(a, b):
84
+ """To calculate the affine matrix, three pairs of points are required. This
85
+ function is used to get the 3rd point, given 2D points a & b.
86
+
87
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
88
+ anticlockwise, using b as the rotation center.
89
+
90
+ Args:
91
+ a (np.ndarray): point(x,y)
92
+ b (np.ndarray): point(x,y)
93
+
94
+ Returns:
95
+ np.ndarray: The 3rd point.
96
+ """
97
+ assert len(a) == 2
98
+ assert len(b) == 2
99
+ direction = a - b
100
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
101
+
102
+ return third_pt
103
+
104
+
105
+ def rotate_point(pt, angle_rad):
106
+ """Rotate a point by an angle.
107
+
108
+ Args:
109
+ pt (list[float]): 2 dimensional point to be rotated
110
+ angle_rad (float): rotation angle by radian
111
+
112
+ Returns:
113
+ list[float]: Rotated point.
114
+ """
115
+ assert len(pt) == 2
116
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
117
+ new_x = pt[0] * cs - pt[1] * sn
118
+ new_y = pt[0] * sn + pt[1] * cs
119
+ rotated_pt = [new_x, new_y]
120
+
121
+ return rotated_pt
EdgeCape/datasets/pipelines/top_down_transform.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import cv2
7
+ import mmcv
8
+ import numpy as np
9
+ from mmcv import fileio
10
+
11
+ from mmpose.datasets.builder import PIPELINES
12
+ from .post_transforms import (affine_transform,
13
+ get_affine_transform)
14
+ from mmpose.core.post_processing import (affine_transform, fliplr_joints,
15
+ get_affine_transform, get_warp_matrix,
16
+ warp_affine_joints)
17
+
18
+ @PIPELINES.register_module()
19
+ class TopDownAffineFewShot:
20
+ """Affine transform the image to make input.
21
+
22
+ Required keys:'img', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
23
+ 'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
24
+ 'joints_3d_visible'.
25
+
26
+ Args:
27
+ use_udp (bool): To use unbiased data processing.
28
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
29
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
30
+ """
31
+
32
+ def __init__(self, use_udp=False):
33
+ self.use_udp = use_udp
34
+
35
+ def __call__(self, results):
36
+ image_size = results['ann_info']['image_size']
37
+
38
+ img = results['img']
39
+ joints_3d = results['joints_3d']
40
+ joints_3d_visible = results['joints_3d_visible']
41
+ c = results['center']
42
+ s = results['scale']
43
+ r = results['rotation']
44
+
45
+ if self.use_udp:
46
+ trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
47
+ img = cv2.warpAffine(
48
+ img,
49
+ trans, (int(image_size[0]), int(image_size[1])),
50
+ flags=cv2.INTER_LINEAR)
51
+ joints_3d[:, 0:2] = \
52
+ warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
53
+ else:
54
+ trans = get_affine_transform(c, s, r, image_size)
55
+ img = cv2.warpAffine(
56
+ img,
57
+ trans, (int(image_size[0]), int(image_size[1])),
58
+ flags=cv2.INTER_LINEAR)
59
+ for i in range(len(joints_3d)):
60
+ if joints_3d_visible[i, 0] > 0.0:
61
+ joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
62
+
63
+ results['img'] = img
64
+ results['joints_3d'] = joints_3d
65
+ results['joints_3d_visible'] = joints_3d_visible
66
+
67
+ return results
68
+
69
+
70
+ @PIPELINES.register_module()
71
+ class TopDownGenerateTargetFewShot:
72
+ """Generate the target heatmap.
73
+
74
+ Required keys: 'joints_3d', 'joints_3d_visible', 'ann_info'.
75
+ Modified keys: 'target', and 'target_weight'.
76
+
77
+ Args:
78
+ sigma: Sigma of heatmap gaussian for 'MSRA' approach.
79
+ kernel: Kernel of heatmap gaussian for 'Megvii' approach.
80
+ encoding (str): Approach to generate target heatmaps.
81
+ Currently supported approaches: 'MSRA', 'Megvii', 'UDP'.
82
+ Default:'MSRA'
83
+
84
+ unbiased_encoding (bool): Option to use unbiased
85
+ encoding methods.
86
+ Paper ref: Zhang et al. Distribution-Aware Coordinate
87
+ Representation for Human Pose Estimation (CVPR 2020).
88
+ keypoint_pose_distance: Keypoint pose distance for UDP.
89
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
90
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
91
+ target_type (str): supported targets: 'GaussianHeatMap',
92
+ 'CombinedTarget'. Default:'GaussianHeatMap'
93
+ CombinedTarget: The combination of classification target
94
+ (response map) and regression target (offset map).
95
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
96
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
97
+ """
98
+
99
+ def __init__(self,
100
+ sigma=2,
101
+ kernel=(11, 11),
102
+ valid_radius_factor=0.0546875,
103
+ target_type='GaussianHeatMap',
104
+ encoding='MSRA',
105
+ unbiased_encoding=False):
106
+ self.sigma = sigma
107
+ self.unbiased_encoding = unbiased_encoding
108
+ self.kernel = kernel
109
+ self.valid_radius_factor = valid_radius_factor
110
+ self.target_type = target_type
111
+ self.encoding = encoding
112
+
113
+ def _msra_generate_target(self, cfg, joints_3d, joints_3d_visible, sigma):
114
+ """Generate the target heatmap via "MSRA" approach.
115
+
116
+ Args:
117
+ cfg (dict): data config
118
+ joints_3d: np.ndarray ([num_joints, 3])
119
+ joints_3d_visible: np.ndarray ([num_joints, 3])
120
+ sigma: Sigma of heatmap gaussian
121
+ Returns:
122
+ tuple: A tuple containing targets.
123
+
124
+ - target: Target heatmaps.
125
+ - target_weight: (1: visible, 0: invisible)
126
+ """
127
+ num_joints = len(joints_3d)
128
+ image_size = cfg['image_size']
129
+ W, H = cfg['heatmap_size']
130
+ joint_weights = cfg['joint_weights']
131
+ use_different_joint_weights = cfg['use_different_joint_weights']
132
+ assert not use_different_joint_weights
133
+
134
+ target_weight = np.zeros((num_joints, 1), dtype=np.float32)
135
+ target = np.zeros((num_joints, H, W), dtype=np.float32)
136
+
137
+ # 3-sigma rule
138
+ tmp_size = sigma * 3
139
+
140
+ if self.unbiased_encoding:
141
+ for joint_id in range(num_joints):
142
+ target_weight[joint_id] = joints_3d_visible[joint_id, 0]
143
+
144
+ feat_stride = image_size / [W, H]
145
+ mu_x = joints_3d[joint_id][0] / feat_stride[0]
146
+ mu_y = joints_3d[joint_id][1] / feat_stride[1]
147
+ # Check that any part of the gaussian is in-bounds
148
+ ul = [mu_x - tmp_size, mu_y - tmp_size]
149
+ br = [mu_x + tmp_size + 1, mu_y + tmp_size + 1]
150
+ if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
151
+ target_weight[joint_id] = 0
152
+
153
+ if target_weight[joint_id] == 0:
154
+ continue
155
+
156
+ x = np.arange(0, W, 1, np.float32)
157
+ y = np.arange(0, H, 1, np.float32)
158
+ y = y[:, None]
159
+
160
+ if target_weight[joint_id] > 0.5:
161
+ target[joint_id] = np.exp(-((x - mu_x)**2 +
162
+ (y - mu_y)**2) /
163
+ (2 * sigma**2))
164
+ else:
165
+ for joint_id in range(num_joints):
166
+ target_weight[joint_id] = joints_3d_visible[joint_id, 0]
167
+
168
+ feat_stride = image_size / [W, H]
169
+ mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
170
+ mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
171
+ # Check that any part of the gaussian is in-bounds
172
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
173
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
174
+ if ul[0] >= W or ul[1] >= H or br[0] < 0 or br[1] < 0:
175
+ target_weight[joint_id] = 0
176
+
177
+ if target_weight[joint_id] > 0.5:
178
+ size = 2 * tmp_size + 1
179
+ x = np.arange(0, size, 1, np.float32)
180
+ y = x[:, None]
181
+ x0 = y0 = size // 2
182
+ # The gaussian is not normalized,
183
+ # we want the center value to equal 1
184
+ g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * sigma**2))
185
+
186
+ # Usable gaussian range
187
+ g_x = max(0, -ul[0]), min(br[0], W) - ul[0]
188
+ g_y = max(0, -ul[1]), min(br[1], H) - ul[1]
189
+ # Image range
190
+ img_x = max(0, ul[0]), min(br[0], W)
191
+ img_y = max(0, ul[1]), min(br[1], H)
192
+
193
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
194
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
195
+
196
+ if use_different_joint_weights:
197
+ target_weight = np.multiply(target_weight, joint_weights)
198
+
199
+ return target, target_weight
200
+
201
+ def _udp_generate_target(self, cfg, joints_3d, joints_3d_visible, factor,
202
+ target_type):
203
+ """Generate the target heatmap via 'UDP' approach. Paper ref: Huang et
204
+ al. The Devil is in the Details: Delving into Unbiased Data Processing
205
+ for Human Pose Estimation (CVPR 2020).
206
+
207
+ Note:
208
+ num keypoints: K
209
+ heatmap height: H
210
+ heatmap width: W
211
+ num target channels: C
212
+ C = K if target_type=='GaussianHeatMap'
213
+ C = 3*K if target_type=='CombinedTarget'
214
+
215
+ Args:
216
+ cfg (dict): data config
217
+ joints_3d (np.ndarray[K, 3]): Annotated keypoints.
218
+ joints_3d_visible (np.ndarray[K, 3]): Visibility of keypoints.
219
+ factor (float): kernel factor for GaussianHeatMap target or
220
+ valid radius factor for CombinedTarget.
221
+ target_type (str): 'GaussianHeatMap' or 'CombinedTarget'.
222
+ GaussianHeatMap: Heatmap target with gaussian distribution.
223
+ CombinedTarget: The combination of classification target
224
+ (response map) and regression target (offset map).
225
+
226
+ Returns:
227
+ tuple: A tuple containing targets.
228
+
229
+ - target (np.ndarray[C, H, W]): Target heatmaps.
230
+ - target_weight (np.ndarray[K, 1]): (1: visible, 0: invisible)
231
+ """
232
+ num_joints = len(joints_3d)
233
+ image_size = cfg['image_size']
234
+ heatmap_size = cfg['heatmap_size']
235
+ joint_weights = cfg['joint_weights']
236
+ use_different_joint_weights = cfg['use_different_joint_weights']
237
+ assert not use_different_joint_weights
238
+
239
+ target_weight = np.ones((num_joints, 1), dtype=np.float32)
240
+ target_weight[:, 0] = joints_3d_visible[:, 0]
241
+
242
+ assert target_type in ['GaussianHeatMap', 'CombinedTarget']
243
+
244
+ if target_type == 'GaussianHeatMap':
245
+ target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]),
246
+ dtype=np.float32)
247
+
248
+ tmp_size = factor * 3
249
+
250
+ # prepare for gaussian
251
+ size = 2 * tmp_size + 1
252
+ x = np.arange(0, size, 1, np.float32)
253
+ y = x[:, None]
254
+
255
+ for joint_id in range(num_joints):
256
+ feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
257
+ mu_x = int(joints_3d[joint_id][0] / feat_stride[0] + 0.5)
258
+ mu_y = int(joints_3d[joint_id][1] / feat_stride[1] + 0.5)
259
+ # Check that any part of the gaussian is in-bounds
260
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
261
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
262
+ if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
263
+ or br[0] < 0 or br[1] < 0:
264
+ # If not, just return the image as is
265
+ target_weight[joint_id] = 0
266
+ continue
267
+
268
+ # # Generate gaussian
269
+ mu_x_ac = joints_3d[joint_id][0] / feat_stride[0]
270
+ mu_y_ac = joints_3d[joint_id][1] / feat_stride[1]
271
+ x0 = y0 = size // 2
272
+ x0 += mu_x_ac - mu_x
273
+ y0 += mu_y_ac - mu_y
274
+ g = np.exp(-((x - x0)**2 + (y - y0)**2) / (2 * factor**2))
275
+
276
+ # Usable gaussian range
277
+ g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
278
+ g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
279
+ # Image range
280
+ img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
281
+ img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
282
+
283
+ v = target_weight[joint_id]
284
+ if v > 0.5:
285
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
286
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
287
+ elif target_type == 'CombinedTarget':
288
+ target = np.zeros(
289
+ (num_joints, 3, heatmap_size[1] * heatmap_size[0]),
290
+ dtype=np.float32)
291
+ feat_width = heatmap_size[0]
292
+ feat_height = heatmap_size[1]
293
+ feat_x_int = np.arange(0, feat_width)
294
+ feat_y_int = np.arange(0, feat_height)
295
+ feat_x_int, feat_y_int = np.meshgrid(feat_x_int, feat_y_int)
296
+ feat_x_int = feat_x_int.flatten()
297
+ feat_y_int = feat_y_int.flatten()
298
+ # Calculate the radius of the positive area in classification
299
+ # heatmap.
300
+ valid_radius = factor * heatmap_size[1]
301
+ feat_stride = (image_size - 1.0) / (heatmap_size - 1.0)
302
+ for joint_id in range(num_joints):
303
+ mu_x = joints_3d[joint_id][0] / feat_stride[0]
304
+ mu_y = joints_3d[joint_id][1] / feat_stride[1]
305
+ x_offset = (mu_x - feat_x_int) / valid_radius
306
+ y_offset = (mu_y - feat_y_int) / valid_radius
307
+ dis = x_offset**2 + y_offset**2
308
+ keep_pos = np.where(dis <= 1)[0]
309
+ v = target_weight[joint_id]
310
+ if v > 0.5:
311
+ target[joint_id, 0, keep_pos] = 1
312
+ target[joint_id, 1, keep_pos] = x_offset[keep_pos]
313
+ target[joint_id, 2, keep_pos] = y_offset[keep_pos]
314
+ target = target.reshape(num_joints * 3, heatmap_size[1],
315
+ heatmap_size[0])
316
+
317
+ if use_different_joint_weights:
318
+ target_weight = np.multiply(target_weight, joint_weights)
319
+
320
+ return target, target_weight
321
+
322
+ def __call__(self, results):
323
+ """Generate the target heatmap."""
324
+ joints_3d = results['joints_3d']
325
+ joints_3d_visible = results['joints_3d_visible']
326
+
327
+ assert self.encoding in ['MSRA', 'UDP']
328
+
329
+ if self.encoding == 'MSRA':
330
+ if isinstance(self.sigma, list):
331
+ num_sigmas = len(self.sigma)
332
+ cfg = results['ann_info']
333
+ num_joints = len(joints_3d)
334
+ heatmap_size = cfg['heatmap_size']
335
+
336
+ target = np.empty(
337
+ (0, num_joints, heatmap_size[1], heatmap_size[0]),
338
+ dtype=np.float32)
339
+ target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
340
+ for i in range(num_sigmas):
341
+ target_i, target_weight_i = self._msra_generate_target(
342
+ cfg, joints_3d, joints_3d_visible, self.sigma[i])
343
+ target = np.concatenate([target, target_i[None]], axis=0)
344
+ target_weight = np.concatenate(
345
+ [target_weight, target_weight_i[None]], axis=0)
346
+ else:
347
+ target, target_weight = self._msra_generate_target(
348
+ results['ann_info'], joints_3d, joints_3d_visible,
349
+ self.sigma)
350
+ elif self.encoding == 'UDP':
351
+ if self.target_type == 'CombinedTarget':
352
+ factors = self.valid_radius_factor
353
+ channel_factor = 3
354
+ elif self.target_type == 'GaussianHeatMap':
355
+ factors = self.sigma
356
+ channel_factor = 1
357
+ if isinstance(factors, list):
358
+ num_factors = len(factors)
359
+ cfg = results['ann_info']
360
+ num_joints = len(joints_3d)
361
+ W, H = cfg['heatmap_size']
362
+
363
+ target = np.empty((0, channel_factor * num_joints, H, W),
364
+ dtype=np.float32)
365
+ target_weight = np.empty((0, num_joints, 1), dtype=np.float32)
366
+ for i in range(num_factors):
367
+ target_i, target_weight_i = self._udp_generate_target(
368
+ cfg, joints_3d, joints_3d_visible, factors[i],
369
+ self.target_type)
370
+ target = np.concatenate([target, target_i[None]], axis=0)
371
+ target_weight = np.concatenate(
372
+ [target_weight, target_weight_i[None]], axis=0)
373
+ else:
374
+ target, target_weight = self._udp_generate_target(
375
+ results['ann_info'], joints_3d, joints_3d_visible, factors,
376
+ self.target_type)
377
+ else:
378
+ raise ValueError(
379
+ f'Encoding approach {self.encoding} is not supported!')
380
+
381
+ results['target'] = target
382
+ results['target_weight'] = target_weight
383
+
384
+ return results
385
+
386
+ @PIPELINES.register_module()
387
+ class LoadDepthFromFile:
388
+ """Load depthmap from file.
389
+
390
+ Required Keys:
391
+
392
+ - depth_path
393
+
394
+ Modified Keys:
395
+
396
+ - depth
397
+
398
+ Args:
399
+ to_float32 (bool): Whether to convert the loaded depth to a float32
400
+ numpy array. If set to False, the loaded depth is an uint8 array.
401
+ Defaults to False.
402
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
403
+ Defaults to 'color'.
404
+ imdecode_backend (str): The depth decoding backend type. The backend
405
+ argument for :func:`mmcv.imfrombytes`.
406
+ See :func:`mmcv.imfrombytes` for details.
407
+ Defaults to 'cv2'.
408
+ file_client_args (dict, optional): Arguments to instantiate a
409
+ FileClient. See :class:`mmengine.fileio.FileClient` for details.
410
+ Defaults to None. It will be deprecated in future. Please use
411
+ ``backend_args`` instead.
412
+ Deprecated in version 2.0.0rc4.
413
+ ignore_empty (bool): Whether to allow loading empty depth or file path
414
+ not existent. Defaults to False.
415
+ backend_args (dict, optional): Instantiates the corresponding file
416
+ backend. It may contain `backend` key to specify the file
417
+ backend. If it contains, the file backend corresponding to this
418
+ value will be used and initialized with the remaining values,
419
+ otherwise the corresponding file backend will be selected
420
+ based on the prefix of the file path. Defaults to None.
421
+ New in version 2.0.0rc4.
422
+ """
423
+
424
+ def __init__(self,
425
+ to_float32=False,
426
+ color_type='color',
427
+ channel_order='rgb',
428
+ file_client_args=dict(backend='disk')):
429
+ self.to_float32 = to_float32
430
+ self.color_type = color_type
431
+ self.channel_order = channel_order
432
+ self.file_client_args = file_client_args.copy()
433
+ self.file_client = None
434
+
435
+ def _read_depth(self, path):
436
+ img = np.load(path)['depth']
437
+ if img is None:
438
+ raise ValueError(f'Fail to read {path}')
439
+ if self.to_float32:
440
+ img = img.astype(np.float32)
441
+ return img
442
+
443
+ def __call__(self, results: dict) -> Optional[dict]:
444
+ """Functions to load depth.
445
+
446
+ Args:
447
+ results (dict): Result dict from
448
+ :class:`mmengine.dataset.BaseDataset`.
449
+
450
+ Returns:
451
+ dict: The dict contains loaded depth and meta information.
452
+ """
453
+
454
+ """Loading depth(s) from file."""
455
+ if self.file_client is None:
456
+ self.file_client = mmcv.FileClient(**self.file_client_args)
457
+
458
+ depth_file = results.get('depth_file', None)
459
+ # Replace file extension with npy
460
+ pre, ext = os.path.splitext(depth_file)
461
+ depth_file = pre + '.npz'
462
+ if isinstance(depth_file, (list, tuple)):
463
+ # Load depths from a list of paths
464
+ results['depth'] = [self._read_depth(path) for path in depth_file]
465
+ elif depth_file is not None:
466
+ # Load single depth from path
467
+ results['depth'] = self._read_depth(depth_file)
468
+ else:
469
+ if 'depth' not in results:
470
+ # If `depth_file`` is not in results, check the `img` exists
471
+ # and format the depth. This for compatibility when the depth
472
+ # is manually set outside the pipeline.
473
+ raise KeyError('Either `depth_file` or `img` should exist in '
474
+ 'results.')
475
+ if isinstance(results['depth'], (list, tuple)):
476
+ assert isinstance(results['depth'][0], np.ndarray)
477
+ else:
478
+ assert isinstance(results['depth'], np.ndarray)
479
+ results['depth_file'] = None
480
+
481
+ return results
482
+
483
+ def __repr__(self):
484
+ repr_str = (f'{self.__class__.__name__}('
485
+ f'to_float32={self.to_float32}, '
486
+ f"color_type='{self.color_type}', "
487
+ f'file_client_args={self.file_client_args})')
488
+ return repr_str
489
+
490
+
491
+ @PIPELINES.register_module()
492
+ class DepthTopDownAffineFewShot:
493
+ """Affine transform the image to make input.
494
+
495
+ Required keys:'img', 'depth', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
496
+ 'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
497
+ 'joints_3d_visible'.
498
+
499
+ Args:
500
+ use_udp (bool): To use unbiased data processing.
501
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
502
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
503
+ """
504
+
505
+ def __init__(self, use_udp=False):
506
+ self.use_udp = use_udp
507
+
508
+ def __call__(self, results):
509
+ image_size = results['ann_info']['image_size']
510
+
511
+ img = results['img']
512
+ depth = results['depth']
513
+ joints_3d = results['joints_3d']
514
+ joints_3d_visible = results['joints_3d_visible']
515
+ c = results['center']
516
+ s = results['scale']
517
+ r = results['rotation']
518
+
519
+ if self.use_udp:
520
+ trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
521
+ img = cv2.warpAffine(
522
+ img,
523
+ trans, (int(image_size[0]), int(image_size[1])),
524
+ flags=cv2.INTER_LINEAR)
525
+ depth = cv2.warpAffine(
526
+ depth,
527
+ trans, (int(image_size[0]), int(image_size[1])),
528
+ flags=cv2.INTER_LINEAR)
529
+ joints_3d[:, 0:2] = warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
530
+ else:
531
+ trans = get_affine_transform(c, s, r, image_size)
532
+ img = cv2.warpAffine(
533
+ img,
534
+ trans, (int(image_size[0]), int(image_size[1])),
535
+ flags=cv2.INTER_LINEAR)
536
+ depth = cv2.warpAffine(
537
+ depth,
538
+ trans, (int(image_size[0]), int(image_size[1])),
539
+ flags=cv2.INTER_LINEAR)
540
+ for i in range(len(joints_3d)):
541
+ if joints_3d_visible[i, 0] > 0.0:
542
+ joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
543
+
544
+ results['img'] = img
545
+ results['depth'] = depth
546
+ results['joints_3d'] = joints_3d
547
+ results['joints_3d_visible'] = joints_3d_visible
548
+
549
+ return results
550
+
551
+
552
+
553
+
554
+ @PIPELINES.register_module()
555
+ class LoadFeatFromFile:
556
+ """Load depthmap from file.
557
+
558
+ Required Keys:
559
+
560
+ - depth_path
561
+
562
+ Modified Keys:
563
+
564
+ - depth
565
+
566
+ Args:
567
+ to_float32 (bool): Whether to convert the loaded depth to a float32
568
+ numpy array. If set to False, the loaded depth is an uint8 array.
569
+ Defaults to False.
570
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
571
+ Defaults to 'color'.
572
+ imdecode_backend (str): The depth decoding backend type. The backend
573
+ argument for :func:`mmcv.imfrombytes`.
574
+ See :func:`mmcv.imfrombytes` for details.
575
+ Defaults to 'cv2'.
576
+ file_client_args (dict, optional): Arguments to instantiate a
577
+ FileClient. See :class:`mmengine.fileio.FileClient` for details.
578
+ Defaults to None. It will be deprecated in future. Please use
579
+ ``backend_args`` instead.
580
+ Deprecated in version 2.0.0rc4.
581
+ ignore_empty (bool): Whether to allow loading empty depth or file path
582
+ not existent. Defaults to False.
583
+ backend_args (dict, optional): Instantiates the corresponding file
584
+ backend. It may contain `backend` key to specify the file
585
+ backend. If it contains, the file backend corresponding to this
586
+ value will be used and initialized with the remaining values,
587
+ otherwise the corresponding file backend will be selected
588
+ based on the prefix of the file path. Defaults to None.
589
+ New in version 2.0.0rc4.
590
+ """
591
+
592
+ def __init__(self,
593
+ to_float32=False,
594
+ color_type='color',
595
+ channel_order='rgb',
596
+ file_client_args=dict(backend='disk')):
597
+ self.to_float32 = to_float32
598
+ self.color_type = color_type
599
+ self.channel_order = channel_order
600
+ self.file_client_args = file_client_args.copy()
601
+ self.file_client = None
602
+
603
+ def _read_depth(self, path):
604
+ img = np.load(path)['feat']
605
+ if img is None:
606
+ raise ValueError(f'Fail to read {path}')
607
+ if self.to_float32:
608
+ img = img.astype(np.float32)
609
+ return img
610
+
611
+ def __call__(self, results: dict) -> Optional[dict]:
612
+ """Functions to load depth.
613
+
614
+ Args:
615
+ results (dict): Result dict from
616
+ :class:`mmengine.dataset.BaseDataset`.
617
+
618
+ Returns:
619
+ dict: The dict contains loaded depth and meta information.
620
+ """
621
+
622
+ """Loading depth(s) from file."""
623
+ if self.file_client is None:
624
+ self.file_client = mmcv.FileClient(**self.file_client_args)
625
+
626
+ feat_file = results.get('feat_file', None)
627
+ # Replace file extension with npy
628
+ pre, ext = os.path.splitext(feat_file)
629
+ feat_file = pre + '.npz'
630
+ if isinstance(feat_file, (list, tuple)):
631
+ # Load depths from a list of paths
632
+ results['feat'] = [self._read_depth(path) for path in feat_file]
633
+ elif feat_file is not None:
634
+ # Load single depth from path
635
+ results['feat'] = self._read_depth(feat_file)
636
+ else:
637
+ if 'feat_file' not in results:
638
+ # If `depth_file`` is not in results, check the `img` exists
639
+ # and format the depth. This for compatibility when the depth
640
+ # is manually set outside the pipeline.
641
+ raise KeyError('Either `feat_file` or `img` should exist in results.')
642
+ if isinstance(results['feat'], (list, tuple)):
643
+ assert isinstance(results['feat'][0], np.ndarray)
644
+ else:
645
+ assert isinstance(results['feat'], np.ndarray)
646
+ results['feat_file'] = None
647
+
648
+ return results
649
+
650
+ def __repr__(self):
651
+ repr_str = (f'{self.__class__.__name__}('
652
+ f'to_float32={self.to_float32}, '
653
+ f"color_type='{self.color_type}', "
654
+ f'file_client_args={self.file_client_args})')
655
+ return repr_str
656
+
657
+
658
+ @PIPELINES.register_module()
659
+ class FeatTopDownAffineFewShot:
660
+ """Affine transform the image to make input.
661
+
662
+ Required keys:'img', 'depth', 'joints_3d', 'joints_3d_visible', 'ann_info','scale',
663
+ 'rotation' and 'center'. Modified keys:'img', 'joints_3d', and
664
+ 'joints_3d_visible'.
665
+
666
+ Args:
667
+ use_udp (bool): To use unbiased data processing.
668
+ Paper ref: Huang et al. The Devil is in the Details: Delving into
669
+ Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
670
+ """
671
+
672
+ def __init__(self, use_udp=False):
673
+ self.use_udp = use_udp
674
+
675
+ def __call__(self, results):
676
+ image_size = results['ann_info']['image_size']
677
+
678
+ img = results['img']
679
+ feat = results['feat']
680
+ joints_3d = results['joints_3d']
681
+ joints_3d_visible = results['joints_3d_visible']
682
+ c = results['center']
683
+ s = results['scale']
684
+ r = results['rotation']
685
+
686
+ if self.use_udp:
687
+ trans = get_warp_matrix(r, c * 2.0, image_size - 1.0, s * 200.0)
688
+ img = cv2.warpAffine(
689
+ img,
690
+ trans, (int(image_size[0]), int(image_size[1])),
691
+ flags=cv2.INTER_LINEAR)
692
+ feat = cv2.warpAffine(
693
+ feat,
694
+ trans, (int(image_size[0]), int(image_size[1])),
695
+ flags=cv2.INTER_LINEAR)
696
+ joints_3d[:, 0:2] = warp_affine_joints(joints_3d[:, 0:2].copy(), trans)
697
+ else:
698
+ trans = get_affine_transform(c, s, r, image_size)
699
+ img = cv2.warpAffine(
700
+ img,
701
+ trans, (int(image_size[0]), int(image_size[1])),
702
+ flags=cv2.INTER_LINEAR)
703
+ feat = cv2.warpAffine(
704
+ feat,
705
+ trans, (int(image_size[0]), int(image_size[1])),
706
+ flags=cv2.INTER_LINEAR)
707
+ for i in range(len(joints_3d)):
708
+ if joints_3d_visible[i, 0] > 0.0:
709
+ joints_3d[i, 0:2] = affine_transform(joints_3d[i, 0:2], trans)
710
+
711
+ results['img'] = img
712
+ results['depth'] = feat
713
+ results['joints_3d'] = joints_3d
714
+ results['joints_3d_visible'] = joints_3d_visible
715
+
716
+ return results
EdgeCape/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .detectors import * # noqa
2
+ from .keypoint_heads import * # noqa
3
+ from .backbones import * # noqa
EdgeCape/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (227 Bytes). View file
 
EdgeCape/models/backbones/__pycache__/adapter.cpython-39.pyc ADDED
Binary file (27.8 kB). View file
 
EdgeCape/models/backbones/__pycache__/dino.cpython-39.pyc ADDED
Binary file (5.48 kB). View file
 
EdgeCape/models/backbones/adapter.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch.nn.functional as F
3
+ import fvcore.nn.weight_init as weight_init
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.functional import interpolate
8
+
9
+ """
10
+ Code is based on: https://github.com/mbanani/probe3d
11
+ """
12
+
13
+
14
+ class SurfaceNormalHead(nn.Module):
15
+ def __init__(
16
+ self,
17
+ feat_dim,
18
+ head_type="multiscale",
19
+ uncertainty_aware=False,
20
+ hidden_dim=512,
21
+ kernel_size=1,
22
+ ):
23
+ super().__init__()
24
+
25
+ self.uncertainty_aware = uncertainty_aware
26
+ output_dim = 4 if uncertainty_aware else 3
27
+
28
+ self.kernel_size = kernel_size
29
+
30
+ assert head_type in ["linear", "multiscale", "dpt"]
31
+ name = f"snorm_{head_type}_k{kernel_size}"
32
+ self.name = f"{name}_UA" if uncertainty_aware else name
33
+
34
+ if head_type == "linear":
35
+ self.head = Linear(feat_dim, output_dim, kernel_size)
36
+ elif head_type == "multiscale":
37
+ self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
38
+ elif head_type == "dpt":
39
+ self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
40
+ else:
41
+ raise ValueError(f"Unknown head type: {self.head_type}")
42
+
43
+ def forward(self, feats):
44
+ return self.head(feats)
45
+
46
+
47
+ class DepthHead(nn.Module):
48
+ def __init__(
49
+ self,
50
+ feat_dim,
51
+ head_type="multiscale",
52
+ min_depth=0.001,
53
+ max_depth=10,
54
+ prediction_type="bindepth",
55
+ hidden_dim=512,
56
+ kernel_size=1,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.kernel_size = kernel_size
61
+ self.name = f"{prediction_type}_{head_type}_k{kernel_size}"
62
+
63
+ if prediction_type == "bindepth":
64
+ output_dim = 256
65
+ self.predict = DepthBinPrediction(min_depth, max_depth, n_bins=output_dim)
66
+ elif prediction_type == "sigdepth":
67
+ output_dim = 1
68
+ self.predict = DepthSigmoidPrediction(min_depth, max_depth)
69
+ else:
70
+ raise ValueError()
71
+
72
+ if head_type == "linear":
73
+ self.head = Linear(feat_dim, output_dim, kernel_size)
74
+ elif head_type == "multiscale":
75
+ self.head = MultiscaleHead(feat_dim, output_dim, hidden_dim, kernel_size)
76
+ elif head_type == "dpt":
77
+ self.head = DPT(feat_dim, output_dim, hidden_dim, kernel_size)
78
+ else:
79
+ raise ValueError(f"Unknown head type: {self.head_type}")
80
+
81
+ def forward(self, feats):
82
+ """Prediction each pixel."""
83
+ feats = self.head(feats)
84
+ depth = self.predict(feats)
85
+ return depth
86
+
87
+
88
+ class DepthBinPrediction(nn.Module):
89
+ def __init__(
90
+ self,
91
+ min_depth=0.001,
92
+ max_depth=10,
93
+ n_bins=256,
94
+ bins_strategy="UD",
95
+ norm_strategy="linear",
96
+ ):
97
+ super().__init__()
98
+ self.n_bins = n_bins
99
+ self.min_depth = min_depth
100
+ self.max_depth = max_depth
101
+ self.norm_strategy = norm_strategy
102
+ self.bins_strategy = bins_strategy
103
+
104
+ def forward(self, prob):
105
+ if self.bins_strategy == "UD":
106
+ bins = torch.linspace(
107
+ self.min_depth, self.max_depth, self.n_bins, device=prob.device
108
+ )
109
+ elif self.bins_strategy == "SID":
110
+ bins = torch.logspace(
111
+ self.min_depth, self.max_depth, self.n_bins, device=prob.device
112
+ )
113
+
114
+ # following Adabins, default linear
115
+ if self.norm_strategy == "linear":
116
+ prob = torch.relu(prob)
117
+ eps = 0.1
118
+ prob = prob + eps
119
+ prob = prob / prob.sum(dim=1, keepdim=True)
120
+ elif self.norm_strategy == "softmax":
121
+ prob = torch.softmax(prob, dim=1)
122
+ elif self.norm_strategy == "sigmoid":
123
+ prob = torch.sigmoid(prob)
124
+ prob = prob / prob.sum(dim=1, keepdim=True)
125
+
126
+ depth = torch.einsum("ikhw,k->ihw", [prob, bins])
127
+ depth = depth.unsqueeze(dim=1)
128
+ return depth
129
+
130
+
131
+ class DepthSigmoidPrediction(nn.Module):
132
+ def __init__(self, min_depth=0.001, max_depth=10):
133
+ super().__init__()
134
+ self.min_depth = min_depth
135
+ self.max_depth = max_depth
136
+
137
+ def forward(self, pred):
138
+ depth = pred.sigmoid()
139
+ depth = self.min_depth + depth * (self.max_depth - self.min_depth)
140
+ return depth
141
+
142
+
143
+ class FeatureFusionBlock(nn.Module):
144
+ def __init__(self, features, kernel_size, with_skip=True):
145
+ super().__init__()
146
+ self.with_skip = with_skip
147
+ if self.with_skip:
148
+ self.resConfUnit1 = ResidualConvUnit(features, kernel_size)
149
+
150
+ self.resConfUnit2 = ResidualConvUnit(features, kernel_size)
151
+
152
+ def forward(self, x, skip_x=None):
153
+ if skip_x is not None:
154
+ assert self.with_skip and skip_x.shape == x.shape
155
+ x = self.resConfUnit1(x) + skip_x
156
+
157
+ x = self.resConfUnit2(x)
158
+ return x
159
+
160
+
161
+ class ResidualConvUnit(nn.Module):
162
+ def __init__(self, features, kernel_size):
163
+ super().__init__()
164
+ assert kernel_size % 1 == 0, "Kernel size needs to be odd"
165
+ padding = kernel_size // 2
166
+ self.conv = nn.Sequential(
167
+ nn.Conv2d(features, features, kernel_size, padding=padding),
168
+ nn.ReLU(True),
169
+ nn.Conv2d(features, features, kernel_size, padding=padding),
170
+ nn.ReLU(True),
171
+ )
172
+
173
+ def forward(self, x):
174
+ return self.conv(x) + x
175
+
176
+
177
+ class DPT(nn.Module):
178
+ def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=3, hr=False, swin=False):
179
+ super().__init__()
180
+ assert len(input_dims) == 4
181
+ self.hr = hr
182
+ self.conv_0 = nn.Conv2d(input_dims[0], hidden_dim, 1, padding=0)
183
+ self.conv_1 = nn.Conv2d(input_dims[1], hidden_dim, 1, padding=0)
184
+ self.conv_2 = nn.Conv2d(input_dims[2], hidden_dim, 1, padding=0)
185
+ self.conv_3 = nn.Conv2d(input_dims[3], hidden_dim, 1, padding=0)
186
+
187
+ self.ref_0 = FeatureFusionBlock(hidden_dim, kernel_size)
188
+ self.ref_1 = FeatureFusionBlock(hidden_dim, kernel_size)
189
+ self.ref_2 = FeatureFusionBlock(hidden_dim, kernel_size)
190
+ self.ref_3 = FeatureFusionBlock(hidden_dim, kernel_size, with_skip=False)
191
+
192
+ self.out_conv = nn.Sequential(
193
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
194
+ nn.ReLU(True),
195
+ nn.Conv2d(hidden_dim, output_dim, 3, padding=1),
196
+ )
197
+
198
+ if swin:
199
+ self.scale_factor = [1, 2, 4, 4]
200
+ else:
201
+ self.scale_factor = [2, 2, 2, 2]
202
+
203
+ def forward(self, features):
204
+ """Prediction each pixel."""
205
+ assert len(features) == 4
206
+ feats = features.copy()
207
+ feats[0] = self.conv_0(feats[0])
208
+ feats[1] = self.conv_1(feats[1])
209
+ feats[2] = self.conv_2(feats[2])
210
+ feats[3] = self.conv_3(feats[3])
211
+
212
+ feats = [interpolate(x, scale_factor=scale_factor) for x, scale_factor in zip(feats, self.scale_factor)]
213
+
214
+ out = self.ref_3(feats[3], None)
215
+ out = self.ref_2(feats[2], out)
216
+ out = self.ref_1(feats[1], out)
217
+ out = self.ref_0(feats[0], out)
218
+ if not self.hr:
219
+ return self.out_conv(out)
220
+ out = interpolate(out, scale_factor=4)
221
+ out = self.out_conv(out)
222
+ # out = interpolate(out, scale_factor=2)
223
+ return out
224
+
225
+
226
+ def make_conv(input_dim, hidden_dim, output_dim, num_layers, kernel_size=1):
227
+ return conv
228
+
229
+
230
+ class Linear(nn.Module):
231
+ def __init__(self, input_dim, output_dim, kernel_size=1):
232
+ super().__init__()
233
+ if type(input_dim) is not int:
234
+ input_dim = sum(input_dim)
235
+
236
+ assert type(input_dim) is int
237
+ padding = kernel_size // 2
238
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, padding=padding)
239
+
240
+ def forward(self, feats):
241
+ if type(feats) is list:
242
+ feats = torch.cat(feats, dim=1)
243
+
244
+ feats = interpolate(feats, scale_factor=4, mode="bilinear")
245
+ return self.conv(feats)
246
+
247
+
248
+ class MultiscaleHead(nn.Module):
249
+ def __init__(self, input_dims, output_dim, hidden_dim=512, kernel_size=1):
250
+ super().__init__()
251
+
252
+ self.convs = nn.ModuleList(
253
+ [make_conv(in_d, None, hidden_dim, 1, kernel_size) for in_d in input_dims]
254
+ )
255
+ interm_dim = len(input_dims) * hidden_dim
256
+ self.conv_mid = make_conv(interm_dim, hidden_dim, hidden_dim, 3, kernel_size)
257
+ self.conv_out = make_conv(hidden_dim, hidden_dim, output_dim, 2, kernel_size)
258
+
259
+ def forward(self, feats):
260
+ num_feats = len(feats)
261
+ feats = [self.convs[i](feats[i]) for i in range(num_feats)]
262
+
263
+ h, w = feats[-1].shape[-2:]
264
+ feats = [interpolate(feat, (h, w), mode="bilinear") for feat in feats]
265
+ feats = torch.cat(feats, dim=1).relu()
266
+
267
+ # upsample
268
+ feats = interpolate(feats, scale_factor=2, mode="bilinear")
269
+ feats = self.conv_mid(feats).relu()
270
+ feats = interpolate(feats, scale_factor=4, mode="bilinear")
271
+ return self.conv_out(feats)
272
+
273
+ def get_norm(norm, out_channels, num_norm_groups=32):
274
+ """
275
+ Args:
276
+ norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
277
+ or a callable that takes a channel number and returns
278
+ the normalization layer as a nn.Module.
279
+ Returns:
280
+ nn.Module or None: the normalization layer
281
+ """
282
+ if norm is None:
283
+ return None
284
+ if isinstance(norm, str):
285
+ if len(norm) == 0:
286
+ return None
287
+ norm = {
288
+ "GN": lambda channels: nn.GroupNorm(num_norm_groups, channels),
289
+ }[norm]
290
+ return norm(out_channels)
291
+
292
+
293
+ def get_activation(activation):
294
+ """
295
+ Args:
296
+ activation (str or callable): either one of relu, lrelu, prelu, leaky_relu,
297
+ sigmoid, tanh, elu, selu, swish, mish; or a callable that takes a
298
+ tensor and returns a tensor.
299
+ Returns:
300
+ nn.Module or None: the activation layer
301
+ """
302
+ if activation is None:
303
+ return None
304
+ if isinstance(activation, str):
305
+ if len(activation) == 0:
306
+ return None
307
+ activation = {
308
+ "relu": nn.ReLU,
309
+ "lrelu": nn.LeakyReLU,
310
+ "prelu": nn.PReLU,
311
+ "leaky_relu": nn.LeakyReLU,
312
+ "sigmoid": nn.Sigmoid,
313
+ "tanh": nn.Tanh,
314
+ "elu": nn.ELU,
315
+ "selu": nn.SELU,
316
+ }[activation]
317
+ return activation()
318
+
319
+
320
+ # SCE crisscross + diags
321
+ class EfficientSpatialContextNet(nn.Module):
322
+ def __init__(self, kernel_size=7, in_channels=768, out_channels=768, use_cuda=True):
323
+ super(EfficientSpatialContextNet, self).__init__()
324
+ self.kernel_size = kernel_size
325
+ self.pad = kernel_size // 2
326
+ self.conv = torch.nn.Conv2d(
327
+ in_channels + 4 * self.kernel_size,
328
+ out_channels,
329
+ 1,
330
+ bias=True,
331
+ padding_mode="zeros",
332
+ )
333
+
334
+ if use_cuda:
335
+ self.conv = self.conv.cuda()
336
+
337
+ def forward(self, feature):
338
+ b, c, h, w = feature.size()
339
+ feature_normalized = F.normalize(feature, p=2, dim=1)
340
+ feature_pad = F.pad(
341
+ feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0
342
+ )
343
+ output = torch.zeros(
344
+ [4 * self.kernel_size, b, h, w],
345
+ dtype=feature.dtype,
346
+ requires_grad=feature.requires_grad,
347
+ )
348
+ if feature.is_cuda:
349
+ output = output.cuda(feature.get_device())
350
+
351
+ # left-top to right-bottom
352
+ for i in range(self.kernel_size):
353
+ c = i
354
+ r = i
355
+ output[i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
356
+
357
+ # col
358
+ for i in range(self.kernel_size):
359
+ c = self.kernel_size // 2
360
+ r = i
361
+ output[1 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
362
+
363
+ # right-top to left-bottom
364
+ for i in range(self.kernel_size):
365
+ c = (self.kernel_size - 1) - i
366
+ r = i
367
+ output[2 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
368
+
369
+ # row
370
+ for i in range(self.kernel_size):
371
+ c = i
372
+ r = self.kernel_size // 2
373
+ output[3 * self.kernel_size + i] = (feature_pad[:, :, r: (h + r), c: (w + c)] * feature_normalized).sum(1)
374
+
375
+ output = output.transpose(0, 1).contiguous()
376
+ output = torch.cat((feature, output), 1)
377
+ output = self.conv(output)
378
+ # output = F.relu(output)
379
+
380
+ return output
381
+
382
+
383
+ class Conv2d(nn.Conv2d):
384
+ """
385
+ A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
386
+ """
387
+
388
+ def __init__(self, *args, **kwargs):
389
+ """
390
+ Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
391
+ Args:
392
+ norm (nn.Module, optional): a normalization layer
393
+ activation (callable(Tensor) -> Tensor): a callable activation function
394
+ It assumes that norm layer is used before activation.
395
+ """
396
+ norm = kwargs.pop("norm", None)
397
+ activation = kwargs.pop("activation", None)
398
+ super().__init__(*args, **kwargs)
399
+
400
+ self.norm = norm
401
+ self.activation = activation
402
+
403
+ def forward(self, x):
404
+ x = F.conv2d(
405
+ x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
406
+ )
407
+ if self.norm is not None:
408
+ x = self.norm(x)
409
+ if self.activation is not None:
410
+ x = self.activation(x)
411
+ return x
412
+
413
+
414
+ class CNNBlockBase(nn.Module):
415
+ """
416
+ A CNN block is assumed to have input channels, output channels and a stride.
417
+ The input and output of `forward()` method must be NCHW tensors.
418
+ The method can perform arbitrary computation but must match the given
419
+ channels and stride specification.
420
+ Attribute:
421
+ in_channels (int):
422
+ out_channels (int):
423
+ stride (int):
424
+ """
425
+
426
+ def __init__(self, in_channels, out_channels, stride):
427
+ """
428
+ The `__init__` method of any subclass should also contain these arguments.
429
+ Args:
430
+ in_channels (int):
431
+ out_channels (int):
432
+ stride (int):
433
+ """
434
+ super().__init__()
435
+ self.in_channels = in_channels
436
+ self.out_channels = out_channels
437
+ self.stride = stride
438
+
439
+
440
+ class BottleneckBlock(CNNBlockBase):
441
+ """
442
+ The standard bottleneck residual block used by ResNet-50, 101 and 152
443
+ defined in :paper:`ResNet`. It contains 3 conv layers with kernels
444
+ 1x1, 3x3, 1x1, and a projection shortcut if needed.
445
+ """
446
+
447
+ def __init__(
448
+ self,
449
+ in_channels,
450
+ out_channels,
451
+ *,
452
+ bottleneck_channels,
453
+ stride=1,
454
+ num_groups=1,
455
+ norm="GN",
456
+ stride_in_1x1=False,
457
+ dilation=1,
458
+ num_norm_groups=32,
459
+ kernel_size=(1, 3, 1)
460
+ ):
461
+ """
462
+ Args:
463
+ bottleneck_channels (int): number of output channels for the 3x3
464
+ "bottleneck" conv layers.
465
+ num_groups (int): number of groups for the 3x3 conv layer.
466
+ norm (str or callable): normalization for all conv layers.
467
+ See :func:`layers.get_norm` for supported format.
468
+ stride_in_1x1 (bool): when stride>1, whether to put stride in the
469
+ first 1x1 convolution or the bottleneck 3x3 convolution.
470
+ dilation (int): the dilation rate of the 3x3 conv layer.
471
+ """
472
+ super().__init__(in_channels, out_channels, stride)
473
+
474
+ if in_channels != out_channels:
475
+ self.shortcut = Conv2d(
476
+ in_channels,
477
+ out_channels,
478
+ kernel_size=1,
479
+ stride=stride,
480
+ bias=False,
481
+ norm=get_norm(norm, out_channels, num_norm_groups),
482
+ )
483
+ else:
484
+ self.shortcut = None
485
+
486
+ # The original MSRA ResNet models have stride in the first 1x1 conv
487
+ # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
488
+ # stride in the 3x3 conv
489
+ stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
490
+
491
+ self.conv1 = Conv2d(
492
+ in_channels,
493
+ bottleneck_channels,
494
+ kernel_size=kernel_size[0],
495
+ stride=stride_1x1,
496
+ padding=(kernel_size[0] - 1) // 2,
497
+ bias=False,
498
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
499
+ )
500
+
501
+ self.conv2 = Conv2d(
502
+ bottleneck_channels,
503
+ bottleneck_channels,
504
+ kernel_size=kernel_size[1],
505
+ stride=stride_3x3,
506
+ padding=dilation * (kernel_size[1] - 1) // 2,
507
+ bias=False,
508
+ groups=num_groups,
509
+ dilation=dilation,
510
+ norm=get_norm(norm, bottleneck_channels, num_norm_groups),
511
+ )
512
+
513
+ self.conv3 = Conv2d(
514
+ bottleneck_channels,
515
+ out_channels,
516
+ kernel_size=kernel_size[2],
517
+ bias=False,
518
+ norm=get_norm(norm, out_channels, num_norm_groups),
519
+ )
520
+
521
+ for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
522
+ if layer is not None: # shortcut can be None
523
+ weight_init.c2_msra_fill(layer)
524
+
525
+ # Zero-initialize the last normalization in each residual branch,
526
+ # so that at the beginning, the residual branch starts with zeros,
527
+ # and each residual block behaves like an identity.
528
+ # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
529
+ # "For BN layers, the learnable scaling coefficient γ is initialized
530
+ # to be 1, except for each residual block's last BN
531
+ # where γ is initialized to be 0."
532
+
533
+ # nn.init.constant_(self.conv3.norm.weight, 0)
534
+ # TODO this somehow hurts performance when training GN models from scratch.
535
+ # Add it as an option when we need to use this code to train a backbone.
536
+
537
+ def forward(self, x):
538
+ out = self.conv1(x)
539
+ out = F.relu_(out)
540
+
541
+ out = self.conv2(out)
542
+ out = F.relu_(out)
543
+
544
+ out = self.conv3(out)
545
+
546
+ if self.shortcut is not None:
547
+ shortcut = self.shortcut(x)
548
+ else:
549
+ shortcut = x
550
+
551
+ out += shortcut
552
+ out = F.relu_(out)
553
+ return out
554
+
555
+
556
+ class ResNet(nn.Module):
557
+ """
558
+ Implement :paper:`ResNet`.
559
+ """
560
+
561
+ def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
562
+ """
563
+ Args:
564
+ stem (nn.Module): a stem module
565
+ stages (list[list[CNNBlockBase]]): several (typically 4) stages,
566
+ each contains multiple :class:`CNNBlockBase`.
567
+ num_classes (None or int): if None, will not perform classification.
568
+ Otherwise, will create a linear layer.
569
+ out_features (list[str]): name of the layers whose outputs should
570
+ be returned in forward. Can be anything in "stem", "linear", or "res2" ...
571
+ If None, will return the output of the last layer.
572
+ freeze_at (int): The number of stages at the beginning to freeze.
573
+ see :meth:`freeze` for detailed explanation.
574
+ """
575
+ super().__init__()
576
+ self.stem = stem
577
+ self.num_classes = num_classes
578
+
579
+ current_stride = self.stem.stride
580
+ self._out_feature_strides = {"stem": current_stride}
581
+ self._out_feature_channels = {"stem": self.stem.out_channels}
582
+
583
+ self.stage_names, self.stages = [], []
584
+
585
+ if out_features is not None:
586
+ # Avoid keeping unused layers in this module. They consume extra memory
587
+ # and may cause allreduce to fail
588
+ num_stages = max(
589
+ [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
590
+ )
591
+ stages = stages[:num_stages]
592
+ for i, blocks in enumerate(stages):
593
+ assert len(blocks) > 0, len(blocks)
594
+ for block in blocks:
595
+ assert isinstance(block, CNNBlockBase), block
596
+
597
+ name = "res" + str(i + 2)
598
+ stage = nn.Sequential(*blocks)
599
+
600
+ self.add_module(name, stage)
601
+ self.stage_names.append(name)
602
+ self.stages.append(stage)
603
+
604
+ self._out_feature_strides[name] = current_stride = int(
605
+ current_stride * np.prod([k.stride for k in blocks])
606
+ )
607
+ self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
608
+ self.stage_names = tuple(self.stage_names) # Make it static for scripting
609
+
610
+ if num_classes is not None:
611
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
612
+ self.linear = nn.Linear(curr_channels, num_classes)
613
+
614
+ # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
615
+ # "The 1000-way fully-connected layer is initialized by
616
+ # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
617
+ nn.init.normal_(self.linear.weight, std=0.01)
618
+ name = "linear"
619
+
620
+ if out_features is None:
621
+ out_features = [name]
622
+ self._out_features = out_features
623
+ assert len(self._out_features)
624
+ children = [x[0] for x in self.named_children()]
625
+ for out_feature in self._out_features:
626
+ assert out_feature in children, "Available children: {}".format(", ".join(children))
627
+ self.freeze(freeze_at)
628
+
629
+ def forward(self, x):
630
+ """
631
+ Args:
632
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
633
+ Returns:
634
+ dict[str->Tensor]: names and the corresponding features
635
+ """
636
+ assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
637
+ outputs = {}
638
+ x = self.stem(x)
639
+ if "stem" in self._out_features:
640
+ outputs["stem"] = x
641
+ for name, stage in zip(self.stage_names, self.stages):
642
+ x = stage(x)
643
+ if name in self._out_features:
644
+ outputs[name] = x
645
+ if self.num_classes is not None:
646
+ x = self.avgpool(x)
647
+ x = torch.flatten(x, 1)
648
+ x = self.linear(x)
649
+ if "linear" in self._out_features:
650
+ outputs["linear"] = x
651
+ return outputs
652
+
653
+ def freeze(self, freeze_at=0):
654
+ """
655
+ Freeze the first several stages of the ResNet. Commonly used in
656
+ fine-tuning.
657
+ Layers that produce the same feature map spatial size are defined as one
658
+ "stage" by :paper:`FPN`.
659
+ Args:
660
+ freeze_at (int): number of stages to freeze.
661
+ `1` means freezing the stem. `2` means freezing the stem and
662
+ one residual stage, etc.
663
+ Returns:
664
+ nn.Module: this ResNet itself
665
+ """
666
+ if freeze_at >= 1:
667
+ self.stem.freeze()
668
+ for idx, stage in enumerate(self.stages, start=2):
669
+ if freeze_at >= idx:
670
+ for block in stage.children():
671
+ block.freeze()
672
+ return self
673
+
674
+ @staticmethod
675
+ def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
676
+ """
677
+ Create a list of blocks of the same type that forms one ResNet stage.
678
+ Args:
679
+ block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
680
+ stage. A module of this type must not change spatial resolution of inputs unless its
681
+ stride != 1.
682
+ num_blocks (int): number of blocks in this stage
683
+ in_channels (int): input channels of the entire stage.
684
+ out_channels (int): output channels of **every block** in the stage.
685
+ kwargs: other arguments passed to the constructor of
686
+ `block_class`. If the argument name is "xx_per_block", the
687
+ argument is a list of values to be passed to each block in the
688
+ stage. Otherwise, the same argument is passed to every block
689
+ in the stage.
690
+ Returns:
691
+ list[CNNBlockBase]: a list of block module.
692
+ Examples:
693
+ ::
694
+ stage = ResNet.make_stage(
695
+ BottleneckBlock, 3, in_channels=16, out_channels=64,
696
+ bottleneck_channels=16, num_groups=1,
697
+ stride_per_block=[2, 1, 1],
698
+ dilations_per_block=[1, 1, 2]
699
+ )
700
+ Usually, layers that produce the same feature map spatial size are defined as one
701
+ "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
702
+ all be 1.
703
+ """
704
+ blocks = []
705
+ for i in range(num_blocks):
706
+ curr_kwargs = {}
707
+ for k, v in kwargs.items():
708
+ if k.endswith("_per_block"):
709
+ assert len(v) == num_blocks, (
710
+ f"Argument '{k}' of make_stage should have the "
711
+ f"same length as num_blocks={num_blocks}."
712
+ )
713
+ newk = k[: -len("_per_block")]
714
+ assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
715
+ curr_kwargs[newk] = v[i]
716
+ else:
717
+ curr_kwargs[k] = v
718
+
719
+ blocks.append(
720
+ block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
721
+ )
722
+ in_channels = out_channels
723
+ return blocks
724
+
725
+ @staticmethod
726
+ def make_default_stages(depth, block_class=None, **kwargs):
727
+ """
728
+ Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
729
+ If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
730
+ instead for fine-grained customization.
731
+ Args:
732
+ depth (int): depth of ResNet
733
+ block_class (type): the CNN block class. Has to accept
734
+ `bottleneck_channels` argument for depth > 50.
735
+ By default it is BasicBlock or BottleneckBlock, based on the
736
+ depth.
737
+ kwargs:
738
+ other arguments to pass to `make_stage`. Should not contain
739
+ stride and channels, as they are predefined for each depth.
740
+ Returns:
741
+ list[list[CNNBlockBase]]: modules in all stages; see arguments of
742
+ :class:`ResNet.__init__`.
743
+ """
744
+ num_blocks_per_stage = {
745
+ 18: [2, 2, 2, 2],
746
+ 34: [3, 4, 6, 3],
747
+ 50: [3, 4, 6, 3],
748
+ 101: [3, 4, 23, 3],
749
+ 152: [3, 8, 36, 3],
750
+ }[depth]
751
+ if block_class is None:
752
+ block_class = BasicBlock if depth < 50 else BottleneckBlock
753
+ if depth < 50:
754
+ in_channels = [64, 64, 128, 256]
755
+ out_channels = [64, 128, 256, 512]
756
+ else:
757
+ in_channels = [64, 256, 512, 1024]
758
+ out_channels = [256, 512, 1024, 2048]
759
+ ret = []
760
+ for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
761
+ if depth >= 50:
762
+ kwargs["bottleneck_channels"] = o // 4
763
+ ret.append(
764
+ ResNet.make_stage(
765
+ block_class=block_class,
766
+ num_blocks=n,
767
+ stride_per_block=[s] + [1] * (n - 1),
768
+ in_channels=i,
769
+ out_channels=o,
770
+ **kwargs,
771
+ )
772
+ )
773
+ return ret
774
+
775
+ class DummyAggregationNetwork(nn.Module): # for testing, return the input
776
+ def __init__(self):
777
+ super(DummyAggregationNetwork, self).__init__()
778
+ # dummy paprameter
779
+ self.dummy = nn.Parameter(torch.ones([]))
780
+
781
+ def forward(self, batch, pose=None):
782
+ return batch * self.dummy
783
+
784
+
785
+ class AggregationNetwork(nn.Module):
786
+ """
787
+ Module for aggregating feature maps across time and space.
788
+ Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
789
+ https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
790
+ """
791
+
792
+ def __init__(
793
+ self,
794
+ device,
795
+ feature_dims=[640, 1280, 1280, 768],
796
+ projection_dim=384,
797
+ num_norm_groups=32,
798
+ save_timestep=[1],
799
+ kernel_size=[1, 3, 1],
800
+ contrastive_temp=10,
801
+ feat_map_dropout=0.0,
802
+ ):
803
+ super().__init__()
804
+ self.skip_connection = True
805
+ self.feat_map_dropout = feat_map_dropout
806
+ self.azimuth_embedding = None
807
+ self.pos_embedding = None
808
+ self.bottleneck_layers = nn.ModuleList()
809
+ self.feature_dims = feature_dims
810
+ # For CLIP symmetric cross entropy loss during training
811
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
812
+ self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
813
+ self.device = device
814
+ self.save_timestep = save_timestep
815
+
816
+ self.mixing_weights_names = []
817
+ for l, feature_dim in enumerate(self.feature_dims):
818
+ bottleneck_layer = nn.Sequential(
819
+ *ResNet.make_stage(
820
+ BottleneckBlock,
821
+ num_blocks=1,
822
+ in_channels=feature_dim,
823
+ bottleneck_channels=projection_dim // 4,
824
+ out_channels=projection_dim,
825
+ norm="GN",
826
+ num_norm_groups=num_norm_groups,
827
+ kernel_size=kernel_size
828
+ )
829
+ )
830
+ self.bottleneck_layers.append(bottleneck_layer)
831
+ for t in save_timestep:
832
+ # 1-index the layer name following prior work
833
+ self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l + 1}")
834
+ self.last_layer = None
835
+ self.bottleneck_layers = self.bottleneck_layers.to(device)
836
+ mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
837
+ self.mixing_weights = nn.Parameter(mixing_weights.to(device))
838
+ # count number of parameters
839
+ num_params = 0
840
+ for param in self.parameters():
841
+ num_params += param.numel()
842
+ print(f"AggregationNetwork has {num_params} parameters.")
843
+
844
+ def load_pretrained_weights(self, pretrained_dict):
845
+ custom_dict = self.state_dict()
846
+
847
+ # Handle size mismatch
848
+ if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict[
849
+ 'mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
850
+ # Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
851
+ custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
852
+ custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
853
+ else:
854
+ custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
855
+
856
+ # Load the weights that do match
857
+ matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
858
+ custom_dict.update(matching_keys)
859
+
860
+ # Now load the updated state_dict
861
+ self.load_state_dict(custom_dict, strict=False)
862
+
863
+ def forward(self, batch, pose=None):
864
+ """
865
+ Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
866
+ """
867
+ if self.feat_map_dropout > 0 and self.training:
868
+ batch = F.dropout(batch, p=self.feat_map_dropout)
869
+
870
+ output_feature = None
871
+ start = 0
872
+ mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
873
+ if self.pos_embedding is not None: # position embedding
874
+ batch = torch.cat((batch, self.pos_embedding), dim=1)
875
+ for i in range(len(mixing_weights)):
876
+ # Share bottleneck layers across timesteps
877
+ bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
878
+ # Chunk the batch according the layer
879
+ # Account for looping if there are multiple timesteps
880
+ end = start + self.feature_dims[i % len(self.feature_dims)]
881
+ feats = batch[:, start:end, :, :]
882
+ start = end
883
+ # Downsample the number of channels and weight the layer
884
+ bottlenecked_feature = bottleneck_layer(feats)
885
+ bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
886
+ if output_feature is None:
887
+ output_feature = bottlenecked_feature
888
+ else:
889
+ output_feature += bottlenecked_feature
890
+
891
+ if self.last_layer is not None:
892
+
893
+ output_feature_after = self.last_layer(output_feature)
894
+ if self.skip_connection:
895
+ # skip connection
896
+ output_feature = output_feature + output_feature_after
897
+ return output_feature
898
+
899
+
900
+ def conv1x1(in_planes, out_planes, stride=1):
901
+ """1x1 convolution without padding"""
902
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
903
+
904
+
905
+ def conv3x3(in_planes, out_planes, stride=1):
906
+ """3x3 convolution with padding"""
907
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
908
+
909
+
910
+ class BasicBlock(nn.Module):
911
+ def __init__(self, in_planes, planes, stride=1):
912
+ super().__init__()
913
+ self.conv1 = conv3x3(in_planes, planes, stride)
914
+ self.conv2 = conv3x3(planes, planes)
915
+ self.bn1 = nn.BatchNorm2d(planes)
916
+ self.bn2 = nn.BatchNorm2d(planes)
917
+ self.relu = nn.ReLU(inplace=True)
918
+
919
+ if stride == 1:
920
+ self.downsample = None
921
+ else:
922
+ self.downsample = nn.Sequential(
923
+ conv1x1(in_planes, planes, stride=stride),
924
+ nn.BatchNorm2d(planes)
925
+ )
926
+
927
+ def forward(self, x):
928
+ y = x
929
+ y = self.relu(self.bn1(self.conv1(y)))
930
+ y = self.bn2(self.conv2(y))
931
+
932
+ if self.downsample is not None:
933
+ x = self.downsample(x)
934
+
935
+ return self.relu(x + y)
EdgeCape/models/backbones/dino.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops as E
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers.models.vit_mae.modeling_vit_mae import (
6
+ get_2d_sincos_pos_embed_from_grid,
7
+ )
8
+
9
+
10
+ def resize_pos_embed(
11
+ pos_embed: torch.Tensor, hw: tuple[int, int], has_cls_token: bool = True
12
+ ):
13
+ """
14
+ Resize positional embedding for arbitrary image resolution. Resizing is done
15
+ via bicubic interpolation.
16
+
17
+ Args:
18
+ pos_embed: Positional embedding tensor of shape ``(n_patches, embed_dim)``.
19
+ hw: Target height and width of the tensor after interpolation.
20
+ has_cls_token: Whether ``pos_embed[0]`` is for the ``[cls]`` token.
21
+
22
+ Returns:
23
+ Tensor of shape ``(new_n_patches, embed_dim)`` of resized embedding.
24
+ ``new_n_patches`` is ``new_height * new_width`` if ``has_cls`` is False,
25
+ else ``1 + new_height * new_width``.
26
+ """
27
+
28
+ n_grid = pos_embed.shape[0] - 1 if has_cls_token else pos_embed.shape[0]
29
+
30
+ # Do not resize if already in same shape.
31
+ if n_grid == hw[0] * hw[1]:
32
+ return pos_embed
33
+
34
+ # Get original position embedding and extract ``[cls]`` token.
35
+ if has_cls_token:
36
+ cls_embed, pos_embed = pos_embed[[0]], pos_embed[1:]
37
+
38
+ orig_dim = int(pos_embed.shape[0] ** 0.5)
39
+
40
+ pos_embed = E.rearrange(pos_embed, "(h w) c -> 1 c h w", h=orig_dim)
41
+ pos_embed = F.interpolate(
42
+ pos_embed, hw, mode="bicubic", align_corners=False, antialias=True
43
+ )
44
+ pos_embed = E.rearrange(pos_embed, "1 c h w -> (h w) c")
45
+
46
+ # Add embedding of ``[cls]`` token back after resizing.
47
+ if has_cls_token:
48
+ pos_embed = torch.cat([cls_embed, pos_embed], dim=0)
49
+
50
+ return pos_embed
51
+
52
+
53
+ def center_padding(images, patch_size):
54
+ _, _, h, w = images.shape
55
+ diff_h = h % patch_size
56
+ diff_w = w % patch_size
57
+
58
+ if diff_h == 0 and diff_w == 0:
59
+ return images
60
+
61
+ pad_h = patch_size - diff_h
62
+ pad_w = patch_size - diff_w
63
+
64
+ pad_t = pad_h // 2
65
+ pad_l = pad_w // 2
66
+ pad_r = pad_w - pad_l
67
+ pad_b = pad_h - pad_t
68
+
69
+ images = F.pad(images, (pad_l, pad_r, pad_t, pad_b))
70
+ return images
71
+
72
+
73
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
74
+ """
75
+ COPIED FROM TRANSFORMERS PACKAGE AND EDITED TO ALLOW FOR DIFFERENT WIDTH-HEIGHT
76
+ Create 2D sin/cos positional embeddings.
77
+
78
+ Args:
79
+ embed_dim (`int`):
80
+ Embedding dimension.
81
+ grid_size (`int`):
82
+ The grid height and width.
83
+ add_cls_token (`bool`, *optional*, defaults to `False`):
84
+ Whether or not to add a classification (CLS) token.
85
+
86
+ Returns:
87
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or
88
+ (1+grid_size*grid_size, embed_dim): the
89
+ position embeddings (with or without classification token)
90
+ """
91
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
92
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
93
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
94
+ grid = np.stack(grid, axis=0)
95
+
96
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
97
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
98
+ if add_cls_token:
99
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
100
+ return pos_embed
101
+
102
+
103
+ def tokens_to_output(output_type, dense_tokens, cls_token, feat_hw):
104
+ if output_type == "cls":
105
+ assert cls_token is not None
106
+ output = cls_token
107
+ elif output_type == "gap":
108
+ output = dense_tokens.mean(dim=1)
109
+ elif output_type == "dense":
110
+ h, w = feat_hw
111
+ dense_tokens = E.rearrange(dense_tokens, "b (h w) c -> b c h w", h=h, w=w)
112
+ output = dense_tokens.contiguous()
113
+ elif output_type == "dense-cls":
114
+ assert cls_token is not None
115
+ h, w = feat_hw
116
+ dense_tokens = E.rearrange(dense_tokens, "b (h w) c -> b c h w", h=h, w=w)
117
+ cls_token = cls_token[:, :, None, None].repeat(1, 1, h, w)
118
+ output = torch.cat((dense_tokens, cls_token), dim=1).contiguous()
119
+ else:
120
+ raise ValueError()
121
+
122
+ return output
123
+
124
+ class DINO(torch.nn.Module):
125
+ def __init__(
126
+ self,
127
+ dino_name="dinov2",
128
+ model_name="vits14",
129
+ output="dense-cls",
130
+ layer=-1,
131
+ return_multilayer=True,
132
+ ):
133
+ super().__init__()
134
+ feat_dims = {
135
+ "vits14": 384,
136
+ "vitb8": 768,
137
+ "vitb16": 768,
138
+ "vitb14": 768,
139
+ "vitb14_reg": 768,
140
+ "vitl14": 1024,
141
+ "vitg14": 1536,
142
+ }
143
+
144
+ # get model
145
+ self.model_name = dino_name
146
+ self.checkpoint_name = f"{dino_name}_{model_name}"
147
+ dino_vit = torch.hub.load(f"facebookresearch/{dino_name}", self.checkpoint_name)
148
+ self.vit = dino_vit.eval().to(torch.float32)
149
+ self.has_registers = "_reg" in model_name
150
+
151
+ assert output in ["cls", "gap", "dense", "dense-cls"]
152
+ self.output = output
153
+ self.patch_size = self.vit.patch_embed.proj.kernel_size[0]
154
+
155
+ feat_dim = feat_dims[model_name]
156
+ feat_dim = feat_dim * 2 if output == "dense-cls" else feat_dim
157
+
158
+ num_layers = len(self.vit.blocks)
159
+ multilayers = [
160
+ num_layers // 4 - 1,
161
+ num_layers // 2 - 1,
162
+ num_layers // 4 * 3 - 1,
163
+ num_layers - 1,
164
+ ]
165
+
166
+ if return_multilayer:
167
+ self.feat_dim = [feat_dim, feat_dim, feat_dim, feat_dim]
168
+ self.multilayers = multilayers
169
+ else:
170
+ self.feat_dim = feat_dim
171
+ layer = multilayers[-1] if layer == -1 else layer
172
+ self.multilayers = [layer]
173
+
174
+ # define layer name (for logging)
175
+ self.layer = "-".join(str(_x) for _x in self.multilayers)
176
+
177
+ def forward(self, images):
178
+
179
+ # pad images (if needed) to ensure it matches patch_size
180
+ images = center_padding(images, self.patch_size)
181
+ h, w = images.shape[-2:]
182
+ h, w = h // self.patch_size, w // self.patch_size
183
+
184
+ if self.model_name == "dinov2":
185
+ x = self.vit.prepare_tokens_with_masks(images, None)
186
+ else:
187
+ x = self.vit.prepare_tokens(images)
188
+
189
+ embeds = []
190
+ for i, blk in enumerate(self.vit.blocks):
191
+ x = blk(x)
192
+ if i in self.multilayers:
193
+ embeds.append(x)
194
+ if len(embeds) == len(self.multilayers):
195
+ break
196
+
197
+ num_spatial = h * w
198
+ outputs = []
199
+ for i, x_i in enumerate(embeds):
200
+ cls_tok = x_i[:, 0]
201
+ # ignoring register tokens
202
+ spatial = x_i[:, -1 * num_spatial :]
203
+ x_i = tokens_to_output(self.output, spatial, cls_tok, (h, w))
204
+ outputs.append(x_i)
205
+
206
+ return outputs[0] if len(outputs) == 1 else outputs
EdgeCape/models/detectors/EdgeCape.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ import mmcv
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn
7
+ import torch.nn.functional as F
8
+ from mmcv.image import imwrite
9
+ from mmcv.visualization.image import imshow
10
+ from mmpose.models import builder
11
+ from mmpose.models.builder import POSENETS
12
+ from mmpose.models.detectors.base import BasePose
13
+ from EdgeCape.models.backbones.adapter import DPT
14
+ from EdgeCape.models.backbones.dino import DINO
15
+
16
+
17
+ @POSENETS.register_module()
18
+ class EdgeCape(BasePose):
19
+ """
20
+ EdgeCape: Edge-aware Context-Aware Pose Estimation.
21
+ Args:
22
+ keypoint_head (dict): Config for keypoint head.
23
+ encoder_config (dict): Config for encoder.
24
+ train_cfg (dict): Config for training. Default: None.
25
+ test_cfg (dict): Config for testing. Default: None.
26
+ freeze_backbone (bool): If True, freeze backbone. Default: False.
27
+ """
28
+
29
+ def __init__(self,
30
+ keypoint_head,
31
+ encoder_config,
32
+ train_cfg=None,
33
+ test_cfg=None,
34
+ freeze_backbone=False):
35
+ super().__init__()
36
+ feature_output_setting = encoder_config.get('output', 'dense-cls')
37
+ model_name = encoder_config.get('model_name', 'vits14')
38
+ self.encoder_sample = self.encoder_query = DINO(output=feature_output_setting, model_name=model_name)
39
+ self.probe = DPT(input_dims=self.encoder_query.feat_dim, output_dim=768)
40
+ self.backbone = 'dino_extractor'
41
+ self.freeze_backbone = freeze_backbone
42
+ if keypoint_head.get('freeze', None) is not None:
43
+ self.freeze_backbone = True
44
+
45
+ self.keypoint_head_module = builder.build_head(keypoint_head)
46
+ self.keypoint_head_module.init_weights()
47
+
48
+ self.train_cfg = train_cfg
49
+ self.test_cfg = test_cfg
50
+ self.target_type = test_cfg.get('target_type',
51
+ 'GaussianHeatMap') # GaussianHeatMap
52
+
53
+ @property
54
+ def with_keypoint(self):
55
+ """Check if has keypoint_head."""
56
+ return hasattr(self, 'keypoint_head_module')
57
+
58
+ def init_weights(self, pretrained=None):
59
+ """Weight initialization for model."""
60
+ self.encoder_sample.init_weights(pretrained)
61
+ self.encoder_query.init_weights(pretrained)
62
+ self.keypoint_head_module.init_weights()
63
+
64
+ def forward(self,
65
+ img_s,
66
+ img_q,
67
+ target_s=None,
68
+ target_weight_s=None,
69
+ target_q=None,
70
+ target_weight_q=None,
71
+ img_metas=None,
72
+ return_loss=True,
73
+ **kwargs):
74
+ """Calls either forward_train or forward_test depending on whether
75
+ return_loss=True. Note this setting will change the expected inputs.
76
+ When `return_loss=True`, img and img_meta are single-nested (i.e.
77
+ Tensor and List[dict]), and when `resturn_loss=False`, img and img_meta
78
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
79
+ the outer list indicating test time augmentations.
80
+ """
81
+ if return_loss:
82
+ return self.forward_train(img_s, target_s, target_weight_s, img_q,
83
+ target_q, target_weight_q, img_metas,
84
+ **kwargs)
85
+ else:
86
+ return self.forward_test(img_s, target_s, target_weight_s, img_q,
87
+ target_q, target_weight_q, img_metas,
88
+ **kwargs)
89
+
90
+ def forward_train(self,
91
+ img_s,
92
+ target_s,
93
+ target_weight_s,
94
+ img_q,
95
+ target_q,
96
+ target_weight_q,
97
+ img_metas,
98
+ **kwargs):
99
+ """Defines the computation performed at every call when training."""
100
+ bs, _, h, w = img_q.shape
101
+ random_mask = kwargs.get('rand_mask', None)
102
+ output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints = self.predict(img_s,
103
+ target_s,
104
+ target_weight_s,
105
+ img_q,
106
+ img_metas,
107
+ random_mask)
108
+
109
+ # parse the img meta to get the target keypoints
110
+ device = output.device
111
+ target_keypoints = self.parse_keypoints_from_img_meta(img_metas,
112
+ device,
113
+ keyword='query')
114
+
115
+ target_sizes = torch.tensor(
116
+ [img_q.shape[-2], img_q.shape[-1]]).unsqueeze(0).repeat(
117
+ img_q.shape[0], 1, 1)
118
+
119
+ losses = dict()
120
+ if self.with_keypoint:
121
+ keypoint_losses = self.keypoint_head_module.get_loss(output,
122
+ initial_proposals,
123
+ similarity_map,
124
+ target_keypoints,
125
+ target_q,
126
+ target_weight_q * mask_s,
127
+ target_sizes,
128
+ reconstructed_keypoints,
129
+ )
130
+ losses.update(keypoint_losses)
131
+ keypoint_accuracy = self.keypoint_head_module.get_accuracy(output[-1],
132
+ target_keypoints,
133
+ target_weight_q * mask_s,
134
+ target_sizes,
135
+ height=h)
136
+ losses.update(keypoint_accuracy)
137
+ return losses
138
+
139
+ def forward_test(self,
140
+ img_s,
141
+ target_s,
142
+ target_weight_s,
143
+ img_q,
144
+ target_q,
145
+ target_weight_q,
146
+ img_metas=None,
147
+ vis_offset=True,
148
+ **kwargs):
149
+
150
+ """Defines the computation performed at every call when testing."""
151
+ batch_size, _, img_height, img_width = img_q.shape
152
+ output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints = self.predict(img_s,
153
+ target_s,
154
+ target_weight_s,
155
+ img_q,
156
+ img_metas
157
+ )
158
+ predicted_pose = output[-1].detach().cpu().numpy()
159
+ result = {}
160
+
161
+ if self.with_keypoint:
162
+ keypoint_result = self.keypoint_head_module.decode(img_metas, predicted_pose, img_size=[img_width, img_height])
163
+ result.update(keypoint_result)
164
+
165
+ if vis_offset:
166
+ result.update({"points": torch.cat((initial_proposals[None], output)).cpu().numpy()})
167
+
168
+ result.update({"sample_image_file": [img_metas[i]['sample_image_file'] for i in range(len(img_metas))]})
169
+
170
+ return result
171
+
172
+ def predict(self,
173
+ img_s,
174
+ target_s,
175
+ target_weight_s,
176
+ img_q,
177
+ img_metas=None,
178
+ random_mask=None):
179
+
180
+ batch_size, _, img_height, img_width = img_q.shape
181
+ assert [i['sample_skeleton'][0] != i['query_skeleton'] for i in img_metas]
182
+ mask_s = target_weight_s[0]
183
+ for target_weight in target_weight_s:
184
+ mask_s = mask_s * target_weight
185
+ feature_q, feature_s = self.extract_features(img_s, img_q)
186
+ skeleton_lst = [i['sample_skeleton'][0] for i in img_metas]
187
+
188
+ (output, initial_proposals, similarity_map, reconstructed_keypoints) = self.keypoint_head_module(
189
+ feature_q, feature_s, target_s, mask_s, skeleton_lst, random_mask=random_mask)
190
+
191
+ return output, initial_proposals, similarity_map, mask_s, reconstructed_keypoints
192
+
193
+ def extract_features(self, img_s, img_q):
194
+ with torch.no_grad():
195
+ dino_feature_s = [self.encoder_sample(img) for img in img_s]
196
+ dino_feature_q = self.encoder_query(img_q) # [bs, 3, h, w]
197
+ if self.freeze_backbone:
198
+ with torch.no_grad():
199
+ feature_s = [self.probe(f) for f in dino_feature_s]
200
+ feature_q = self.probe(dino_feature_q)
201
+ else:
202
+ feature_s = [self.probe(f) for f in dino_feature_s]
203
+ feature_q = self.probe(dino_feature_q)
204
+
205
+ return feature_q, feature_s
206
+
207
+ def parse_keypoints_from_img_meta(self, img_meta, device, keyword='query'):
208
+ """Parse keypoints from the img_meta.
209
+
210
+ Args:
211
+ img_meta (dict): Image meta info.
212
+ device (torch.device): Device of the output keypoints.
213
+ keyword (str): 'query' or 'sample'. Default: 'query'.
214
+
215
+ Returns:
216
+ Tensor: Keypoints coordinates of query images.
217
+ """
218
+
219
+ if keyword == 'query':
220
+ query_kpt = torch.stack([
221
+ torch.tensor(info[f'{keyword}_joints_3d']).to(device) for info in img_meta], dim=0)[:, :, :2]
222
+ else:
223
+ query_kpt = []
224
+ for info in img_meta:
225
+ if isinstance(info[f'{keyword}_joints_3d'][0], torch.Tensor):
226
+ samples = torch.stack(info[f'{keyword}_joints_3d'])
227
+ else:
228
+ samples = np.array(info[f'{keyword}_joints_3d'])
229
+ query_kpt.append(torch.tensor(samples).to(device)[:, :, :2])
230
+ query_kpt = torch.stack(query_kpt, dim=0) # [bs, , num_samples, num_query, 2]
231
+ return query_kpt
232
+
233
+ def get_full_similarity_map(self, feature_q, feature_s, h, w):
234
+ resized_feature_q = F.interpolate(feature_q, size=(h, w),
235
+ mode='bilinear')
236
+ resized_feature_s = [F.interpolate(s, size=(h, w), mode='bilinear') for
237
+ s in feature_s]
238
+ return [self.chunk_cosine_sim(f_s, resized_feature_q) for f_s in
239
+ resized_feature_s]
240
+
241
+ # UNMODIFIED
242
+ def show_result(self,
243
+ img,
244
+ result,
245
+ skeleton=None,
246
+ kpt_score_thr=0.3,
247
+ bbox_color='green',
248
+ pose_kpt_color=None,
249
+ pose_limb_color=None,
250
+ radius=4,
251
+ text_color=(255, 0, 0),
252
+ thickness=1,
253
+ font_scale=0.5,
254
+ win_name='',
255
+ show=False,
256
+ wait_time=0,
257
+ out_file=None):
258
+ """Draw `result` over `img`.
259
+
260
+ Args:
261
+ img (str or Tensor): The image to be displayed.
262
+ result (list[dict]): The results to draw over `img`
263
+ (bbox_result, pose_result).
264
+ kpt_score_thr (float, optional): Minimum score of keypoints
265
+ to be shown. Default: 0.3.
266
+ bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
267
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints.
268
+ If None, do not draw keypoints.
269
+ pose_limb_color (np.array[Mx3]): Color of M limbs.
270
+ If None, do not draw limbs.
271
+ text_color (str or tuple or :obj:`Color`): Color of texts.
272
+ thickness (int): Thickness of lines.
273
+ font_scale (float): Font scales of texts.
274
+ win_name (str): The window name.
275
+ wait_time (int): Value of waitKey param.
276
+ Default: 0.
277
+ out_file (str or None): The filename to write the image.
278
+ Default: None.
279
+
280
+ Returns:
281
+ Tensor: Visualized img, only if not `show` or `out_file`.
282
+ """
283
+
284
+ img = mmcv.imread(img)
285
+ img = img.copy()
286
+ img_h, img_w, _ = img.shape
287
+
288
+ bbox_result = []
289
+ pose_result = []
290
+ for res in result:
291
+ bbox_result.append(res['bbox'])
292
+ pose_result.append(res['keypoints'])
293
+
294
+ if len(bbox_result) > 0:
295
+ bboxes = np.vstack(bbox_result)
296
+ # draw bounding boxes
297
+ mmcv.imshow_bboxes(
298
+ img,
299
+ bboxes,
300
+ colors=bbox_color,
301
+ top_k=-1,
302
+ thickness=thickness,
303
+ show=False,
304
+ win_name=win_name,
305
+ wait_time=wait_time,
306
+ out_file=None)
307
+
308
+ for person_id, kpts in enumerate(pose_result):
309
+ # draw each point on image
310
+ if pose_kpt_color is not None:
311
+ assert len(pose_kpt_color) == len(kpts), (
312
+ len(pose_kpt_color), len(kpts))
313
+ for kid, kpt in enumerate(kpts):
314
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(
315
+ kpt[1]), kpt[2]
316
+ if kpt_score > kpt_score_thr:
317
+ img_copy = img.copy()
318
+ r, g, b = pose_kpt_color[kid]
319
+ cv2.circle(img_copy, (int(x_coord), int(y_coord)),
320
+ radius, (int(r), int(g), int(b)), -1)
321
+ transparency = max(0, min(1, kpt_score))
322
+ cv2.addWeighted(
323
+ img_copy,
324
+ transparency,
325
+ img,
326
+ 1 - transparency,
327
+ 0,
328
+ dst=img)
329
+
330
+ # draw limbs
331
+ if skeleton is not None and pose_limb_color is not None:
332
+ assert len(pose_limb_color) == len(skeleton)
333
+ for sk_id, sk in enumerate(skeleton):
334
+ pos1 = (int(kpts[sk[0] - 1, 0]), int(kpts[sk[0] - 1,
335
+ 1]))
336
+ pos2 = (int(kpts[sk[1] - 1, 0]), int(kpts[sk[1] - 1,
337
+ 1]))
338
+ if (pos1[0] > 0 and pos1[0] < img_w and pos1[1] > 0
339
+ and pos1[1] < img_h and pos2[0] > 0
340
+ and pos2[0] < img_w and pos2[1] > 0
341
+ and pos2[1] < img_h
342
+ and kpts[sk[0] - 1, 2] > kpt_score_thr
343
+ and kpts[sk[1] - 1, 2] > kpt_score_thr):
344
+ img_copy = img.copy()
345
+ X = (pos1[0], pos2[0])
346
+ Y = (pos1[1], pos2[1])
347
+ mX = np.mean(X)
348
+ mY = np.mean(Y)
349
+ length = ((Y[0] - Y[1]) ** 2 + (
350
+ X[0] - X[1]) ** 2) ** 0.5
351
+ angle = math.degrees(
352
+ math.atan2(Y[0] - Y[1], X[0] - X[1]))
353
+ stickwidth = 2
354
+ polygon = cv2.ellipse2Poly(
355
+ (int(mX), int(mY)),
356
+ (int(length / 2), int(stickwidth)), int(angle),
357
+ 0, 360, 1)
358
+
359
+ r, g, b = pose_limb_color[sk_id]
360
+ cv2.fillConvexPoly(img_copy, polygon,
361
+ (int(r), int(g), int(b)))
362
+ transparency = max(
363
+ 0,
364
+ min(
365
+ 1, 0.5 *
366
+ (kpts[sk[0] - 1, 2] + kpts[
367
+ sk[1] - 1, 2])))
368
+ cv2.addWeighted(
369
+ img_copy,
370
+ transparency,
371
+ img,
372
+ 1 - transparency,
373
+ 0,
374
+ dst=img)
375
+
376
+ show, wait_time = 1, 1
377
+ if show:
378
+ height, width = img.shape[:2]
379
+ max_ = max(height, width)
380
+
381
+ factor = min(1, 800 / max_)
382
+ enlarge = cv2.resize(
383
+ img, (0, 0),
384
+ fx=factor,
385
+ fy=factor,
386
+ interpolation=cv2.INTER_CUBIC)
387
+ imshow(enlarge, win_name, wait_time)
388
+
389
+ if out_file is not None:
390
+ imwrite(img, out_file)
391
+
392
+ return img
EdgeCape/models/detectors/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .EdgeCape import EdgeCape
2
+
3
+ __all__ = ['EdgeCape']
EdgeCape/models/detectors/__pycache__/EdgeCape.cpython-39.pyc ADDED
Binary file (11.6 kB). View file