cheng-hust commited on
Commit
e8861c0
·
verified ·
1 Parent(s): 7c5b9cf

Upload 91 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src/__init__.py +5 -0
  2. src/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src/core/__init__.py +7 -0
  4. src/core/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src/core/__pycache__/config.cpython-310.pyc +0 -0
  6. src/core/__pycache__/yaml_config.cpython-310.pyc +0 -0
  7. src/core/__pycache__/yaml_utils.cpython-310.pyc +0 -0
  8. src/core/config.py +264 -0
  9. src/core/yaml_config.py +152 -0
  10. src/core/yaml_utils.py +208 -0
  11. src/data/__init__.py +7 -0
  12. src/data/__pycache__/__init__.cpython-310.pyc +0 -0
  13. src/data/__pycache__/dataloader.cpython-310.pyc +0 -0
  14. src/data/__pycache__/transforms.cpython-310.pyc +0 -0
  15. src/data/cifar10/__init__.py +14 -0
  16. src/data/cifar10/__pycache__/__init__.cpython-310.pyc +0 -0
  17. src/data/coco/__init__.py +9 -0
  18. src/data/coco/__pycache__/__init__.cpython-310.pyc +0 -0
  19. src/data/coco/__pycache__/coco_dataset.cpython-310.pyc +0 -0
  20. src/data/coco/__pycache__/coco_eval.cpython-310.pyc +0 -0
  21. src/data/coco/__pycache__/coco_utils.cpython-310.pyc +0 -0
  22. src/data/coco/coco_dataset.py +238 -0
  23. src/data/coco/coco_eval.py +269 -0
  24. src/data/coco/coco_utils.py +184 -0
  25. src/data/dataloader.py +28 -0
  26. src/data/functional.py +169 -0
  27. src/data/transforms.py +142 -0
  28. src/misc/__init__.py +3 -0
  29. src/misc/__pycache__/__init__.cpython-310.pyc +0 -0
  30. src/misc/__pycache__/dist.cpython-310.pyc +0 -0
  31. src/misc/__pycache__/logger.cpython-310.pyc +0 -0
  32. src/misc/__pycache__/visualizer.cpython-310.pyc +0 -0
  33. src/misc/dist.py +190 -0
  34. src/misc/logger.py +239 -0
  35. src/misc/visualizer.py +34 -0
  36. src/nn/__init__.py +7 -0
  37. src/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  38. src/nn/arch/__init__.py +1 -0
  39. src/nn/arch/__pycache__/__init__.cpython-310.pyc +0 -0
  40. src/nn/arch/__pycache__/classification.cpython-310.pyc +0 -0
  41. src/nn/arch/classification.py +41 -0
  42. src/nn/backbone/__init__.py +5 -0
  43. src/nn/backbone/__pycache__/__init__.cpython-310.pyc +0 -0
  44. src/nn/backbone/__pycache__/common.cpython-310.pyc +0 -0
  45. src/nn/backbone/__pycache__/presnet.cpython-310.pyc +0 -0
  46. src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc +0 -0
  47. src/nn/backbone/common.py +102 -0
  48. src/nn/backbone/presnet.py +225 -0
  49. src/nn/backbone/test_resnet.py +81 -0
  50. src/nn/backbone/utils.py +58 -0
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from . import data
3
+ from . import nn
4
+ from . import optim
5
+ from . import zoo
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (269 Bytes). View file
 
src/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """by lyuwenyu
2
+ """
3
+
4
+ # from .yaml_utils import register, create, load_config, merge_config, merge_dict
5
+ from .yaml_utils import *
6
+ from .config import BaseConfig
7
+ from .yaml_config import YAMLConfig
src/core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (309 Bytes). View file
 
src/core/__pycache__/config.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
src/core/__pycache__/yaml_config.cpython-310.pyc ADDED
Binary file (4.68 kB). View file
 
src/core/__pycache__/yaml_utils.cpython-310.pyc ADDED
Binary file (4.24 kB). View file
 
src/core/config.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """by lyuwenyu
2
+ """
3
+
4
+ from pprint import pprint
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from torch.optim import Optimizer
9
+ from torch.optim.lr_scheduler import LRScheduler
10
+ from torch.cuda.amp.grad_scaler import GradScaler
11
+
12
+ from typing import Callable, List, Dict
13
+
14
+
15
+ __all__ = ['BaseConfig', ]
16
+
17
+
18
+
19
+ class BaseConfig(object):
20
+ # TODO property
21
+
22
+
23
+ def __init__(self) -> None:
24
+ super().__init__()
25
+
26
+ self.task :str = None
27
+
28
+ self._model :nn.Module = None
29
+ self._postprocessor :nn.Module = None
30
+ self._criterion :nn.Module = None
31
+ self._optimizer :Optimizer = None
32
+ self._lr_scheduler :LRScheduler = None
33
+ self._train_dataloader :DataLoader = None
34
+ self._val_dataloader :DataLoader = None
35
+ self._ema :nn.Module = None
36
+ self._scaler :GradScaler = None
37
+
38
+ self.train_dataset :Dataset = None
39
+ self.val_dataset :Dataset = None
40
+ self.num_workers :int = 0
41
+ self.collate_fn :Callable = None
42
+
43
+ self.batch_size :int = None
44
+ self._train_batch_size :int = None
45
+ self._val_batch_size :int = None
46
+ self._train_shuffle: bool = None
47
+ self._val_shuffle: bool = None
48
+
49
+ self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None
50
+
51
+ # runtime
52
+ self.resume :str = None
53
+ self.tuning :str = None
54
+
55
+ self.epoches :int = None
56
+ self.last_epoch :int = -1
57
+ self.end_epoch :int = None
58
+
59
+ self.use_amp :bool = False
60
+ self.use_ema :bool = False
61
+ self.sync_bn :bool = False
62
+ self.clip_max_norm : float = None
63
+ self.find_unused_parameters :bool = None
64
+ # self.ema_decay: float = 0.9999
65
+ # self.grad_clip_: Callable = None
66
+
67
+ self.log_dir :str = './logs/'
68
+ self.log_step :int = 10
69
+ self._output_dir :str = None
70
+ self._print_freq :int = None
71
+ self.checkpoint_step :int = 1
72
+
73
+ # self.device :str = torch.device('cpu')
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ self.device = torch.device(device)
76
+
77
+
78
+ @property
79
+ def model(self, ) -> nn.Module:
80
+ return self._model
81
+
82
+ @model.setter
83
+ def model(self, m):
84
+ assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
85
+ self._model = m
86
+
87
+ @property
88
+ def postprocessor(self, ) -> nn.Module:
89
+ return self._postprocessor
90
+
91
+ @postprocessor.setter
92
+ def postprocessor(self, m):
93
+ assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
94
+ self._postprocessor = m
95
+
96
+ @property
97
+ def criterion(self, ) -> nn.Module:
98
+ return self._criterion
99
+
100
+ @criterion.setter
101
+ def criterion(self, m):
102
+ assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
103
+ self._criterion = m
104
+
105
+ @property
106
+ def optimizer(self, ) -> Optimizer:
107
+ return self._optimizer
108
+
109
+ @optimizer.setter
110
+ def optimizer(self, m):
111
+ assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class'
112
+ self._optimizer = m
113
+
114
+ @property
115
+ def lr_scheduler(self, ) -> LRScheduler:
116
+ return self._lr_scheduler
117
+
118
+ @lr_scheduler.setter
119
+ def lr_scheduler(self, m):
120
+ assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class'
121
+ self._lr_scheduler = m
122
+
123
+
124
+ @property
125
+ def train_dataloader(self):
126
+ if self._train_dataloader is None and self.train_dataset is not None:
127
+ loader = DataLoader(self.train_dataset,
128
+ batch_size=self.train_batch_size,
129
+ num_workers=self.num_workers,
130
+ collate_fn=self.collate_fn,
131
+ shuffle=self.train_shuffle, )
132
+ loader.shuffle = self.train_shuffle
133
+ self._train_dataloader = loader
134
+
135
+ return self._train_dataloader
136
+
137
+ @train_dataloader.setter
138
+ def train_dataloader(self, loader):
139
+ self._train_dataloader = loader
140
+
141
+ @property
142
+ def val_dataloader(self):
143
+ if self._val_dataloader is None and self.val_dataset is not None:
144
+ loader = DataLoader(self.val_dataset,
145
+ batch_size=self.val_batch_size,
146
+ num_workers=self.num_workers,
147
+ drop_last=False,
148
+ collate_fn=self.collate_fn,
149
+ shuffle=self.val_shuffle)
150
+ loader.shuffle = self.val_shuffle
151
+ self._val_dataloader = loader
152
+
153
+ return self._val_dataloader
154
+
155
+ @val_dataloader.setter
156
+ def val_dataloader(self, loader):
157
+ self._val_dataloader = loader
158
+
159
+
160
+ # TODO method
161
+ # @property
162
+ # def ema(self, ) -> nn.Module:
163
+ # if self._ema is None and self.use_ema and self.model is not None:
164
+ # self._ema = ModelEMA(self.model, self.ema_decay)
165
+ # return self._ema
166
+
167
+ @property
168
+ def ema(self, ) -> nn.Module:
169
+ return self._ema
170
+
171
+ @ema.setter
172
+ def ema(self, obj):
173
+ self._ema = obj
174
+
175
+
176
+ @property
177
+ def scaler(self) -> GradScaler:
178
+ if self._scaler is None and self.use_amp and torch.cuda.is_available():
179
+ self._scaler = GradScaler()
180
+ return self._scaler
181
+
182
+ @scaler.setter
183
+ def scaler(self, obj: GradScaler):
184
+ self._scaler = obj
185
+
186
+
187
+ @property
188
+ def val_shuffle(self):
189
+ if self._val_shuffle is None:
190
+ print('warning: set default val_shuffle=False')
191
+ return False
192
+ return self._val_shuffle
193
+
194
+ @val_shuffle.setter
195
+ def val_shuffle(self, shuffle):
196
+ assert isinstance(shuffle, bool), 'shuffle must be bool'
197
+ self._val_shuffle = shuffle
198
+
199
+ @property
200
+ def train_shuffle(self):
201
+ if self._train_shuffle is None:
202
+ print('warning: set default train_shuffle=True')
203
+ return True
204
+ return self._train_shuffle
205
+
206
+ @train_shuffle.setter
207
+ def train_shuffle(self, shuffle):
208
+ assert isinstance(shuffle, bool), 'shuffle must be bool'
209
+ self._train_shuffle = shuffle
210
+
211
+
212
+ @property
213
+ def train_batch_size(self):
214
+ if self._train_batch_size is None and isinstance(self.batch_size, int):
215
+ print(f'warning: set train_batch_size=batch_size={self.batch_size}')
216
+ return self.batch_size
217
+ return self._train_batch_size
218
+
219
+ @train_batch_size.setter
220
+ def train_batch_size(self, batch_size):
221
+ assert isinstance(batch_size, int), 'batch_size must be int'
222
+ self._train_batch_size = batch_size
223
+
224
+ @property
225
+ def val_batch_size(self):
226
+ if self._val_batch_size is None:
227
+ print(f'warning: set val_batch_size=batch_size={self.batch_size}')
228
+ return self.batch_size
229
+ return self._val_batch_size
230
+
231
+ @val_batch_size.setter
232
+ def val_batch_size(self, batch_size):
233
+ assert isinstance(batch_size, int), 'batch_size must be int'
234
+ self._val_batch_size = batch_size
235
+
236
+
237
+ @property
238
+ def output_dir(self):
239
+ if self._output_dir is None:
240
+ return self.log_dir
241
+ return self._output_dir
242
+
243
+ @output_dir.setter
244
+ def output_dir(self, root):
245
+ self._output_dir = root
246
+
247
+ @property
248
+ def print_freq(self):
249
+ if self._print_freq is None:
250
+ # self._print_freq = self.log_step
251
+ return self.log_step
252
+ return self._print_freq
253
+
254
+ @print_freq.setter
255
+ def print_freq(self, n):
256
+ assert isinstance(n, int), 'print_freq must be int'
257
+ self._print_freq = n
258
+
259
+
260
+ # def __repr__(self) -> str:
261
+ # pass
262
+
263
+
264
+
src/core/yaml_config.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """by lyuwenyu
2
+ """
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import re
8
+ import copy
9
+
10
+ from .config import BaseConfig
11
+ from .yaml_utils import load_config, merge_config, create, merge_dict
12
+
13
+
14
+ class YAMLConfig(BaseConfig):
15
+ def __init__(self, cfg_path: str, **kwargs) -> None:
16
+ super().__init__()
17
+
18
+ cfg = load_config(cfg_path)
19
+ merge_dict(cfg, kwargs)
20
+
21
+ # pprint(cfg)
22
+
23
+ self.yaml_cfg = cfg
24
+
25
+ self.log_step = cfg.get('log_step', 100)
26
+ self.checkpoint_step = cfg.get('checkpoint_step', 1)
27
+ self.epoches = cfg.get('epoches', -1)
28
+ self.resume = cfg.get('resume', '')
29
+ self.tuning = cfg.get('tuning', '')
30
+ self.sync_bn = cfg.get('sync_bn', False)
31
+ self.output_dir = cfg.get('output_dir', None)
32
+
33
+ self.use_ema = cfg.get('use_ema', False)
34
+ self.use_amp = cfg.get('use_amp', False)
35
+ self.autocast = cfg.get('autocast', dict())
36
+ self.find_unused_parameters = cfg.get('find_unused_parameters', None)
37
+ self.clip_max_norm = cfg.get('clip_max_norm', 0.)
38
+
39
+
40
+ @property
41
+ def model(self, ) -> torch.nn.Module:
42
+ if self._model is None and 'model' in self.yaml_cfg:
43
+ merge_config(self.yaml_cfg)
44
+ self._model = create(self.yaml_cfg['model'])
45
+ return self._model
46
+
47
+ @property
48
+ def postprocessor(self, ) -> torch.nn.Module:
49
+ if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
50
+ merge_config(self.yaml_cfg)
51
+ self._postprocessor = create(self.yaml_cfg['postprocessor'])
52
+ return self._postprocessor
53
+
54
+ @property
55
+ def criterion(self, ):
56
+ if self._criterion is None and 'criterion' in self.yaml_cfg:
57
+ merge_config(self.yaml_cfg)
58
+ self._criterion = create(self.yaml_cfg['criterion'])
59
+ return self._criterion
60
+
61
+
62
+ @property
63
+ def optimizer(self, ):
64
+ if self._optimizer is None and 'optimizer' in self.yaml_cfg:
65
+ merge_config(self.yaml_cfg)
66
+ params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)
67
+ self._optimizer = create('optimizer', params=params)
68
+
69
+ return self._optimizer
70
+
71
+ @property
72
+ def lr_scheduler(self, ):
73
+ if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
74
+ merge_config(self.yaml_cfg)
75
+ self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
76
+ print('Initial lr: ', self._lr_scheduler.get_last_lr())
77
+
78
+ return self._lr_scheduler
79
+
80
+ @property
81
+ def train_dataloader(self, ):
82
+ if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
83
+ merge_config(self.yaml_cfg)
84
+ self._train_dataloader = create('train_dataloader')
85
+ self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
86
+
87
+ return self._train_dataloader
88
+
89
+ @property
90
+ def val_dataloader(self, ):
91
+ if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg:
92
+ merge_config(self.yaml_cfg)
93
+ self._val_dataloader = create('val_dataloader')
94
+ self._val_dataloader.shuffle = self.yaml_cfg['val_dataloader'].get('shuffle', False)
95
+
96
+ return self._val_dataloader
97
+
98
+
99
+ @property
100
+ def ema(self, ):
101
+ if self._ema is None and self.yaml_cfg.get('use_ema', False):
102
+ merge_config(self.yaml_cfg)
103
+ self._ema = create('ema', model=self.model)
104
+
105
+ return self._ema
106
+
107
+
108
+ @property
109
+ def scaler(self, ):
110
+ if self._scaler is None and self.yaml_cfg.get('use_amp', False):
111
+ merge_config(self.yaml_cfg)
112
+ self._scaler = create('scaler')
113
+
114
+ return self._scaler
115
+
116
+
117
+ @staticmethod
118
+ def get_optim_params(cfg: dict, model: nn.Module):
119
+ '''
120
+ E.g.:
121
+ ^(?=.*a)(?=.*b).*$ means including a and b
122
+ ^((?!b.)*a((?!b).)*$ means including a but not b
123
+ ^((?!b|c).)*a((?!b|c).)*$ means including a but not (b | c)
124
+ '''
125
+ assert 'type' in cfg, ''
126
+ cfg = copy.deepcopy(cfg)
127
+
128
+ if 'params' not in cfg:
129
+ return model.parameters()
130
+
131
+ assert isinstance(cfg['params'], list), ''
132
+
133
+ param_groups = []
134
+ visited = []
135
+ for pg in cfg['params']:
136
+ pattern = pg['params']
137
+ params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
138
+ pg['params'] = params.values()
139
+ param_groups.append(pg)
140
+ visited.extend(list(params.keys()))
141
+
142
+ names = [k for k, v in model.named_parameters() if v.requires_grad]
143
+
144
+ if len(visited) < len(names):
145
+ unseen = set(names) - set(visited)
146
+ params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
147
+ param_groups.append({'params': params.values()})
148
+ visited.extend(list(params.keys()))
149
+
150
+ assert len(visited) == len(names), ''
151
+
152
+ return param_groups
src/core/yaml_utils.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"by lyuwenyu
2
+ """
3
+
4
+ import os
5
+ import yaml
6
+ import inspect
7
+ import importlib
8
+
9
+ __all__ = ['GLOBAL_CONFIG', 'register', 'create', 'load_config', 'merge_config', 'merge_dict']
10
+
11
+
12
+ GLOBAL_CONFIG = dict()
13
+ INCLUDE_KEY = '__include__'
14
+
15
+
16
+ def register(cls: type):
17
+ '''
18
+ Args:
19
+ cls (type): Module class to be registered.
20
+ '''
21
+ if cls.__name__ in GLOBAL_CONFIG:
22
+ raise ValueError('{} already registered'.format(cls.__name__))
23
+
24
+ if inspect.isfunction(cls):
25
+ GLOBAL_CONFIG[cls.__name__] = cls
26
+
27
+ elif inspect.isclass(cls):
28
+ GLOBAL_CONFIG[cls.__name__] = extract_schema(cls)
29
+
30
+ else:
31
+ raise ValueError(f'register {cls}')
32
+
33
+ return cls
34
+
35
+
36
+ def extract_schema(cls: type):
37
+ '''
38
+ Args:
39
+ cls (type),
40
+ Return:
41
+ Dict,
42
+ '''
43
+ argspec = inspect.getfullargspec(cls.__init__)
44
+ arg_names = [arg for arg in argspec.args if arg != 'self']
45
+ num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
46
+ num_requires = len(arg_names) - num_defualts
47
+
48
+ schame = dict()
49
+ schame['_name'] = cls.__name__
50
+ schame['_pymodule'] = importlib.import_module(cls.__module__)
51
+ schame['_inject'] = getattr(cls, '__inject__', [])
52
+ schame['_share'] = getattr(cls, '__share__', [])
53
+
54
+ for i, name in enumerate(arg_names):
55
+ if name in schame['_share']:
56
+ assert i >= num_requires, 'share config must have default value.'
57
+ value = argspec.defaults[i - num_requires]
58
+
59
+ elif i >= num_requires:
60
+ value = argspec.defaults[i - num_requires]
61
+
62
+ else:
63
+ value = None
64
+
65
+ schame[name] = value
66
+
67
+ return schame
68
+
69
+
70
+
71
+ def create(type_or_name, **kwargs):
72
+ '''
73
+ '''
74
+ assert type(type_or_name) in (type, str), 'create should be class or name.'
75
+
76
+ name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
77
+
78
+ if name in GLOBAL_CONFIG:
79
+ if hasattr(GLOBAL_CONFIG[name], '__dict__'):
80
+ return GLOBAL_CONFIG[name]
81
+ else:
82
+ raise ValueError('The module {} is not registered'.format(name))
83
+
84
+ cfg = GLOBAL_CONFIG[name]
85
+
86
+ if isinstance(cfg, dict) and 'type' in cfg:
87
+ _cfg: dict = GLOBAL_CONFIG[cfg['type']]
88
+ _cfg.update(cfg) # update global cls default args
89
+ _cfg.update(kwargs) # TODO
90
+ name = _cfg.pop('type')
91
+
92
+ return create(name)
93
+
94
+
95
+ cls = getattr(cfg['_pymodule'], name)
96
+ argspec = inspect.getfullargspec(cls.__init__)
97
+ arg_names = [arg for arg in argspec.args if arg != 'self']
98
+
99
+ cls_kwargs = {}
100
+ cls_kwargs.update(cfg)
101
+
102
+ # shared var
103
+ for k in cfg['_share']:
104
+ if k in GLOBAL_CONFIG:
105
+ cls_kwargs[k] = GLOBAL_CONFIG[k]
106
+ else:
107
+ cls_kwargs[k] = cfg[k]
108
+
109
+ # inject
110
+ for k in cfg['_inject']:
111
+ _k = cfg[k]
112
+
113
+ if _k is None:
114
+ continue
115
+
116
+ if isinstance(_k, str):
117
+ if _k not in GLOBAL_CONFIG:
118
+ raise ValueError(f'Missing inject config of {_k}.')
119
+
120
+ _cfg = GLOBAL_CONFIG[_k]
121
+
122
+ if isinstance(_cfg, dict):
123
+ cls_kwargs[k] = create(_cfg['_name'])
124
+ else:
125
+ cls_kwargs[k] = _cfg
126
+
127
+ elif isinstance(_k, dict):
128
+ if 'type' not in _k.keys():
129
+ raise ValueError(f'Missing inject for `type` style.')
130
+
131
+ _type = str(_k['type'])
132
+ if _type not in GLOBAL_CONFIG:
133
+ raise ValueError(f'Missing {_type} in inspect stage.')
134
+
135
+ # TODO modified inspace, maybe get wrong result for using `> 1`
136
+ _cfg: dict = GLOBAL_CONFIG[_type]
137
+ # _cfg_copy = copy.deepcopy(_cfg)
138
+ _cfg.update(_k) # update
139
+ cls_kwargs[k] = create(_type)
140
+ # _cfg.update(_cfg_copy) # resume
141
+
142
+ else:
143
+ raise ValueError(f'Inject does not support {_k}')
144
+
145
+
146
+ cls_kwargs = {n: cls_kwargs[n] for n in arg_names}
147
+
148
+ return cls(**cls_kwargs)
149
+
150
+
151
+
152
+ def load_config(file_path, cfg=dict()):
153
+ '''load config
154
+ '''
155
+ _, ext = os.path.splitext(file_path)
156
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
157
+
158
+ with open(file_path) as f:
159
+ file_cfg = yaml.load(f, Loader=yaml.Loader)
160
+ if file_cfg is None:
161
+ return {}
162
+
163
+ if INCLUDE_KEY in file_cfg:
164
+ base_yamls = list(file_cfg[INCLUDE_KEY])
165
+ for base_yaml in base_yamls:
166
+ if base_yaml.startswith('~'):
167
+ base_yaml = os.path.expanduser(base_yaml)
168
+
169
+ if not base_yaml.startswith('/'):
170
+ base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
171
+
172
+ with open(base_yaml) as f:
173
+ base_cfg = load_config(base_yaml, cfg)
174
+ merge_config(base_cfg, cfg)
175
+
176
+ return merge_config(file_cfg, cfg)
177
+
178
+
179
+
180
+ def merge_dict(dct, another_dct):
181
+ '''merge another_dct into dct
182
+ '''
183
+ for k in another_dct:
184
+ if (k in dct and isinstance(dct[k], dict) and isinstance(another_dct[k], dict)):
185
+ merge_dict(dct[k], another_dct[k])
186
+ else:
187
+ dct[k] = another_dct[k]
188
+
189
+ return dct
190
+
191
+
192
+
193
+ def merge_config(config, another_cfg=None):
194
+ """
195
+ Merge config into global config or another_cfg.
196
+
197
+ Args:
198
+ config (dict): Config to be merged.
199
+
200
+ Returns: global config
201
+ """
202
+ global GLOBAL_CONFIG
203
+ dct = GLOBAL_CONFIG if another_cfg is None else another_cfg
204
+
205
+ return merge_dict(dct, config)
206
+
207
+
208
+
src/data/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ from .coco import *
3
+ from .cifar10 import CIFAR10
4
+
5
+ from .dataloader import *
6
+ from .transforms import *
7
+
src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (270 Bytes). View file
 
src/data/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
src/data/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (5.19 kB). View file
 
src/data/cifar10/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torchvision
3
+ from typing import Optional, Callable
4
+
5
+ from src.core import register
6
+
7
+
8
+ @register
9
+ class CIFAR10(torchvision.datasets.CIFAR10):
10
+ __inject__ = ['transform', 'target_transform']
11
+
12
+ def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
13
+ super().__init__(root, train, transform, target_transform, download)
14
+
src/data/cifar10/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (823 Bytes). View file
 
src/data/coco/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .coco_dataset import (
2
+ CocoDetection,
3
+ mscoco_category2label,
4
+ mscoco_label2category,
5
+ mscoco_category2name,
6
+ )
7
+ from .coco_eval import *
8
+
9
+ from .coco_utils import get_coco_api_from_dataset
src/data/coco/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (402 Bytes). View file
 
src/data/coco/__pycache__/coco_dataset.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
src/data/coco/__pycache__/coco_eval.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
src/data/coco/__pycache__/coco_utils.cpython-310.pyc ADDED
Binary file (6.61 kB). View file
 
src/data/coco/coco_dataset.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+
4
+ COCO dataset which returns image_id for evaluation.
5
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
6
+ """
7
+
8
+ import torch
9
+ import torch.utils.data
10
+
11
+ import torchvision
12
+ torchvision.disable_beta_transforms_warning()
13
+
14
+ from torchvision import datapoints
15
+
16
+ from pycocotools import mask as coco_mask
17
+
18
+ from src.core import register
19
+
20
+ __all__ = ['CocoDetection']
21
+
22
+
23
+ @register
24
+ class CocoDetection(torchvision.datasets.CocoDetection):
25
+ __inject__ = ['transforms']
26
+ __share__ = ['remap_mscoco_category']
27
+
28
+ def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False):
29
+ super(CocoDetection, self).__init__(img_folder, ann_file)
30
+ self._transforms = transforms
31
+ self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category)
32
+ self.img_folder = img_folder
33
+ self.ann_file = ann_file
34
+ self.return_masks = return_masks
35
+ self.remap_mscoco_category = remap_mscoco_category
36
+
37
+ def __getitem__(self, idx):
38
+ img, target = super(CocoDetection, self).__getitem__(idx)
39
+ image_id = self.ids[idx]
40
+ target = {'image_id': image_id, 'annotations': target}
41
+ img, target = self.prepare(img, target)
42
+
43
+ # ['boxes', 'masks', 'labels']:
44
+ if 'boxes' in target:
45
+ target['boxes'] = datapoints.BoundingBox(
46
+ target['boxes'],
47
+ format=datapoints.BoundingBoxFormat.XYXY,
48
+ spatial_size=img.size[::-1]) # h w
49
+
50
+ if 'masks' in target:
51
+ target['masks'] = datapoints.Mask(target['masks'])
52
+
53
+ if self._transforms is not None:
54
+ img, target = self._transforms(img, target)
55
+
56
+ return img, target
57
+
58
+ def extra_repr(self) -> str:
59
+ s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n'
60
+ s += f' return_masks: {self.return_masks}\n'
61
+ if hasattr(self, '_transforms') and self._transforms is not None:
62
+ s += f' transforms:\n {repr(self._transforms)}'
63
+
64
+ return s
65
+
66
+
67
+ def convert_coco_poly_to_mask(segmentations, height, width):
68
+ masks = []
69
+ for polygons in segmentations:
70
+ rles = coco_mask.frPyObjects(polygons, height, width)
71
+ mask = coco_mask.decode(rles)
72
+ if len(mask.shape) < 3:
73
+ mask = mask[..., None]
74
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
75
+ mask = mask.any(dim=2)
76
+ masks.append(mask)
77
+ if masks:
78
+ masks = torch.stack(masks, dim=0)
79
+ else:
80
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
81
+ return masks
82
+
83
+
84
+ class ConvertCocoPolysToMask(object):
85
+ def __init__(self, return_masks=False, remap_mscoco_category=False):
86
+ self.return_masks = return_masks
87
+ self.remap_mscoco_category = remap_mscoco_category
88
+
89
+ def __call__(self, image, target):
90
+ w, h = image.size
91
+
92
+ image_id = target["image_id"]
93
+ image_id = torch.tensor([image_id])
94
+
95
+ anno = target["annotations"]
96
+
97
+ anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
98
+
99
+ boxes = [obj["bbox"] for obj in anno]
100
+ # guard against no boxes via resizing
101
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
102
+ boxes[:, 2:] += boxes[:, :2]
103
+ boxes[:, 0::2].clamp_(min=0, max=w)
104
+ boxes[:, 1::2].clamp_(min=0, max=h)
105
+
106
+ if self.remap_mscoco_category:
107
+ classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
108
+ else:
109
+ classes = [obj["category_id"] for obj in anno]
110
+
111
+ classes = torch.tensor(classes, dtype=torch.int64)
112
+
113
+ if self.return_masks:
114
+ segmentations = [obj["segmentation"] for obj in anno]
115
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
116
+
117
+ keypoints = None
118
+ if anno and "keypoints" in anno[0]:
119
+ keypoints = [obj["keypoints"] for obj in anno]
120
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
121
+ num_keypoints = keypoints.shape[0]
122
+ if num_keypoints:
123
+ keypoints = keypoints.view(num_keypoints, -1, 3)
124
+
125
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
126
+ boxes = boxes[keep]
127
+ classes = classes[keep]
128
+ if self.return_masks:
129
+ masks = masks[keep]
130
+ if keypoints is not None:
131
+ keypoints = keypoints[keep]
132
+
133
+ target = {}
134
+ target["boxes"] = boxes
135
+ target["labels"] = classes
136
+ if self.return_masks:
137
+ target["masks"] = masks
138
+ target["image_id"] = image_id
139
+ if keypoints is not None:
140
+ target["keypoints"] = keypoints
141
+
142
+ # for conversion to coco api
143
+ area = torch.tensor([obj["area"] for obj in anno])
144
+ iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
145
+ target["area"] = area[keep]
146
+ target["iscrowd"] = iscrowd[keep]
147
+
148
+ target["orig_size"] = torch.as_tensor([int(w), int(h)])
149
+ target["size"] = torch.as_tensor([int(w), int(h)])
150
+
151
+ return image, target
152
+
153
+
154
+ mscoco_category2name = {
155
+ 1: 'person',
156
+ 2: 'bicycle',
157
+ 3: 'car',
158
+ 4: 'motorcycle',
159
+ 5: 'airplane',
160
+ 6: 'bus',
161
+ 7: 'train',
162
+ 8: 'truck',
163
+ 9: 'boat',
164
+ 10: 'traffic light',
165
+ 11: 'fire hydrant',
166
+ 13: 'stop sign',
167
+ 14: 'parking meter',
168
+ 15: 'bench',
169
+ 16: 'bird',
170
+ 17: 'cat',
171
+ 18: 'dog',
172
+ 19: 'horse',
173
+ 20: 'sheep',
174
+ 21: 'cow',
175
+ 22: 'elephant',
176
+ 23: 'bear',
177
+ 24: 'zebra',
178
+ 25: 'giraffe',
179
+ 27: 'backpack',
180
+ 28: 'umbrella',
181
+ 31: 'handbag',
182
+ 32: 'tie',
183
+ 33: 'suitcase',
184
+ 34: 'frisbee',
185
+ 35: 'skis',
186
+ 36: 'snowboard',
187
+ 37: 'sports ball',
188
+ 38: 'kite',
189
+ 39: 'baseball bat',
190
+ 40: 'baseball glove',
191
+ 41: 'skateboard',
192
+ 42: 'surfboard',
193
+ 43: 'tennis racket',
194
+ 44: 'bottle',
195
+ 46: 'wine glass',
196
+ 47: 'cup',
197
+ 48: 'fork',
198
+ 49: 'knife',
199
+ 50: 'spoon',
200
+ 51: 'bowl',
201
+ 52: 'banana',
202
+ 53: 'apple',
203
+ 54: 'sandwich',
204
+ 55: 'orange',
205
+ 56: 'broccoli',
206
+ 57: 'carrot',
207
+ 58: 'hot dog',
208
+ 59: 'pizza',
209
+ 60: 'donut',
210
+ 61: 'cake',
211
+ 62: 'chair',
212
+ 63: 'couch',
213
+ 64: 'potted plant',
214
+ 65: 'bed',
215
+ 67: 'dining table',
216
+ 70: 'toilet',
217
+ 72: 'tv',
218
+ 73: 'laptop',
219
+ 74: 'mouse',
220
+ 75: 'remote',
221
+ 76: 'keyboard',
222
+ 77: 'cell phone',
223
+ 78: 'microwave',
224
+ 79: 'oven',
225
+ 80: 'toaster',
226
+ 81: 'sink',
227
+ 82: 'refrigerator',
228
+ 84: 'book',
229
+ 85: 'clock',
230
+ 86: 'vase',
231
+ 87: 'scissors',
232
+ 88: 'teddy bear',
233
+ 89: 'hair drier',
234
+ 90: 'toothbrush'
235
+ }
236
+
237
+ mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
238
+ mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
src/data/coco/coco_eval.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ COCO evaluator that works in distributed mode.
4
+
5
+ Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
6
+ The difference is that there is less copy-pasting from pycocotools
7
+ in the end of the file, as python3 can suppress prints with contextlib
8
+ """
9
+ import os
10
+ import contextlib
11
+ import copy
12
+ import numpy as np
13
+ import torch
14
+
15
+ from pycocotools.cocoeval import COCOeval
16
+ from pycocotools.coco import COCO
17
+ import pycocotools.mask as mask_util
18
+
19
+ from src.misc import dist
20
+
21
+
22
+ __all__ = ['CocoEvaluator',]
23
+
24
+
25
+ class CocoEvaluator(object):
26
+ def __init__(self, coco_gt, iou_types):
27
+ assert isinstance(iou_types, (list, tuple))
28
+ coco_gt = copy.deepcopy(coco_gt)
29
+ self.coco_gt = coco_gt
30
+
31
+ self.iou_types = iou_types
32
+ self.coco_eval = {}
33
+ for iou_type in iou_types:
34
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
35
+
36
+ self.img_ids = []
37
+ self.eval_imgs = {k: [] for k in iou_types}
38
+
39
+ def update(self, predictions):
40
+ img_ids = list(np.unique(list(predictions.keys())))
41
+ self.img_ids.extend(img_ids)
42
+
43
+ for iou_type in self.iou_types:
44
+ results = self.prepare(predictions, iou_type)
45
+
46
+ # suppress pycocotools prints
47
+ with open(os.devnull, 'w') as devnull:
48
+ with contextlib.redirect_stdout(devnull):
49
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
50
+ coco_eval = self.coco_eval[iou_type]
51
+
52
+ coco_eval.cocoDt = coco_dt
53
+ coco_eval.params.imgIds = list(img_ids)
54
+ img_ids, eval_imgs = evaluate(coco_eval)
55
+
56
+ self.eval_imgs[iou_type].append(eval_imgs)
57
+
58
+ def synchronize_between_processes(self):
59
+ for iou_type in self.iou_types:
60
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
61
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
62
+
63
+ def accumulate(self):
64
+ for coco_eval in self.coco_eval.values():
65
+ coco_eval.accumulate()
66
+
67
+ def summarize(self):
68
+ for iou_type, coco_eval in self.coco_eval.items():
69
+ print("IoU metric: {}".format(iou_type))
70
+ coco_eval.summarize()
71
+
72
+ def prepare(self, predictions, iou_type):
73
+ if iou_type == "bbox":
74
+ return self.prepare_for_coco_detection(predictions)
75
+ elif iou_type == "segm":
76
+ return self.prepare_for_coco_segmentation(predictions)
77
+ elif iou_type == "keypoints":
78
+ return self.prepare_for_coco_keypoint(predictions)
79
+ else:
80
+ raise ValueError("Unknown iou type {}".format(iou_type))
81
+
82
+ def prepare_for_coco_detection(self, predictions):
83
+ coco_results = []
84
+ for original_id, prediction in predictions.items():
85
+ if len(prediction) == 0:
86
+ continue
87
+
88
+ boxes = prediction["boxes"]
89
+ boxes = convert_to_xywh(boxes).tolist()
90
+ scores = prediction["scores"].tolist()
91
+ labels = prediction["labels"].tolist()
92
+
93
+ coco_results.extend(
94
+ [
95
+ {
96
+ "image_id": original_id,
97
+ "category_id": labels[k],
98
+ "bbox": box,
99
+ "score": scores[k],
100
+ }
101
+ for k, box in enumerate(boxes)
102
+ ]
103
+ )
104
+ return coco_results
105
+
106
+ def prepare_for_coco_segmentation(self, predictions):
107
+ coco_results = []
108
+ for original_id, prediction in predictions.items():
109
+ if len(prediction) == 0:
110
+ continue
111
+
112
+ scores = prediction["scores"]
113
+ labels = prediction["labels"]
114
+ masks = prediction["masks"]
115
+
116
+ masks = masks > 0.5
117
+
118
+ scores = prediction["scores"].tolist()
119
+ labels = prediction["labels"].tolist()
120
+
121
+ rles = [
122
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
123
+ for mask in masks
124
+ ]
125
+ for rle in rles:
126
+ rle["counts"] = rle["counts"].decode("utf-8")
127
+
128
+ coco_results.extend(
129
+ [
130
+ {
131
+ "image_id": original_id,
132
+ "category_id": labels[k],
133
+ "segmentation": rle,
134
+ "score": scores[k],
135
+ }
136
+ for k, rle in enumerate(rles)
137
+ ]
138
+ )
139
+ return coco_results
140
+
141
+ def prepare_for_coco_keypoint(self, predictions):
142
+ coco_results = []
143
+ for original_id, prediction in predictions.items():
144
+ if len(prediction) == 0:
145
+ continue
146
+
147
+ boxes = prediction["boxes"]
148
+ boxes = convert_to_xywh(boxes).tolist()
149
+ scores = prediction["scores"].tolist()
150
+ labels = prediction["labels"].tolist()
151
+ keypoints = prediction["keypoints"]
152
+ keypoints = keypoints.flatten(start_dim=1).tolist()
153
+
154
+ coco_results.extend(
155
+ [
156
+ {
157
+ "image_id": original_id,
158
+ "category_id": labels[k],
159
+ 'keypoints': keypoint,
160
+ "score": scores[k],
161
+ }
162
+ for k, keypoint in enumerate(keypoints)
163
+ ]
164
+ )
165
+ return coco_results
166
+
167
+
168
+ def convert_to_xywh(boxes):
169
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
170
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
171
+
172
+
173
+ def merge(img_ids, eval_imgs):
174
+ all_img_ids = dist.all_gather(img_ids)
175
+ all_eval_imgs = dist.all_gather(eval_imgs)
176
+
177
+ merged_img_ids = []
178
+ for p in all_img_ids:
179
+ merged_img_ids.extend(p)
180
+
181
+ merged_eval_imgs = []
182
+ for p in all_eval_imgs:
183
+ merged_eval_imgs.append(p)
184
+
185
+ merged_img_ids = np.array(merged_img_ids)
186
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
187
+
188
+ # keep only unique (and in sorted order) images
189
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
190
+ merged_eval_imgs = merged_eval_imgs[..., idx]
191
+
192
+ return merged_img_ids, merged_eval_imgs
193
+
194
+
195
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
196
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
197
+ img_ids = list(img_ids)
198
+ eval_imgs = list(eval_imgs.flatten())
199
+
200
+ coco_eval.evalImgs = eval_imgs
201
+ coco_eval.params.imgIds = img_ids
202
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
203
+
204
+
205
+ #################################################################
206
+ # From pycocotools, just removed the prints and fixed
207
+ # a Python3 bug about unicode not defined
208
+ #################################################################
209
+
210
+
211
+ # import io
212
+ # from contextlib import redirect_stdout
213
+ # def evaluate(imgs):
214
+ # with redirect_stdout(io.StringIO()):
215
+ # imgs.evaluate()
216
+ # return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
217
+
218
+
219
+ def evaluate(self):
220
+ '''
221
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
222
+ :return: None
223
+ '''
224
+ # tic = time.time()
225
+ # print('Running per image evaluation...')
226
+ p = self.params
227
+ # add backward compatibility if useSegm is specified in params
228
+ if p.useSegm is not None:
229
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
230
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
231
+ # print('Evaluate annotation type *{}*'.format(p.iouType))
232
+ p.imgIds = list(np.unique(p.imgIds))
233
+ if p.useCats:
234
+ p.catIds = list(np.unique(p.catIds))
235
+ p.maxDets = sorted(p.maxDets)
236
+ self.params = p
237
+
238
+ self._prepare()
239
+ # loop through images, area range, max detection number
240
+ catIds = p.catIds if p.useCats else [-1]
241
+
242
+ if p.iouType == 'segm' or p.iouType == 'bbox':
243
+ computeIoU = self.computeIoU
244
+ elif p.iouType == 'keypoints':
245
+ computeIoU = self.computeOks
246
+ self.ious = {
247
+ (imgId, catId): computeIoU(imgId, catId)
248
+ for imgId in p.imgIds
249
+ for catId in catIds}
250
+
251
+ evaluateImg = self.evaluateImg
252
+ maxDet = p.maxDets[-1]
253
+ evalImgs = [
254
+ evaluateImg(imgId, catId, areaRng, maxDet)
255
+ for catId in catIds
256
+ for areaRng in p.areaRng
257
+ for imgId in p.imgIds
258
+ ]
259
+ # this is NOT in the pycocotools code, but could be done outside
260
+ evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
261
+ self._paramsEval = copy.deepcopy(self.params)
262
+ # toc = time.time()
263
+ # print('DONE (t={:0.2f}s).'.format(toc-tic))
264
+ return p.imgIds, evalImgs
265
+
266
+ #################################################################
267
+ # end of straight copy from pycocotools, just removing the prints
268
+ #################################################################
269
+
src/data/coco/coco_utils.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.utils.data
5
+ import torchvision
6
+ from pycocotools import mask as coco_mask
7
+ from pycocotools.coco import COCO
8
+
9
+
10
+ def convert_coco_poly_to_mask(segmentations, height, width):
11
+ masks = []
12
+ for polygons in segmentations:
13
+ rles = coco_mask.frPyObjects(polygons, height, width)
14
+ mask = coco_mask.decode(rles)
15
+ if len(mask.shape) < 3:
16
+ mask = mask[..., None]
17
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
18
+ mask = mask.any(dim=2)
19
+ masks.append(mask)
20
+ if masks:
21
+ masks = torch.stack(masks, dim=0)
22
+ else:
23
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
24
+ return masks
25
+
26
+
27
+ class ConvertCocoPolysToMask:
28
+ def __call__(self, image, target):
29
+ w, h = image.size
30
+
31
+ image_id = target["image_id"]
32
+
33
+ anno = target["annotations"]
34
+
35
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
36
+
37
+ boxes = [obj["bbox"] for obj in anno]
38
+ # guard against no boxes via resizing
39
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
40
+ boxes[:, 2:] += boxes[:, :2]
41
+ boxes[:, 0::2].clamp_(min=0, max=w)
42
+ boxes[:, 1::2].clamp_(min=0, max=h)
43
+
44
+ classes = [obj["category_id"] for obj in anno]
45
+ classes = torch.tensor(classes, dtype=torch.int64)
46
+
47
+ segmentations = [obj["segmentation"] for obj in anno]
48
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
49
+
50
+ keypoints = None
51
+ if anno and "keypoints" in anno[0]:
52
+ keypoints = [obj["keypoints"] for obj in anno]
53
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
54
+ num_keypoints = keypoints.shape[0]
55
+ if num_keypoints:
56
+ keypoints = keypoints.view(num_keypoints, -1, 3)
57
+
58
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
59
+ boxes = boxes[keep]
60
+ classes = classes[keep]
61
+ masks = masks[keep]
62
+ if keypoints is not None:
63
+ keypoints = keypoints[keep]
64
+
65
+ target = {}
66
+ target["boxes"] = boxes
67
+ target["labels"] = classes
68
+ target["masks"] = masks
69
+ target["image_id"] = image_id
70
+ if keypoints is not None:
71
+ target["keypoints"] = keypoints
72
+
73
+ # for conversion to coco api
74
+ area = torch.tensor([obj["area"] for obj in anno])
75
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
76
+ target["area"] = area
77
+ target["iscrowd"] = iscrowd
78
+
79
+ return image, target
80
+
81
+
82
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
83
+ def _has_only_empty_bbox(anno):
84
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
85
+
86
+ def _count_visible_keypoints(anno):
87
+ return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
88
+
89
+ min_keypoints_per_image = 10
90
+
91
+ def _has_valid_annotation(anno):
92
+ # if it's empty, there is no annotation
93
+ if len(anno) == 0:
94
+ return False
95
+ # if all boxes have close to zero area, there is no annotation
96
+ if _has_only_empty_bbox(anno):
97
+ return False
98
+ # keypoints task have a slight different criteria for considering
99
+ # if an annotation is valid
100
+ if "keypoints" not in anno[0]:
101
+ return True
102
+ # for keypoint detection tasks, only consider valid images those
103
+ # containing at least min_keypoints_per_image
104
+ if _count_visible_keypoints(anno) >= min_keypoints_per_image:
105
+ return True
106
+ return False
107
+
108
+ ids = []
109
+ for ds_idx, img_id in enumerate(dataset.ids):
110
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
111
+ anno = dataset.coco.loadAnns(ann_ids)
112
+ if cat_list:
113
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
114
+ if _has_valid_annotation(anno):
115
+ ids.append(ds_idx)
116
+
117
+ dataset = torch.utils.data.Subset(dataset, ids)
118
+ return dataset
119
+
120
+
121
+ def convert_to_coco_api(ds):
122
+ coco_ds = COCO()
123
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
124
+ ann_id = 1
125
+ dataset = {"images": [], "categories": [], "annotations": []}
126
+ categories = set()
127
+ for img_idx in range(len(ds)):
128
+ # find better way to get target
129
+ # targets = ds.get_annotations(img_idx)
130
+ img, targets = ds[img_idx]
131
+ image_id = targets["image_id"].item()
132
+ img_dict = {}
133
+ img_dict["id"] = image_id
134
+ img_dict["height"] = img.shape[-2]
135
+ img_dict["width"] = img.shape[-1]
136
+ dataset["images"].append(img_dict)
137
+ bboxes = targets["boxes"].clone()
138
+ bboxes[:, 2:] -= bboxes[:, :2]
139
+ bboxes = bboxes.tolist()
140
+ labels = targets["labels"].tolist()
141
+ areas = targets["area"].tolist()
142
+ iscrowd = targets["iscrowd"].tolist()
143
+ if "masks" in targets:
144
+ masks = targets["masks"]
145
+ # make masks Fortran contiguous for coco_mask
146
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
147
+ if "keypoints" in targets:
148
+ keypoints = targets["keypoints"]
149
+ keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
150
+ num_objs = len(bboxes)
151
+ for i in range(num_objs):
152
+ ann = {}
153
+ ann["image_id"] = image_id
154
+ ann["bbox"] = bboxes[i]
155
+ ann["category_id"] = labels[i]
156
+ categories.add(labels[i])
157
+ ann["area"] = areas[i]
158
+ ann["iscrowd"] = iscrowd[i]
159
+ ann["id"] = ann_id
160
+ if "masks" in targets:
161
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
162
+ if "keypoints" in targets:
163
+ ann["keypoints"] = keypoints[i]
164
+ ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
165
+ dataset["annotations"].append(ann)
166
+ ann_id += 1
167
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
168
+ coco_ds.dataset = dataset
169
+ coco_ds.createIndex()
170
+ return coco_ds
171
+
172
+
173
+ def get_coco_api_from_dataset(dataset):
174
+ # FIXME: This is... awful?
175
+ for _ in range(10):
176
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
177
+ break
178
+ if isinstance(dataset, torch.utils.data.Subset):
179
+ dataset = dataset.dataset
180
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
181
+ return dataset.coco
182
+ return convert_to_coco_api(dataset)
183
+
184
+
src/data/dataloader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+
4
+ from src.core import register
5
+
6
+
7
+ __all__ = ['DataLoader']
8
+
9
+
10
+ @register
11
+ class DataLoader(data.DataLoader):
12
+ __inject__ = ['dataset', 'collate_fn']
13
+
14
+ def __repr__(self) -> str:
15
+ format_string = self.__class__.__name__ + "("
16
+ for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']:
17
+ format_string += "\n"
18
+ format_string += " {0}: {1}".format(n, getattr(self, n))
19
+ format_string += "\n)"
20
+ return format_string
21
+
22
+
23
+
24
+ @register
25
+ def default_collate_fn(items):
26
+ '''default collate_fn
27
+ '''
28
+ return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
src/data/functional.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms.functional as F
3
+
4
+ from packaging import version
5
+ from typing import Optional, List
6
+ from torch import Tensor
7
+
8
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
9
+ import torchvision
10
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
11
+ from torchvision.ops import _new_empty_tensor
12
+ from torchvision.ops.misc import _output_size
13
+
14
+
15
+ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
16
+ # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
17
+ """
18
+ Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
19
+ This will eventually be supported natively by PyTorch, and this
20
+ class can go away.
21
+ """
22
+ if version.parse(torchvision.__version__) < version.parse('0.7'):
23
+ if input.numel() > 0:
24
+ return torch.nn.functional.interpolate(
25
+ input, size, scale_factor, mode, align_corners
26
+ )
27
+
28
+ output_shape = _output_size(2, input, size, scale_factor)
29
+ output_shape = list(input.shape[:-2]) + list(output_shape)
30
+ return _new_empty_tensor(input, output_shape)
31
+ else:
32
+ return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
33
+
34
+
35
+
36
+ def crop(image, target, region):
37
+ cropped_image = F.crop(image, *region)
38
+
39
+ target = target.copy()
40
+ i, j, h, w = region
41
+
42
+ # should we do something wrt the original size?
43
+ target["size"] = torch.tensor([h, w])
44
+
45
+ fields = ["labels", "area", "iscrowd"]
46
+
47
+ if "boxes" in target:
48
+ boxes = target["boxes"]
49
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
50
+ cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
51
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
52
+ cropped_boxes = cropped_boxes.clamp(min=0)
53
+ area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
54
+ target["boxes"] = cropped_boxes.reshape(-1, 4)
55
+ target["area"] = area
56
+ fields.append("boxes")
57
+
58
+ if "masks" in target:
59
+ # FIXME should we update the area here if there are no boxes?
60
+ target['masks'] = target['masks'][:, i:i + h, j:j + w]
61
+ fields.append("masks")
62
+
63
+ # remove elements for which the boxes or masks that have zero area
64
+ if "boxes" in target or "masks" in target:
65
+ # favor boxes selection when defining which elements to keep
66
+ # this is compatible with previous implementation
67
+ if "boxes" in target:
68
+ cropped_boxes = target['boxes'].reshape(-1, 2, 2)
69
+ keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
70
+ else:
71
+ keep = target['masks'].flatten(1).any(1)
72
+
73
+ for field in fields:
74
+ target[field] = target[field][keep]
75
+
76
+ return cropped_image, target
77
+
78
+
79
+ def hflip(image, target):
80
+ flipped_image = F.hflip(image)
81
+
82
+ w, h = image.size
83
+
84
+ target = target.copy()
85
+ if "boxes" in target:
86
+ boxes = target["boxes"]
87
+ boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
88
+ target["boxes"] = boxes
89
+
90
+ if "masks" in target:
91
+ target['masks'] = target['masks'].flip(-1)
92
+
93
+ return flipped_image, target
94
+
95
+
96
+ def resize(image, target, size, max_size=None):
97
+ # size can be min_size (scalar) or (w, h) tuple
98
+
99
+ def get_size_with_aspect_ratio(image_size, size, max_size=None):
100
+ w, h = image_size
101
+ if max_size is not None:
102
+ min_original_size = float(min((w, h)))
103
+ max_original_size = float(max((w, h)))
104
+ if max_original_size / min_original_size * size > max_size:
105
+ size = int(round(max_size * min_original_size / max_original_size))
106
+
107
+ if (w <= h and w == size) or (h <= w and h == size):
108
+ return (h, w)
109
+
110
+ if w < h:
111
+ ow = size
112
+ oh = int(size * h / w)
113
+ else:
114
+ oh = size
115
+ ow = int(size * w / h)
116
+
117
+ # r = min(size / min(h, w), max_size / max(h, w))
118
+ # ow = int(w * r)
119
+ # oh = int(h * r)
120
+
121
+ return (oh, ow)
122
+
123
+ def get_size(image_size, size, max_size=None):
124
+ if isinstance(size, (list, tuple)):
125
+ return size[::-1]
126
+ else:
127
+ return get_size_with_aspect_ratio(image_size, size, max_size)
128
+
129
+ size = get_size(image.size, size, max_size)
130
+ rescaled_image = F.resize(image, size)
131
+
132
+ if target is None:
133
+ return rescaled_image, None
134
+
135
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
136
+ ratio_width, ratio_height = ratios
137
+
138
+ target = target.copy()
139
+ if "boxes" in target:
140
+ boxes = target["boxes"]
141
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
142
+ target["boxes"] = scaled_boxes
143
+
144
+ if "area" in target:
145
+ area = target["area"]
146
+ scaled_area = area * (ratio_width * ratio_height)
147
+ target["area"] = scaled_area
148
+
149
+ h, w = size
150
+ target["size"] = torch.tensor([h, w])
151
+
152
+ if "masks" in target:
153
+ target['masks'] = interpolate(
154
+ target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
155
+
156
+ return rescaled_image, target
157
+
158
+
159
+ def pad(image, target, padding):
160
+ # assumes that we only pad on the bottom right corners
161
+ padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
162
+ if target is None:
163
+ return padded_image, None
164
+ target = target.copy()
165
+ # should we do something wrt the original size?
166
+ target["size"] = torch.tensor(padded_image.size[::-1])
167
+ if "masks" in target:
168
+ target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
169
+ return padded_image, target
src/data/transforms.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"by lyuwenyu
2
+ """
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ import torchvision
9
+ torchvision.disable_beta_transforms_warning()
10
+ from torchvision import datapoints
11
+
12
+ import torchvision.transforms.v2 as T
13
+ import torchvision.transforms.v2.functional as F
14
+
15
+ from PIL import Image
16
+ from typing import Any, Dict, List, Optional
17
+
18
+ from src.core import register, GLOBAL_CONFIG
19
+
20
+
21
+ __all__ = ['Compose', ]
22
+
23
+
24
+ RandomPhotometricDistort = register(T.RandomPhotometricDistort)
25
+ RandomZoomOut = register(T.RandomZoomOut)
26
+ # RandomIoUCrop = register(T.RandomIoUCrop)
27
+ RandomHorizontalFlip = register(T.RandomHorizontalFlip)
28
+ Resize = register(T.Resize)
29
+ ToImageTensor = register(T.ToImageTensor)
30
+ ConvertDtype = register(T.ConvertDtype)
31
+ SanitizeBoundingBox = register(T.SanitizeBoundingBox)
32
+ RandomCrop = register(T.RandomCrop)
33
+ Normalize = register(T.Normalize)
34
+
35
+
36
+
37
+ @register
38
+ class Compose(T.Compose):
39
+ def __init__(self, ops) -> None:
40
+ transforms = []
41
+ if ops is not None:
42
+ for op in ops:
43
+ if isinstance(op, dict):
44
+ name = op.pop('type')
45
+ transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op)
46
+ transforms.append(transfom)
47
+ # op['type'] = name
48
+ elif isinstance(op, nn.Module):
49
+ transforms.append(op)
50
+
51
+ else:
52
+ raise ValueError('')
53
+ else:
54
+ transforms =[EmptyTransform(), ]
55
+
56
+ super().__init__(transforms=transforms)
57
+
58
+
59
+ @register
60
+ class EmptyTransform(T.Transform):
61
+ def __init__(self, ) -> None:
62
+ super().__init__()
63
+
64
+ def forward(self, *inputs):
65
+ inputs = inputs if len(inputs) > 1 else inputs[0]
66
+ return inputs
67
+
68
+
69
+ @register
70
+ class PadToSize(T.Pad):
71
+ _transformed_types = (
72
+ Image.Image,
73
+ datapoints.Image,
74
+ datapoints.Video,
75
+ datapoints.Mask,
76
+ datapoints.BoundingBox,
77
+ )
78
+ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
79
+ sz = F.get_spatial_size(flat_inputs[0])
80
+ h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1]
81
+ self.padding = [0, 0, w, h]
82
+ return dict(padding=self.padding)
83
+
84
+ def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
85
+ if isinstance(spatial_size, int):
86
+ spatial_size = (spatial_size, spatial_size)
87
+
88
+ self.spatial_size = spatial_size
89
+ super().__init__(0, fill, padding_mode)
90
+
91
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
92
+ fill = self._fill[type(inpt)]
93
+ padding = params['padding']
94
+ return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
95
+
96
+ def __call__(self, *inputs: Any) -> Any:
97
+ outputs = super().forward(*inputs)
98
+ if len(outputs) > 1 and isinstance(outputs[1], dict):
99
+ outputs[1]['padding'] = torch.tensor(self.padding)
100
+ return outputs
101
+
102
+
103
+ @register
104
+ class RandomIoUCrop(T.RandomIoUCrop):
105
+ def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0):
106
+ super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials)
107
+ self.p = p
108
+
109
+ def __call__(self, *inputs: Any) -> Any:
110
+ if torch.rand(1) >= self.p:
111
+ return inputs if len(inputs) > 1 else inputs[0]
112
+
113
+ return super().forward(*inputs)
114
+
115
+
116
+ @register
117
+ class ConvertBox(T.Transform):
118
+ _transformed_types = (
119
+ datapoints.BoundingBox,
120
+ )
121
+ def __init__(self, out_fmt='', normalize=False) -> None:
122
+ super().__init__()
123
+ self.out_fmt = out_fmt
124
+ self.normalize = normalize
125
+
126
+ self.data_fmt = {
127
+ 'xyxy': datapoints.BoundingBoxFormat.XYXY,
128
+ 'cxcywh': datapoints.BoundingBoxFormat.CXCYWH
129
+ }
130
+
131
+ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
132
+ if self.out_fmt:
133
+ spatial_size = inpt.spatial_size
134
+ in_fmt = inpt.format.value.lower()
135
+ inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt)
136
+ inpt = datapoints.BoundingBox(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size)
137
+
138
+ if self.normalize:
139
+ inpt = inpt / torch.tensor(inpt.spatial_size[::-1]).tile(2)[None]
140
+
141
+ return inpt
142
+
src/misc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ from .logger import *
3
+ from .visualizer import *
src/misc/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
src/misc/__pycache__/dist.cpython-310.pyc ADDED
Binary file (4.79 kB). View file
 
src/misc/__pycache__/logger.cpython-310.pyc ADDED
Binary file (7.77 kB). View file
 
src/misc/__pycache__/visualizer.cpython-310.pyc ADDED
Binary file (1.09 kB). View file
 
src/misc/dist.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ reference
3
+ - https://github.com/pytorch/vision/blob/main/references/detection/utils.py
4
+ - https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
5
+
6
+ by lyuwenyu
7
+ """
8
+
9
+ import random
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.distributed
15
+ import torch.distributed as tdist
16
+
17
+ from torch.nn.parallel import DistributedDataParallel as DDP
18
+
19
+ from torch.utils.data import DistributedSampler
20
+ from torch.utils.data.dataloader import DataLoader
21
+
22
+
23
+ def init_distributed():
24
+ '''
25
+ distributed setup
26
+ args:
27
+ backend (str), ('nccl', 'gloo')
28
+ '''
29
+ try:
30
+ # # https://pytorch.org/docs/stable/elastic/run.html
31
+ # LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
32
+ # RANK = int(os.getenv('RANK', -1))
33
+ # WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
34
+
35
+ tdist.init_process_group(init_method='env://', )
36
+ torch.distributed.barrier()
37
+
38
+ rank = get_rank()
39
+ device = torch.device(f'cuda:{rank}')
40
+ torch.cuda.set_device(device)
41
+
42
+ setup_print(rank == 0)
43
+ print('Initialized distributed mode...')
44
+
45
+ return True
46
+
47
+ except:
48
+ print('Not init distributed mode.')
49
+ return False
50
+
51
+
52
+ def setup_print(is_main):
53
+ '''This function disables printing when not in master process
54
+ '''
55
+ import builtins as __builtin__
56
+ builtin_print = __builtin__.print
57
+
58
+ def print(*args, **kwargs):
59
+ force = kwargs.pop('force', False)
60
+ if is_main or force:
61
+ builtin_print(*args, **kwargs)
62
+
63
+ __builtin__.print = print
64
+
65
+
66
+ def is_dist_available_and_initialized():
67
+ if not tdist.is_available():
68
+ return False
69
+ if not tdist.is_initialized():
70
+ return False
71
+ return True
72
+
73
+
74
+ def get_rank():
75
+ if not is_dist_available_and_initialized():
76
+ return 0
77
+ return tdist.get_rank()
78
+
79
+
80
+ def get_world_size():
81
+ if not is_dist_available_and_initialized():
82
+ return 1
83
+ return tdist.get_world_size()
84
+
85
+
86
+ def is_main_process():
87
+ return get_rank() == 0
88
+
89
+
90
+ def save_on_master(*args, **kwargs):
91
+ if is_main_process():
92
+ torch.save(*args, **kwargs)
93
+
94
+
95
+
96
+ def warp_model(model, find_unused_parameters=False, sync_bn=False,):
97
+ if is_dist_available_and_initialized():
98
+ rank = get_rank()
99
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
100
+ model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
101
+ return model
102
+
103
+
104
+ def warp_loader(loader, shuffle=False):
105
+ if is_dist_available_and_initialized():
106
+ sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
107
+ loader = DataLoader(loader.dataset,
108
+ loader.batch_size,
109
+ sampler=sampler,
110
+ drop_last=loader.drop_last,
111
+ collate_fn=loader.collate_fn,
112
+ pin_memory=loader.pin_memory,
113
+ num_workers=loader.num_workers, )
114
+ return loader
115
+
116
+
117
+
118
+ def is_parallel(model) -> bool:
119
+ # Returns True if model is of type DP or DDP
120
+ return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel)
121
+
122
+
123
+ def de_parallel(model) -> nn.Module:
124
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
125
+ return model.module if is_parallel(model) else model
126
+
127
+
128
+ def reduce_dict(data, avg=True):
129
+ '''
130
+ Args
131
+ data dict: input, {k: v, ...}
132
+ avg bool: true
133
+ '''
134
+ world_size = get_world_size()
135
+ if world_size < 2:
136
+ return data
137
+
138
+ with torch.no_grad():
139
+ keys, values = [], []
140
+ for k in sorted(data.keys()):
141
+ keys.append(k)
142
+ values.append(data[k])
143
+
144
+ values = torch.stack(values, dim=0)
145
+ tdist.all_reduce(values)
146
+
147
+ if avg is True:
148
+ values /= world_size
149
+
150
+ _data = {k: v for k, v in zip(keys, values)}
151
+
152
+ return _data
153
+
154
+
155
+
156
+ def all_gather(data):
157
+ """
158
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
159
+ Args:
160
+ data: any picklable object
161
+ Returns:
162
+ list[data]: list of data gathered from each rank
163
+ """
164
+ world_size = get_world_size()
165
+ if world_size == 1:
166
+ return [data]
167
+ data_list = [None] * world_size
168
+ tdist.all_gather_object(data_list, data)
169
+ return data_list
170
+
171
+
172
+ import time
173
+ def sync_time():
174
+ '''sync_time
175
+ '''
176
+ if torch.cuda.is_available():
177
+ torch.cuda.synchronize()
178
+
179
+ return time.time()
180
+
181
+
182
+
183
+ def set_seed(seed):
184
+ # fix the seed for reproducibility
185
+ seed = seed + get_rank()
186
+ torch.manual_seed(seed)
187
+ np.random.seed(seed)
188
+ random.seed(seed)
189
+
190
+
src/misc/logger.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ https://github.com/facebookresearch/detr/blob/main/util/misc.py
4
+ Mostly copy-paste from torchvision references.
5
+ """
6
+
7
+ import time
8
+ import pickle
9
+ import datetime
10
+ from collections import defaultdict, deque
11
+ from typing import Dict
12
+
13
+ import torch
14
+ import torch.distributed as tdist
15
+
16
+ from .dist import is_dist_available_and_initialized, get_world_size
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not is_dist_available_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
44
+ tdist.barrier()
45
+ tdist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value)
79
+
80
+
81
+ def all_gather(data):
82
+ """
83
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
84
+ Args:
85
+ data: any picklable object
86
+ Returns:
87
+ list[data]: list of data gathered from each rank
88
+ """
89
+ world_size = get_world_size()
90
+ if world_size == 1:
91
+ return [data]
92
+
93
+ # serialized to a Tensor
94
+ buffer = pickle.dumps(data)
95
+ storage = torch.ByteStorage.from_buffer(buffer)
96
+ tensor = torch.ByteTensor(storage).to("cuda")
97
+
98
+ # obtain Tensor size of each rank
99
+ local_size = torch.tensor([tensor.numel()], device="cuda")
100
+ size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
101
+ tdist.all_gather(size_list, local_size)
102
+ size_list = [int(size.item()) for size in size_list]
103
+ max_size = max(size_list)
104
+
105
+ # receiving Tensor from all ranks
106
+ # we pad the tensor because torch all_gather does not support
107
+ # gathering tensors of different shapes
108
+ tensor_list = []
109
+ for _ in size_list:
110
+ tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
111
+ if local_size != max_size:
112
+ padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
113
+ tensor = torch.cat((tensor, padding), dim=0)
114
+ tdist.all_gather(tensor_list, tensor)
115
+
116
+ data_list = []
117
+ for size, tensor in zip(size_list, tensor_list):
118
+ buffer = tensor.cpu().numpy().tobytes()[:size]
119
+ data_list.append(pickle.loads(buffer))
120
+
121
+ return data_list
122
+
123
+
124
+ def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]:
125
+ """
126
+ Args:
127
+ input_dict (dict): all the values will be reduced
128
+ average (bool): whether to do average or sum
129
+ Reduce the values in the dictionary from all processes so that all processes
130
+ have the averaged results. Returns a dict with the same fields as
131
+ input_dict, after reduction.
132
+ """
133
+ world_size = get_world_size()
134
+ if world_size < 2:
135
+ return input_dict
136
+ with torch.no_grad():
137
+ names = []
138
+ values = []
139
+ # sort the keys so that they are consistent across processes
140
+ for k in sorted(input_dict.keys()):
141
+ names.append(k)
142
+ values.append(input_dict[k])
143
+ values = torch.stack(values, dim=0)
144
+ tdist.all_reduce(values)
145
+ if average:
146
+ values /= world_size
147
+ reduced_dict = {k: v for k, v in zip(names, values)}
148
+ return reduced_dict
149
+
150
+
151
+ class MetricLogger(object):
152
+ def __init__(self, delimiter="\t"):
153
+ self.meters = defaultdict(SmoothedValue)
154
+ self.delimiter = delimiter
155
+
156
+ def update(self, **kwargs):
157
+ for k, v in kwargs.items():
158
+ if isinstance(v, torch.Tensor):
159
+ v = v.item()
160
+ assert isinstance(v, (float, int))
161
+ self.meters[k].update(v)
162
+
163
+ def __getattr__(self, attr):
164
+ if attr in self.meters:
165
+ return self.meters[attr]
166
+ if attr in self.__dict__:
167
+ return self.__dict__[attr]
168
+ raise AttributeError("'{}' object has no attribute '{}'".format(
169
+ type(self).__name__, attr))
170
+
171
+ def __str__(self):
172
+ loss_str = []
173
+ for name, meter in self.meters.items():
174
+ loss_str.append(
175
+ "{}: {}".format(name, str(meter))
176
+ )
177
+ return self.delimiter.join(loss_str)
178
+
179
+ def synchronize_between_processes(self):
180
+ for meter in self.meters.values():
181
+ meter.synchronize_between_processes()
182
+
183
+ def add_meter(self, name, meter):
184
+ self.meters[name] = meter
185
+
186
+ def log_every(self, iterable, print_freq, header=None):
187
+ i = 0
188
+ if not header:
189
+ header = ''
190
+ start_time = time.time()
191
+ end = time.time()
192
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
193
+ data_time = SmoothedValue(fmt='{avg:.4f}')
194
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
195
+ if torch.cuda.is_available():
196
+ log_msg = self.delimiter.join([
197
+ header,
198
+ '[{0' + space_fmt + '}/{1}]',
199
+ 'eta: {eta}',
200
+ '{meters}',
201
+ 'time: {time}',
202
+ 'data: {data}',
203
+ 'max mem: {memory:.0f}'
204
+ ])
205
+ else:
206
+ log_msg = self.delimiter.join([
207
+ header,
208
+ '[{0' + space_fmt + '}/{1}]',
209
+ 'eta: {eta}',
210
+ '{meters}',
211
+ 'time: {time}',
212
+ 'data: {data}'
213
+ ])
214
+ MB = 1024.0 * 1024.0
215
+ for obj in iterable:
216
+ data_time.update(time.time() - end)
217
+ yield obj
218
+ iter_time.update(time.time() - end)
219
+ if i % print_freq == 0 or i == len(iterable) - 1:
220
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
221
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
222
+ if torch.cuda.is_available():
223
+ print(log_msg.format(
224
+ i, len(iterable), eta=eta_string,
225
+ meters=str(self),
226
+ time=str(iter_time), data=str(data_time),
227
+ memory=torch.cuda.max_memory_allocated() / MB))
228
+ else:
229
+ print(log_msg.format(
230
+ i, len(iterable), eta=eta_string,
231
+ meters=str(self),
232
+ time=str(iter_time), data=str(data_time)))
233
+ i += 1
234
+ end = time.time()
235
+ total_time = time.time() - start_time
236
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
237
+ print('{} Total time: {} ({:.4f} s / it)'.format(
238
+ header, total_time_str, total_time / len(iterable)))
239
+
src/misc/visualizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"by lyuwenyu
2
+ """
3
+
4
+ import torch
5
+ import torch.utils.data
6
+
7
+ import torchvision
8
+ torchvision.disable_beta_transforms_warning()
9
+
10
+ import PIL
11
+
12
+ __all__ = ['show_sample']
13
+
14
+ def show_sample(sample):
15
+ """for coco dataset/dataloader
16
+ """
17
+ import matplotlib.pyplot as plt
18
+ from torchvision.transforms.v2 import functional as F
19
+ from torchvision.utils import draw_bounding_boxes
20
+
21
+ image, target = sample
22
+ if isinstance(image, PIL.Image.Image):
23
+ image = F.to_image_tensor(image)
24
+
25
+ image = F.convert_dtype(image, torch.uint8)
26
+ annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
27
+
28
+ fig, ax = plt.subplots()
29
+ ax.imshow(annotated_image.permute(1, 2, 0).numpy())
30
+ ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
31
+ fig.tight_layout()
32
+ fig.show()
33
+ plt.show()
34
+
src/nn/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ from .arch import *
3
+ from .criterion import *
4
+
5
+ #
6
+ from .backbone import *
7
+
src/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (226 Bytes). View file
 
src/nn/arch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .classification import *
src/nn/arch/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (200 Bytes). View file
 
src/nn/arch/__pycache__/classification.cpython-310.pyc ADDED
Binary file (1.55 kB). View file
 
src/nn/arch/classification.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from src.core import register
5
+
6
+
7
+ __all__ = ['Classification', 'ClassHead']
8
+
9
+
10
+ @register
11
+ class Classification(nn.Module):
12
+ __inject__ = ['backbone', 'head']
13
+
14
+ def __init__(self, backbone: nn.Module, head: nn.Module=None):
15
+ super().__init__()
16
+
17
+ self.backbone = backbone
18
+ self.head = head
19
+
20
+ def forward(self, x):
21
+ x = self.backbone(x)
22
+
23
+ if self.head is not None:
24
+ x = self.head(x)
25
+
26
+ return x
27
+
28
+
29
+ @register
30
+ class ClassHead(nn.Module):
31
+ def __init__(self, hidden_dim, num_classes):
32
+ super().__init__()
33
+ self.pool = nn.AdaptiveAvgPool2d(1)
34
+ self.proj = nn.Linear(hidden_dim, num_classes)
35
+
36
+ def forward(self, x):
37
+ x = x[0] if isinstance(x, (list, tuple)) else x
38
+ x = self.pool(x)
39
+ x = x.reshape(x.shape[0], -1)
40
+ x = self.proj(x)
41
+ return x
src/nn/backbone/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from .presnet import *
3
+ from .test_resnet import *
4
+
5
+ from .common import *
src/nn/backbone/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (238 Bytes). View file
 
src/nn/backbone/__pycache__/common.cpython-310.pyc ADDED
Binary file (3.28 kB). View file
 
src/nn/backbone/__pycache__/presnet.cpython-310.pyc ADDED
Binary file (6.45 kB). View file
 
src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc ADDED
Binary file (3.06 kB). View file
 
src/nn/backbone/common.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''by lyuwenyu
2
+ '''
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+
9
+ class ConvNormLayer(nn.Module):
10
+ def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
11
+ super().__init__()
12
+ self.conv = nn.Conv2d(
13
+ ch_in,
14
+ ch_out,
15
+ kernel_size,
16
+ stride,
17
+ padding=(kernel_size-1)//2 if padding is None else padding,
18
+ bias=bias)
19
+ self.norm = nn.BatchNorm2d(ch_out)
20
+ self.act = nn.Identity() if act is None else get_activation(act)
21
+
22
+ def forward(self, x):
23
+ return self.act(self.norm(self.conv(x)))
24
+
25
+
26
+ class FrozenBatchNorm2d(nn.Module):
27
+ """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
28
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
29
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
30
+ without which any other models than torchvision.models.resnet[18,34,50,101]
31
+ produce nans.
32
+ """
33
+ def __init__(self, num_features, eps=1e-5):
34
+ super(FrozenBatchNorm2d, self).__init__()
35
+ n = num_features
36
+ self.register_buffer("weight", torch.ones(n))
37
+ self.register_buffer("bias", torch.zeros(n))
38
+ self.register_buffer("running_mean", torch.zeros(n))
39
+ self.register_buffer("running_var", torch.ones(n))
40
+ self.eps = eps
41
+ self.num_features = n
42
+
43
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
44
+ missing_keys, unexpected_keys, error_msgs):
45
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
46
+ if num_batches_tracked_key in state_dict:
47
+ del state_dict[num_batches_tracked_key]
48
+
49
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
50
+ state_dict, prefix, local_metadata, strict,
51
+ missing_keys, unexpected_keys, error_msgs)
52
+
53
+ def forward(self, x):
54
+ # move reshapes to the beginning
55
+ # to make it fuser-friendly
56
+ w = self.weight.reshape(1, -1, 1, 1)
57
+ b = self.bias.reshape(1, -1, 1, 1)
58
+ rv = self.running_var.reshape(1, -1, 1, 1)
59
+ rm = self.running_mean.reshape(1, -1, 1, 1)
60
+ scale = w * (rv + self.eps).rsqrt()
61
+ bias = b - rm * scale
62
+ return x * scale + bias
63
+
64
+ def extra_repr(self):
65
+ return (
66
+ "{num_features}, eps={eps}".format(**self.__dict__)
67
+ )
68
+
69
+
70
+ def get_activation(act: str, inpace: bool=True):
71
+ '''get activation
72
+ '''
73
+ act = act.lower()
74
+
75
+ if act == 'silu':
76
+ m = nn.SiLU()
77
+
78
+ elif act == 'relu':
79
+ m = nn.ReLU()
80
+
81
+ elif act == 'leaky_relu':
82
+ m = nn.LeakyReLU()
83
+
84
+ elif act == 'silu':
85
+ m = nn.SiLU()
86
+
87
+ elif act == 'gelu':
88
+ m = nn.GELU()
89
+
90
+ elif act is None:
91
+ m = nn.Identity()
92
+
93
+ elif isinstance(act, nn.Module):
94
+ m = act
95
+
96
+ else:
97
+ raise RuntimeError('')
98
+
99
+ if hasattr(m, 'inplace'):
100
+ m.inplace = inpace
101
+
102
+ return m
src/nn/backbone/presnet.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''by lyuwenyu
2
+ '''
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from collections import OrderedDict
8
+
9
+ from .common import get_activation, ConvNormLayer, FrozenBatchNorm2d
10
+
11
+ from src.core import register
12
+
13
+
14
+ __all__ = ['PResNet']
15
+
16
+
17
+ ResNet_cfg = {
18
+ 18: [2, 2, 2, 2],
19
+ 34: [3, 4, 6, 3],
20
+ 50: [3, 4, 6, 3],
21
+ 101: [3, 4, 23, 3],
22
+ # 152: [3, 8, 36, 3],
23
+ }
24
+
25
+
26
+ donwload_url = {
27
+ 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth',
28
+ 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth',
29
+ 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth',
30
+ 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth',
31
+ }
32
+
33
+
34
+ class BasicBlock(nn.Module):
35
+ expansion = 1
36
+
37
+ def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
38
+ super().__init__()
39
+
40
+ self.shortcut = shortcut
41
+
42
+ if not shortcut:
43
+ if variant == 'd' and stride == 2:
44
+ self.short = nn.Sequential(OrderedDict([
45
+ ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
46
+ ('conv', ConvNormLayer(ch_in, ch_out, 1, 1))
47
+ ]))
48
+ else:
49
+ self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
50
+
51
+ self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
52
+ self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
53
+ self.act = nn.Identity() if act is None else get_activation(act)
54
+
55
+
56
+ def forward(self, x):
57
+ out = self.branch2a(x)
58
+ out = self.branch2b(out)
59
+ if self.shortcut:
60
+ short = x
61
+ else:
62
+ short = self.short(x)
63
+
64
+ out = out + short
65
+ out = self.act(out)
66
+
67
+ return out
68
+
69
+
70
+ class BottleNeck(nn.Module):
71
+ expansion = 4
72
+
73
+ def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
74
+ super().__init__()
75
+
76
+ if variant == 'a':
77
+ stride1, stride2 = stride, 1
78
+ else:
79
+ stride1, stride2 = 1, stride
80
+
81
+ width = ch_out
82
+
83
+ self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
84
+ self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
85
+ self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
86
+
87
+ self.shortcut = shortcut
88
+ if not shortcut:
89
+ if variant == 'd' and stride == 2:
90
+ self.short = nn.Sequential(OrderedDict([
91
+ ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
92
+ ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
93
+ ]))
94
+ else:
95
+ self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
96
+
97
+ self.act = nn.Identity() if act is None else get_activation(act)
98
+
99
+ def forward(self, x):
100
+ out = self.branch2a(x)
101
+ out = self.branch2b(out)
102
+ out = self.branch2c(out)
103
+
104
+ if self.shortcut:
105
+ short = x
106
+ else:
107
+ short = self.short(x)
108
+
109
+ out = out + short
110
+ out = self.act(out)
111
+
112
+ return out
113
+
114
+
115
+ class Blocks(nn.Module):
116
+ def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'):
117
+ super().__init__()
118
+
119
+ self.blocks = nn.ModuleList()
120
+ for i in range(count):
121
+ self.blocks.append(
122
+ block(
123
+ ch_in,
124
+ ch_out,
125
+ stride=2 if i == 0 and stage_num != 2 else 1,
126
+ shortcut=False if i == 0 else True,
127
+ variant=variant,
128
+ act=act)
129
+ )
130
+
131
+ if i == 0:
132
+ ch_in = ch_out * block.expansion
133
+
134
+ def forward(self, x):
135
+ out = x
136
+ for block in self.blocks:
137
+ out = block(out)
138
+ return out
139
+
140
+
141
+ @register
142
+ class PResNet(nn.Module):
143
+ def __init__(
144
+ self,
145
+ depth,
146
+ variant='d',
147
+ num_stages=4,
148
+ return_idx=[0, 1, 2, 3],
149
+ act='relu',
150
+ freeze_at=-1,
151
+ freeze_norm=True,
152
+ pretrained=False):
153
+ super().__init__()
154
+
155
+ block_nums = ResNet_cfg[depth]
156
+ ch_in = 64
157
+ if variant in ['c', 'd']:
158
+ conv_def = [
159
+ [3, ch_in // 2, 3, 2, "conv1_1"],
160
+ [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
161
+ [ch_in // 2, ch_in, 3, 1, "conv1_3"],
162
+ ]
163
+ else:
164
+ conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
165
+
166
+ self.conv1 = nn.Sequential(OrderedDict([
167
+ (_name, ConvNormLayer(c_in, c_out, k, s, act=act)) for c_in, c_out, k, s, _name in conv_def
168
+ ]))
169
+
170
+ ch_out_list = [64, 128, 256, 512]
171
+ block = BottleNeck if depth >= 50 else BasicBlock
172
+
173
+ _out_channels = [block.expansion * v for v in ch_out_list]
174
+ _out_strides = [4, 8, 16, 32]
175
+
176
+ self.res_layers = nn.ModuleList()
177
+ for i in range(num_stages):
178
+ stage_num = i + 2
179
+ self.res_layers.append(
180
+ Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant)
181
+ )
182
+ ch_in = _out_channels[i]
183
+
184
+ self.return_idx = return_idx
185
+ self.out_channels = [_out_channels[_i] for _i in return_idx]
186
+ self.out_strides = [_out_strides[_i] for _i in return_idx]
187
+
188
+ if freeze_at >= 0:
189
+ self._freeze_parameters(self.conv1)
190
+ for i in range(min(freeze_at, num_stages)):
191
+ self._freeze_parameters(self.res_layers[i])
192
+
193
+ if freeze_norm:
194
+ self._freeze_norm(self)
195
+
196
+ if pretrained:
197
+ state = torch.hub.load_state_dict_from_url(donwload_url[depth])
198
+ self.load_state_dict(state)
199
+ print(f'Load PResNet{depth} state_dict')
200
+
201
+ def _freeze_parameters(self, m: nn.Module):
202
+ for p in m.parameters():
203
+ p.requires_grad = False
204
+
205
+ def _freeze_norm(self, m: nn.Module):
206
+ if isinstance(m, nn.BatchNorm2d):
207
+ m = FrozenBatchNorm2d(m.num_features)
208
+ else:
209
+ for name, child in m.named_children():
210
+ _child = self._freeze_norm(child)
211
+ if _child is not child:
212
+ setattr(m, name, _child)
213
+ return m
214
+
215
+ def forward(self, x):
216
+ conv1 = self.conv1(x)
217
+ x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
218
+ outs = []
219
+ for idx, stage in enumerate(self.res_layers):
220
+ x = stage(x)
221
+ if idx in self.return_idx:
222
+ outs.append(x)
223
+ return outs
224
+
225
+
src/nn/backbone/test_resnet.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from collections import OrderedDict
6
+
7
+
8
+ from src.core import register
9
+
10
+
11
+ class BasicBlock(nn.Module):
12
+ expansion = 1
13
+
14
+ def __init__(self, in_planes, planes, stride=1):
15
+ super(BasicBlock, self).__init__()
16
+
17
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.shortcut = nn.Sequential()
24
+ if stride != 1 or in_planes != self.expansion*planes:
25
+ self.shortcut = nn.Sequential(
26
+ nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False),
27
+ nn.BatchNorm2d(self.expansion*planes)
28
+ )
29
+ def forward(self, x):
30
+ out = F.relu(self.bn1(self.conv1(x)))
31
+ out = self.bn2(self.conv2(out))
32
+ out += self.shortcut(x)
33
+ out = F.relu(out)
34
+ return out
35
+
36
+
37
+
38
+ class _ResNet(nn.Module):
39
+ def __init__(self, block, num_blocks, num_classes=10):
40
+ super().__init__()
41
+ self.in_planes = 64
42
+
43
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
44
+ self.bn1 = nn.BatchNorm2d(64)
45
+
46
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
47
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
48
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
49
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
50
+
51
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
52
+
53
+ def _make_layer(self, block, planes, num_blocks, stride):
54
+ strides = [stride] + [1]*(num_blocks-1)
55
+ layers = []
56
+ for stride in strides:
57
+ layers.append(block(self.in_planes, planes, stride))
58
+ self.in_planes = planes * block.expansion
59
+ return nn.Sequential(*layers)
60
+
61
+ def forward(self, x):
62
+ out = F.relu(self.bn1(self.conv1(x)))
63
+ out = self.layer1(out)
64
+ out = self.layer2(out)
65
+ out = self.layer3(out)
66
+ out = self.layer4(out)
67
+ out = F.avg_pool2d(out, 4)
68
+ out = out.view(out.size(0), -1)
69
+ out = self.linear(out)
70
+ return out
71
+
72
+
73
+ @register
74
+ class MResNet(nn.Module):
75
+ def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None:
76
+ super().__init__()
77
+ self.model = _ResNet(BasicBlock, num_blocks, num_classes)
78
+
79
+ def forward(self, x):
80
+ return self.model(x)
81
+
src/nn/backbone/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
3
+
4
+ by lyuwenyu
5
+ """
6
+
7
+ from collections import OrderedDict
8
+ from typing import Dict, List
9
+
10
+
11
+ import torch.nn as nn
12
+
13
+
14
+ class IntermediateLayerGetter(nn.ModuleDict):
15
+ """
16
+ Module wrapper that returns intermediate layers from a model
17
+
18
+ It has a strong assumption that the modules have been registered
19
+ into the model in the same order as they are used.
20
+ This means that one should **not** reuse the same nn.Module
21
+ twice in the forward if you want this to work.
22
+
23
+ Additionally, it is only able to query submodules that are directly
24
+ assigned to the model. So if `model` is passed, `model.feature1` can
25
+ be returned, but not `model.feature1.layer2`.
26
+ """
27
+
28
+ _version = 3
29
+
30
+ def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
31
+ if not set(return_layers).issubset([name for name, _ in model.named_children()]):
32
+ raise ValueError("return_layers are not present in model. {}"\
33
+ .format([name for name, _ in model.named_children()]))
34
+ orig_return_layers = return_layers
35
+ return_layers = {str(k): str(k) for k in return_layers}
36
+ layers = OrderedDict()
37
+ for name, module in model.named_children():
38
+ layers[name] = module
39
+ if name in return_layers:
40
+ del return_layers[name]
41
+ if not return_layers:
42
+ break
43
+
44
+ super().__init__(layers)
45
+ self.return_layers = orig_return_layers
46
+
47
+ def forward(self, x):
48
+ # out = OrderedDict()
49
+ outputs = []
50
+ for name, module in self.items():
51
+ x = module(x)
52
+ if name in self.return_layers:
53
+ # out_name = self.return_layers[name]
54
+ # out[out_name] = x
55
+ outputs.append(x)
56
+
57
+ return outputs
58
+