cheng-hust commited on
Commit
d0b3ae7
·
verified ·
1 Parent(s): ed23be2

Delete src2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. src2/__init__.py +0 -5
  2. src2/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src2/core/__init__.py +0 -7
  4. src2/core/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src2/core/__pycache__/config.cpython-310.pyc +0 -0
  6. src2/core/__pycache__/yaml_config.cpython-310.pyc +0 -0
  7. src2/core/__pycache__/yaml_utils.cpython-310.pyc +0 -0
  8. src2/core/config.py +0 -264
  9. src2/core/yaml_config.py +0 -152
  10. src2/core/yaml_utils.py +0 -208
  11. src2/data/__init__.py +0 -7
  12. src2/data/__pycache__/__init__.cpython-310.pyc +0 -0
  13. src2/data/__pycache__/dataloader.cpython-310.pyc +0 -0
  14. src2/data/__pycache__/transforms.cpython-310.pyc +0 -0
  15. src2/data/cifar10/__init__.py +0 -14
  16. src2/data/cifar10/__pycache__/__init__.cpython-310.pyc +0 -0
  17. src2/data/coco/__init__.py +0 -9
  18. src2/data/coco/__pycache__/__init__.cpython-310.pyc +0 -0
  19. src2/data/coco/__pycache__/coco_dataset.cpython-310.pyc +0 -0
  20. src2/data/coco/__pycache__/coco_eval.cpython-310.pyc +0 -0
  21. src2/data/coco/__pycache__/coco_utils.cpython-310.pyc +0 -0
  22. src2/data/coco/coco_dataset.py +0 -238
  23. src2/data/coco/coco_eval.py +0 -269
  24. src2/data/coco/coco_utils.py +0 -184
  25. src2/data/dataloader.py +0 -28
  26. src2/data/functional.py +0 -169
  27. src2/data/transforms.py +0 -142
  28. src2/misc/__init__.py +0 -3
  29. src2/misc/__pycache__/__init__.cpython-310.pyc +0 -0
  30. src2/misc/__pycache__/dist.cpython-310.pyc +0 -0
  31. src2/misc/__pycache__/logger.cpython-310.pyc +0 -0
  32. src2/misc/__pycache__/visualizer.cpython-310.pyc +0 -0
  33. src2/misc/dist.py +0 -190
  34. src2/misc/logger.py +0 -239
  35. src2/misc/visualizer.py +0 -34
  36. src2/nn/__init__.py +0 -7
  37. src2/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  38. src2/nn/arch/__init__.py +0 -1
  39. src2/nn/arch/__pycache__/__init__.cpython-310.pyc +0 -0
  40. src2/nn/arch/__pycache__/classification.cpython-310.pyc +0 -0
  41. src2/nn/arch/classification.py +0 -41
  42. src2/nn/backbone/__init__.py +0 -5
  43. src2/nn/backbone/__pycache__/__init__.cpython-310.pyc +0 -0
  44. src2/nn/backbone/__pycache__/common.cpython-310.pyc +0 -0
  45. src2/nn/backbone/__pycache__/presnet.cpython-310.pyc +0 -0
  46. src2/nn/backbone/__pycache__/test_resnet.cpython-310.pyc +0 -0
  47. src2/nn/backbone/common.py +0 -102
  48. src2/nn/backbone/presnet.py +0 -225
  49. src2/nn/backbone/test_resnet.py +0 -81
  50. src2/nn/backbone/utils.py +0 -58
src2/__init__.py DELETED
@@ -1,5 +0,0 @@
1
-
2
- from . import data
3
- from . import nn
4
- from . import optim
5
- from . import zoo
 
 
 
 
 
 
src2/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (269 Bytes)
 
src2/core/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- """by lyuwenyu
2
- """
3
-
4
- # from .yaml_utils import register, create, load_config, merge_config, merge_dict
5
- from .yaml_utils import *
6
- from .config import BaseConfig
7
- from .yaml_config import YAMLConfig
 
 
 
 
 
 
 
 
src2/core/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (309 Bytes)
 
src2/core/__pycache__/config.cpython-310.pyc DELETED
Binary file (6.52 kB)
 
src2/core/__pycache__/yaml_config.cpython-310.pyc DELETED
Binary file (4.68 kB)
 
src2/core/__pycache__/yaml_utils.cpython-310.pyc DELETED
Binary file (4.24 kB)
 
src2/core/config.py DELETED
@@ -1,264 +0,0 @@
1
- """by lyuwenyu
2
- """
3
-
4
- from pprint import pprint
5
- import torch
6
- import torch.nn as nn
7
- from torch.utils.data import Dataset, DataLoader
8
- from torch.optim import Optimizer
9
- from torch.optim.lr_scheduler import LRScheduler
10
- from torch.cuda.amp.grad_scaler import GradScaler
11
-
12
- from typing import Callable, List, Dict
13
-
14
-
15
- __all__ = ['BaseConfig', ]
16
-
17
-
18
-
19
- class BaseConfig(object):
20
- # TODO property
21
-
22
-
23
- def __init__(self) -> None:
24
- super().__init__()
25
-
26
- self.task :str = None
27
-
28
- self._model :nn.Module = None
29
- self._postprocessor :nn.Module = None
30
- self._criterion :nn.Module = None
31
- self._optimizer :Optimizer = None
32
- self._lr_scheduler :LRScheduler = None
33
- self._train_dataloader :DataLoader = None
34
- self._val_dataloader :DataLoader = None
35
- self._ema :nn.Module = None
36
- self._scaler :GradScaler = None
37
-
38
- self.train_dataset :Dataset = None
39
- self.val_dataset :Dataset = None
40
- self.num_workers :int = 0
41
- self.collate_fn :Callable = None
42
-
43
- self.batch_size :int = None
44
- self._train_batch_size :int = None
45
- self._val_batch_size :int = None
46
- self._train_shuffle: bool = None
47
- self._val_shuffle: bool = None
48
-
49
- self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None
50
-
51
- # runtime
52
- self.resume :str = None
53
- self.tuning :str = None
54
-
55
- self.epoches :int = None
56
- self.last_epoch :int = -1
57
- self.end_epoch :int = None
58
-
59
- self.use_amp :bool = False
60
- self.use_ema :bool = False
61
- self.sync_bn :bool = False
62
- self.clip_max_norm : float = None
63
- self.find_unused_parameters :bool = None
64
- # self.ema_decay: float = 0.9999
65
- # self.grad_clip_: Callable = None
66
-
67
- self.log_dir :str = './logs/'
68
- self.log_step :int = 10
69
- self._output_dir :str = None
70
- self._print_freq :int = None
71
- self.checkpoint_step :int = 1
72
-
73
- # self.device :str = torch.device('cpu')
74
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
- self.device = torch.device(device)
76
-
77
-
78
- @property
79
- def model(self, ) -> nn.Module:
80
- return self._model
81
-
82
- @model.setter
83
- def model(self, m):
84
- assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
85
- self._model = m
86
-
87
- @property
88
- def postprocessor(self, ) -> nn.Module:
89
- return self._postprocessor
90
-
91
- @postprocessor.setter
92
- def postprocessor(self, m):
93
- assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
94
- self._postprocessor = m
95
-
96
- @property
97
- def criterion(self, ) -> nn.Module:
98
- return self._criterion
99
-
100
- @criterion.setter
101
- def criterion(self, m):
102
- assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
103
- self._criterion = m
104
-
105
- @property
106
- def optimizer(self, ) -> Optimizer:
107
- return self._optimizer
108
-
109
- @optimizer.setter
110
- def optimizer(self, m):
111
- assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class'
112
- self._optimizer = m
113
-
114
- @property
115
- def lr_scheduler(self, ) -> LRScheduler:
116
- return self._lr_scheduler
117
-
118
- @lr_scheduler.setter
119
- def lr_scheduler(self, m):
120
- assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class'
121
- self._lr_scheduler = m
122
-
123
-
124
- @property
125
- def train_dataloader(self):
126
- if self._train_dataloader is None and self.train_dataset is not None:
127
- loader = DataLoader(self.train_dataset,
128
- batch_size=self.train_batch_size,
129
- num_workers=self.num_workers,
130
- collate_fn=self.collate_fn,
131
- shuffle=self.train_shuffle, )
132
- loader.shuffle = self.train_shuffle
133
- self._train_dataloader = loader
134
-
135
- return self._train_dataloader
136
-
137
- @train_dataloader.setter
138
- def train_dataloader(self, loader):
139
- self._train_dataloader = loader
140
-
141
- @property
142
- def val_dataloader(self):
143
- if self._val_dataloader is None and self.val_dataset is not None:
144
- loader = DataLoader(self.val_dataset,
145
- batch_size=self.val_batch_size,
146
- num_workers=self.num_workers,
147
- drop_last=False,
148
- collate_fn=self.collate_fn,
149
- shuffle=self.val_shuffle)
150
- loader.shuffle = self.val_shuffle
151
- self._val_dataloader = loader
152
-
153
- return self._val_dataloader
154
-
155
- @val_dataloader.setter
156
- def val_dataloader(self, loader):
157
- self._val_dataloader = loader
158
-
159
-
160
- # TODO method
161
- # @property
162
- # def ema(self, ) -> nn.Module:
163
- # if self._ema is None and self.use_ema and self.model is not None:
164
- # self._ema = ModelEMA(self.model, self.ema_decay)
165
- # return self._ema
166
-
167
- @property
168
- def ema(self, ) -> nn.Module:
169
- return self._ema
170
-
171
- @ema.setter
172
- def ema(self, obj):
173
- self._ema = obj
174
-
175
-
176
- @property
177
- def scaler(self) -> GradScaler:
178
- if self._scaler is None and self.use_amp and torch.cuda.is_available():
179
- self._scaler = GradScaler()
180
- return self._scaler
181
-
182
- @scaler.setter
183
- def scaler(self, obj: GradScaler):
184
- self._scaler = obj
185
-
186
-
187
- @property
188
- def val_shuffle(self):
189
- if self._val_shuffle is None:
190
- print('warning: set default val_shuffle=False')
191
- return False
192
- return self._val_shuffle
193
-
194
- @val_shuffle.setter
195
- def val_shuffle(self, shuffle):
196
- assert isinstance(shuffle, bool), 'shuffle must be bool'
197
- self._val_shuffle = shuffle
198
-
199
- @property
200
- def train_shuffle(self):
201
- if self._train_shuffle is None:
202
- print('warning: set default train_shuffle=True')
203
- return True
204
- return self._train_shuffle
205
-
206
- @train_shuffle.setter
207
- def train_shuffle(self, shuffle):
208
- assert isinstance(shuffle, bool), 'shuffle must be bool'
209
- self._train_shuffle = shuffle
210
-
211
-
212
- @property
213
- def train_batch_size(self):
214
- if self._train_batch_size is None and isinstance(self.batch_size, int):
215
- print(f'warning: set train_batch_size=batch_size={self.batch_size}')
216
- return self.batch_size
217
- return self._train_batch_size
218
-
219
- @train_batch_size.setter
220
- def train_batch_size(self, batch_size):
221
- assert isinstance(batch_size, int), 'batch_size must be int'
222
- self._train_batch_size = batch_size
223
-
224
- @property
225
- def val_batch_size(self):
226
- if self._val_batch_size is None:
227
- print(f'warning: set val_batch_size=batch_size={self.batch_size}')
228
- return self.batch_size
229
- return self._val_batch_size
230
-
231
- @val_batch_size.setter
232
- def val_batch_size(self, batch_size):
233
- assert isinstance(batch_size, int), 'batch_size must be int'
234
- self._val_batch_size = batch_size
235
-
236
-
237
- @property
238
- def output_dir(self):
239
- if self._output_dir is None:
240
- return self.log_dir
241
- return self._output_dir
242
-
243
- @output_dir.setter
244
- def output_dir(self, root):
245
- self._output_dir = root
246
-
247
- @property
248
- def print_freq(self):
249
- if self._print_freq is None:
250
- # self._print_freq = self.log_step
251
- return self.log_step
252
- return self._print_freq
253
-
254
- @print_freq.setter
255
- def print_freq(self, n):
256
- assert isinstance(n, int), 'print_freq must be int'
257
- self._print_freq = n
258
-
259
-
260
- # def __repr__(self) -> str:
261
- # pass
262
-
263
-
264
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/core/yaml_config.py DELETED
@@ -1,152 +0,0 @@
1
- """by lyuwenyu
2
- """
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- import re
8
- import copy
9
-
10
- from .config import BaseConfig
11
- from .yaml_utils import load_config, merge_config, create, merge_dict
12
-
13
-
14
- class YAMLConfig(BaseConfig):
15
- def __init__(self, cfg_path: str, **kwargs) -> None:
16
- super().__init__()
17
-
18
- cfg = load_config(cfg_path)
19
- merge_dict(cfg, kwargs)
20
-
21
- # pprint(cfg)
22
-
23
- self.yaml_cfg = cfg
24
-
25
- self.log_step = cfg.get('log_step', 100)
26
- self.checkpoint_step = cfg.get('checkpoint_step', 1)
27
- self.epoches = cfg.get('epoches', -1)
28
- self.resume = cfg.get('resume', '')
29
- self.tuning = cfg.get('tuning', '')
30
- self.sync_bn = cfg.get('sync_bn', False)
31
- self.output_dir = cfg.get('output_dir', None)
32
-
33
- self.use_ema = cfg.get('use_ema', False)
34
- self.use_amp = cfg.get('use_amp', False)
35
- self.autocast = cfg.get('autocast', dict())
36
- self.find_unused_parameters = cfg.get('find_unused_parameters', None)
37
- self.clip_max_norm = cfg.get('clip_max_norm', 0.)
38
-
39
-
40
- @property
41
- def model(self, ) -> torch.nn.Module:
42
- if self._model is None and 'model' in self.yaml_cfg:
43
- merge_config(self.yaml_cfg)
44
- self._model = create(self.yaml_cfg['model'])
45
- return self._model
46
-
47
- @property
48
- def postprocessor(self, ) -> torch.nn.Module:
49
- if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
50
- merge_config(self.yaml_cfg)
51
- self._postprocessor = create(self.yaml_cfg['postprocessor'])
52
- return self._postprocessor
53
-
54
- @property
55
- def criterion(self, ):
56
- if self._criterion is None and 'criterion' in self.yaml_cfg:
57
- merge_config(self.yaml_cfg)
58
- self._criterion = create(self.yaml_cfg['criterion'])
59
- return self._criterion
60
-
61
-
62
- @property
63
- def optimizer(self, ):
64
- if self._optimizer is None and 'optimizer' in self.yaml_cfg:
65
- merge_config(self.yaml_cfg)
66
- params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)
67
- self._optimizer = create('optimizer', params=params)
68
-
69
- return self._optimizer
70
-
71
- @property
72
- def lr_scheduler(self, ):
73
- if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
74
- merge_config(self.yaml_cfg)
75
- self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
76
- print('Initial lr: ', self._lr_scheduler.get_last_lr())
77
-
78
- return self._lr_scheduler
79
-
80
- @property
81
- def train_dataloader(self, ):
82
- if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
83
- merge_config(self.yaml_cfg)
84
- self._train_dataloader = create('train_dataloader')
85
- self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
86
-
87
- return self._train_dataloader
88
-
89
- @property
90
- def val_dataloader(self, ):
91
- if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg:
92
- merge_config(self.yaml_cfg)
93
- self._val_dataloader = create('val_dataloader')
94
- self._val_dataloader.shuffle = self.yaml_cfg['val_dataloader'].get('shuffle', False)
95
-
96
- return self._val_dataloader
97
-
98
-
99
- @property
100
- def ema(self, ):
101
- if self._ema is None and self.yaml_cfg.get('use_ema', False):
102
- merge_config(self.yaml_cfg)
103
- self._ema = create('ema', model=self.model)
104
-
105
- return self._ema
106
-
107
-
108
- @property
109
- def scaler(self, ):
110
- if self._scaler is None and self.yaml_cfg.get('use_amp', False):
111
- merge_config(self.yaml_cfg)
112
- self._scaler = create('scaler')
113
-
114
- return self._scaler
115
-
116
-
117
- @staticmethod
118
- def get_optim_params(cfg: dict, model: nn.Module):
119
- '''
120
- E.g.:
121
- ^(?=.*a)(?=.*b).*$ means including a and b
122
- ^((?!b.)*a((?!b).)*$ means including a but not b
123
- ^((?!b|c).)*a((?!b|c).)*$ means including a but not (b | c)
124
- '''
125
- assert 'type' in cfg, ''
126
- cfg = copy.deepcopy(cfg)
127
-
128
- if 'params' not in cfg:
129
- return model.parameters()
130
-
131
- assert isinstance(cfg['params'], list), ''
132
-
133
- param_groups = []
134
- visited = []
135
- for pg in cfg['params']:
136
- pattern = pg['params']
137
- params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
138
- pg['params'] = params.values()
139
- param_groups.append(pg)
140
- visited.extend(list(params.keys()))
141
-
142
- names = [k for k, v in model.named_parameters() if v.requires_grad]
143
-
144
- if len(visited) < len(names):
145
- unseen = set(names) - set(visited)
146
- params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
147
- param_groups.append({'params': params.values()})
148
- visited.extend(list(params.keys()))
149
-
150
- assert len(visited) == len(names), ''
151
-
152
- return param_groups
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/core/yaml_utils.py DELETED
@@ -1,208 +0,0 @@
1
- """"by lyuwenyu
2
- """
3
-
4
- import os
5
- import yaml
6
- import inspect
7
- import importlib
8
-
9
- __all__ = ['GLOBAL_CONFIG', 'register', 'create', 'load_config', 'merge_config', 'merge_dict']
10
-
11
-
12
- GLOBAL_CONFIG = dict()
13
- INCLUDE_KEY = '__include__'
14
-
15
-
16
- def register(cls: type):
17
- '''
18
- Args:
19
- cls (type): Module class to be registered.
20
- '''
21
- if cls.__name__ in GLOBAL_CONFIG:
22
- raise ValueError('{} already registered'.format(cls.__name__))
23
-
24
- if inspect.isfunction(cls):
25
- GLOBAL_CONFIG[cls.__name__] = cls
26
-
27
- elif inspect.isclass(cls):
28
- GLOBAL_CONFIG[cls.__name__] = extract_schema(cls)
29
-
30
- else:
31
- raise ValueError(f'register {cls}')
32
-
33
- return cls
34
-
35
-
36
- def extract_schema(cls: type):
37
- '''
38
- Args:
39
- cls (type),
40
- Return:
41
- Dict,
42
- '''
43
- argspec = inspect.getfullargspec(cls.__init__)
44
- arg_names = [arg for arg in argspec.args if arg != 'self']
45
- num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
46
- num_requires = len(arg_names) - num_defualts
47
-
48
- schame = dict()
49
- schame['_name'] = cls.__name__
50
- schame['_pymodule'] = importlib.import_module(cls.__module__)
51
- schame['_inject'] = getattr(cls, '__inject__', [])
52
- schame['_share'] = getattr(cls, '__share__', [])
53
-
54
- for i, name in enumerate(arg_names):
55
- if name in schame['_share']:
56
- assert i >= num_requires, 'share config must have default value.'
57
- value = argspec.defaults[i - num_requires]
58
-
59
- elif i >= num_requires:
60
- value = argspec.defaults[i - num_requires]
61
-
62
- else:
63
- value = None
64
-
65
- schame[name] = value
66
-
67
- return schame
68
-
69
-
70
-
71
- def create(type_or_name, **kwargs):
72
- '''
73
- '''
74
- assert type(type_or_name) in (type, str), 'create should be class or name.'
75
-
76
- name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
77
-
78
- if name in GLOBAL_CONFIG:
79
- if hasattr(GLOBAL_CONFIG[name], '__dict__'):
80
- return GLOBAL_CONFIG[name]
81
- else:
82
- raise ValueError('The module {} is not registered'.format(name))
83
-
84
- cfg = GLOBAL_CONFIG[name]
85
-
86
- if isinstance(cfg, dict) and 'type' in cfg:
87
- _cfg: dict = GLOBAL_CONFIG[cfg['type']]
88
- _cfg.update(cfg) # update global cls default args
89
- _cfg.update(kwargs) # TODO
90
- name = _cfg.pop('type')
91
-
92
- return create(name)
93
-
94
-
95
- cls = getattr(cfg['_pymodule'], name)
96
- argspec = inspect.getfullargspec(cls.__init__)
97
- arg_names = [arg for arg in argspec.args if arg != 'self']
98
-
99
- cls_kwargs = {}
100
- cls_kwargs.update(cfg)
101
-
102
- # shared var
103
- for k in cfg['_share']:
104
- if k in GLOBAL_CONFIG:
105
- cls_kwargs[k] = GLOBAL_CONFIG[k]
106
- else:
107
- cls_kwargs[k] = cfg[k]
108
-
109
- # inject
110
- for k in cfg['_inject']:
111
- _k = cfg[k]
112
-
113
- if _k is None:
114
- continue
115
-
116
- if isinstance(_k, str):
117
- if _k not in GLOBAL_CONFIG:
118
- raise ValueError(f'Missing inject config of {_k}.')
119
-
120
- _cfg = GLOBAL_CONFIG[_k]
121
-
122
- if isinstance(_cfg, dict):
123
- cls_kwargs[k] = create(_cfg['_name'])
124
- else:
125
- cls_kwargs[k] = _cfg
126
-
127
- elif isinstance(_k, dict):
128
- if 'type' not in _k.keys():
129
- raise ValueError(f'Missing inject for `type` style.')
130
-
131
- _type = str(_k['type'])
132
- if _type not in GLOBAL_CONFIG:
133
- raise ValueError(f'Missing {_type} in inspect stage.')
134
-
135
- # TODO modified inspace, maybe get wrong result for using `> 1`
136
- _cfg: dict = GLOBAL_CONFIG[_type]
137
- # _cfg_copy = copy.deepcopy(_cfg)
138
- _cfg.update(_k) # update
139
- cls_kwargs[k] = create(_type)
140
- # _cfg.update(_cfg_copy) # resume
141
-
142
- else:
143
- raise ValueError(f'Inject does not support {_k}')
144
-
145
-
146
- cls_kwargs = {n: cls_kwargs[n] for n in arg_names}
147
-
148
- return cls(**cls_kwargs)
149
-
150
-
151
-
152
- def load_config(file_path, cfg=dict()):
153
- '''load config
154
- '''
155
- _, ext = os.path.splitext(file_path)
156
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
157
-
158
- with open(file_path) as f:
159
- file_cfg = yaml.load(f, Loader=yaml.Loader)
160
- if file_cfg is None:
161
- return {}
162
-
163
- if INCLUDE_KEY in file_cfg:
164
- base_yamls = list(file_cfg[INCLUDE_KEY])
165
- for base_yaml in base_yamls:
166
- if base_yaml.startswith('~'):
167
- base_yaml = os.path.expanduser(base_yaml)
168
-
169
- if not base_yaml.startswith('/'):
170
- base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
171
-
172
- with open(base_yaml) as f:
173
- base_cfg = load_config(base_yaml, cfg)
174
- merge_config(base_cfg, cfg)
175
-
176
- return merge_config(file_cfg, cfg)
177
-
178
-
179
-
180
- def merge_dict(dct, another_dct):
181
- '''merge another_dct into dct
182
- '''
183
- for k in another_dct:
184
- if (k in dct and isinstance(dct[k], dict) and isinstance(another_dct[k], dict)):
185
- merge_dict(dct[k], another_dct[k])
186
- else:
187
- dct[k] = another_dct[k]
188
-
189
- return dct
190
-
191
-
192
-
193
- def merge_config(config, another_cfg=None):
194
- """
195
- Merge config into global config or another_cfg.
196
-
197
- Args:
198
- config (dict): Config to be merged.
199
-
200
- Returns: global config
201
- """
202
- global GLOBAL_CONFIG
203
- dct = GLOBAL_CONFIG if another_cfg is None else another_cfg
204
-
205
- return merge_dict(dct, config)
206
-
207
-
208
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/__init__.py DELETED
@@ -1,7 +0,0 @@
1
-
2
- from .coco import *
3
- from .cifar10 import CIFAR10
4
-
5
- from .dataloader import *
6
- from .transforms import *
7
-
 
 
 
 
 
 
 
 
src2/data/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (270 Bytes)
 
src2/data/__pycache__/dataloader.cpython-310.pyc DELETED
Binary file (1.29 kB)
 
src2/data/__pycache__/transforms.cpython-310.pyc DELETED
Binary file (5.19 kB)
 
src2/data/cifar10/__init__.py DELETED
@@ -1,14 +0,0 @@
1
-
2
- import torchvision
3
- from typing import Optional, Callable
4
-
5
- from src.core import register
6
-
7
-
8
- @register
9
- class CIFAR10(torchvision.datasets.CIFAR10):
10
- __inject__ = ['transform', 'target_transform']
11
-
12
- def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
13
- super().__init__(root, train, transform, target_transform, download)
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/cifar10/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (823 Bytes)
 
src2/data/coco/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from .coco_dataset import (
2
- CocoDetection,
3
- mscoco_category2label,
4
- mscoco_label2category,
5
- mscoco_category2name,
6
- )
7
- from .coco_eval import *
8
-
9
- from .coco_utils import get_coco_api_from_dataset
 
 
 
 
 
 
 
 
 
 
src2/data/coco/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (402 Bytes)
 
src2/data/coco/__pycache__/coco_dataset.cpython-310.pyc DELETED
Binary file (7.07 kB)
 
src2/data/coco/__pycache__/coco_eval.cpython-310.pyc DELETED
Binary file (7.26 kB)
 
src2/data/coco/__pycache__/coco_utils.cpython-310.pyc DELETED
Binary file (6.61 kB)
 
src2/data/coco/coco_dataset.py DELETED
@@ -1,238 +0,0 @@
1
- """
2
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
-
4
- COCO dataset which returns image_id for evaluation.
5
- Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
6
- """
7
-
8
- import torch
9
- import torch.utils.data
10
-
11
- import torchvision
12
- torchvision.disable_beta_transforms_warning()
13
-
14
- from torchvision import datapoints
15
-
16
- from pycocotools import mask as coco_mask
17
-
18
- from src.core import register
19
-
20
- __all__ = ['CocoDetection']
21
-
22
-
23
- @register
24
- class CocoDetection(torchvision.datasets.CocoDetection):
25
- __inject__ = ['transforms']
26
- __share__ = ['remap_mscoco_category']
27
-
28
- def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False):
29
- super(CocoDetection, self).__init__(img_folder, ann_file)
30
- self._transforms = transforms
31
- self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category)
32
- self.img_folder = img_folder
33
- self.ann_file = ann_file
34
- self.return_masks = return_masks
35
- self.remap_mscoco_category = remap_mscoco_category
36
-
37
- def __getitem__(self, idx):
38
- img, target = super(CocoDetection, self).__getitem__(idx)
39
- image_id = self.ids[idx]
40
- target = {'image_id': image_id, 'annotations': target}
41
- img, target = self.prepare(img, target)
42
-
43
- # ['boxes', 'masks', 'labels']:
44
- if 'boxes' in target:
45
- target['boxes'] = datapoints.BoundingBox(
46
- target['boxes'],
47
- format=datapoints.BoundingBoxFormat.XYXY,
48
- spatial_size=img.size[::-1]) # h w
49
-
50
- if 'masks' in target:
51
- target['masks'] = datapoints.Mask(target['masks'])
52
-
53
- if self._transforms is not None:
54
- img, target = self._transforms(img, target)
55
-
56
- return img, target
57
-
58
- def extra_repr(self) -> str:
59
- s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n'
60
- s += f' return_masks: {self.return_masks}\n'
61
- if hasattr(self, '_transforms') and self._transforms is not None:
62
- s += f' transforms:\n {repr(self._transforms)}'
63
-
64
- return s
65
-
66
-
67
- def convert_coco_poly_to_mask(segmentations, height, width):
68
- masks = []
69
- for polygons in segmentations:
70
- rles = coco_mask.frPyObjects(polygons, height, width)
71
- mask = coco_mask.decode(rles)
72
- if len(mask.shape) < 3:
73
- mask = mask[..., None]
74
- mask = torch.as_tensor(mask, dtype=torch.uint8)
75
- mask = mask.any(dim=2)
76
- masks.append(mask)
77
- if masks:
78
- masks = torch.stack(masks, dim=0)
79
- else:
80
- masks = torch.zeros((0, height, width), dtype=torch.uint8)
81
- return masks
82
-
83
-
84
- class ConvertCocoPolysToMask(object):
85
- def __init__(self, return_masks=False, remap_mscoco_category=False):
86
- self.return_masks = return_masks
87
- self.remap_mscoco_category = remap_mscoco_category
88
-
89
- def __call__(self, image, target):
90
- w, h = image.size
91
-
92
- image_id = target["image_id"]
93
- image_id = torch.tensor([image_id])
94
-
95
- anno = target["annotations"]
96
-
97
- anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
98
-
99
- boxes = [obj["bbox"] for obj in anno]
100
- # guard against no boxes via resizing
101
- boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
102
- boxes[:, 2:] += boxes[:, :2]
103
- boxes[:, 0::2].clamp_(min=0, max=w)
104
- boxes[:, 1::2].clamp_(min=0, max=h)
105
-
106
- if self.remap_mscoco_category:
107
- classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
108
- else:
109
- classes = [obj["category_id"] for obj in anno]
110
-
111
- classes = torch.tensor(classes, dtype=torch.int64)
112
-
113
- if self.return_masks:
114
- segmentations = [obj["segmentation"] for obj in anno]
115
- masks = convert_coco_poly_to_mask(segmentations, h, w)
116
-
117
- keypoints = None
118
- if anno and "keypoints" in anno[0]:
119
- keypoints = [obj["keypoints"] for obj in anno]
120
- keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
121
- num_keypoints = keypoints.shape[0]
122
- if num_keypoints:
123
- keypoints = keypoints.view(num_keypoints, -1, 3)
124
-
125
- keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
126
- boxes = boxes[keep]
127
- classes = classes[keep]
128
- if self.return_masks:
129
- masks = masks[keep]
130
- if keypoints is not None:
131
- keypoints = keypoints[keep]
132
-
133
- target = {}
134
- target["boxes"] = boxes
135
- target["labels"] = classes
136
- if self.return_masks:
137
- target["masks"] = masks
138
- target["image_id"] = image_id
139
- if keypoints is not None:
140
- target["keypoints"] = keypoints
141
-
142
- # for conversion to coco api
143
- area = torch.tensor([obj["area"] for obj in anno])
144
- iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
145
- target["area"] = area[keep]
146
- target["iscrowd"] = iscrowd[keep]
147
-
148
- target["orig_size"] = torch.as_tensor([int(w), int(h)])
149
- target["size"] = torch.as_tensor([int(w), int(h)])
150
-
151
- return image, target
152
-
153
-
154
- mscoco_category2name = {
155
- 1: 'person',
156
- 2: 'bicycle',
157
- 3: 'car',
158
- 4: 'motorcycle',
159
- 5: 'airplane',
160
- 6: 'bus',
161
- 7: 'train',
162
- 8: 'truck',
163
- 9: 'boat',
164
- 10: 'traffic light',
165
- 11: 'fire hydrant',
166
- 13: 'stop sign',
167
- 14: 'parking meter',
168
- 15: 'bench',
169
- 16: 'bird',
170
- 17: 'cat',
171
- 18: 'dog',
172
- 19: 'horse',
173
- 20: 'sheep',
174
- 21: 'cow',
175
- 22: 'elephant',
176
- 23: 'bear',
177
- 24: 'zebra',
178
- 25: 'giraffe',
179
- 27: 'backpack',
180
- 28: 'umbrella',
181
- 31: 'handbag',
182
- 32: 'tie',
183
- 33: 'suitcase',
184
- 34: 'frisbee',
185
- 35: 'skis',
186
- 36: 'snowboard',
187
- 37: 'sports ball',
188
- 38: 'kite',
189
- 39: 'baseball bat',
190
- 40: 'baseball glove',
191
- 41: 'skateboard',
192
- 42: 'surfboard',
193
- 43: 'tennis racket',
194
- 44: 'bottle',
195
- 46: 'wine glass',
196
- 47: 'cup',
197
- 48: 'fork',
198
- 49: 'knife',
199
- 50: 'spoon',
200
- 51: 'bowl',
201
- 52: 'banana',
202
- 53: 'apple',
203
- 54: 'sandwich',
204
- 55: 'orange',
205
- 56: 'broccoli',
206
- 57: 'carrot',
207
- 58: 'hot dog',
208
- 59: 'pizza',
209
- 60: 'donut',
210
- 61: 'cake',
211
- 62: 'chair',
212
- 63: 'couch',
213
- 64: 'potted plant',
214
- 65: 'bed',
215
- 67: 'dining table',
216
- 70: 'toilet',
217
- 72: 'tv',
218
- 73: 'laptop',
219
- 74: 'mouse',
220
- 75: 'remote',
221
- 76: 'keyboard',
222
- 77: 'cell phone',
223
- 78: 'microwave',
224
- 79: 'oven',
225
- 80: 'toaster',
226
- 81: 'sink',
227
- 82: 'refrigerator',
228
- 84: 'book',
229
- 85: 'clock',
230
- 86: 'vase',
231
- 87: 'scissors',
232
- 88: 'teddy bear',
233
- 89: 'hair drier',
234
- 90: 'toothbrush'
235
- }
236
-
237
- mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
238
- mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/coco/coco_eval.py DELETED
@@ -1,269 +0,0 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
- """
3
- COCO evaluator that works in distributed mode.
4
-
5
- Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
6
- The difference is that there is less copy-pasting from pycocotools
7
- in the end of the file, as python3 can suppress prints with contextlib
8
- """
9
- import os
10
- import contextlib
11
- import copy
12
- import numpy as np
13
- import torch
14
-
15
- from pycocotools.cocoeval import COCOeval
16
- from pycocotools.coco import COCO
17
- import pycocotools.mask as mask_util
18
-
19
- from src.misc import dist
20
-
21
-
22
- __all__ = ['CocoEvaluator',]
23
-
24
-
25
- class CocoEvaluator(object):
26
- def __init__(self, coco_gt, iou_types):
27
- assert isinstance(iou_types, (list, tuple))
28
- coco_gt = copy.deepcopy(coco_gt)
29
- self.coco_gt = coco_gt
30
-
31
- self.iou_types = iou_types
32
- self.coco_eval = {}
33
- for iou_type in iou_types:
34
- self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
35
-
36
- self.img_ids = []
37
- self.eval_imgs = {k: [] for k in iou_types}
38
-
39
- def update(self, predictions):
40
- img_ids = list(np.unique(list(predictions.keys())))
41
- self.img_ids.extend(img_ids)
42
-
43
- for iou_type in self.iou_types:
44
- results = self.prepare(predictions, iou_type)
45
-
46
- # suppress pycocotools prints
47
- with open(os.devnull, 'w') as devnull:
48
- with contextlib.redirect_stdout(devnull):
49
- coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
50
- coco_eval = self.coco_eval[iou_type]
51
-
52
- coco_eval.cocoDt = coco_dt
53
- coco_eval.params.imgIds = list(img_ids)
54
- img_ids, eval_imgs = evaluate(coco_eval)
55
-
56
- self.eval_imgs[iou_type].append(eval_imgs)
57
-
58
- def synchronize_between_processes(self):
59
- for iou_type in self.iou_types:
60
- self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
61
- create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
62
-
63
- def accumulate(self):
64
- for coco_eval in self.coco_eval.values():
65
- coco_eval.accumulate()
66
-
67
- def summarize(self):
68
- for iou_type, coco_eval in self.coco_eval.items():
69
- print("IoU metric: {}".format(iou_type))
70
- coco_eval.summarize()
71
-
72
- def prepare(self, predictions, iou_type):
73
- if iou_type == "bbox":
74
- return self.prepare_for_coco_detection(predictions)
75
- elif iou_type == "segm":
76
- return self.prepare_for_coco_segmentation(predictions)
77
- elif iou_type == "keypoints":
78
- return self.prepare_for_coco_keypoint(predictions)
79
- else:
80
- raise ValueError("Unknown iou type {}".format(iou_type))
81
-
82
- def prepare_for_coco_detection(self, predictions):
83
- coco_results = []
84
- for original_id, prediction in predictions.items():
85
- if len(prediction) == 0:
86
- continue
87
-
88
- boxes = prediction["boxes"]
89
- boxes = convert_to_xywh(boxes).tolist()
90
- scores = prediction["scores"].tolist()
91
- labels = prediction["labels"].tolist()
92
-
93
- coco_results.extend(
94
- [
95
- {
96
- "image_id": original_id,
97
- "category_id": labels[k],
98
- "bbox": box,
99
- "score": scores[k],
100
- }
101
- for k, box in enumerate(boxes)
102
- ]
103
- )
104
- return coco_results
105
-
106
- def prepare_for_coco_segmentation(self, predictions):
107
- coco_results = []
108
- for original_id, prediction in predictions.items():
109
- if len(prediction) == 0:
110
- continue
111
-
112
- scores = prediction["scores"]
113
- labels = prediction["labels"]
114
- masks = prediction["masks"]
115
-
116
- masks = masks > 0.5
117
-
118
- scores = prediction["scores"].tolist()
119
- labels = prediction["labels"].tolist()
120
-
121
- rles = [
122
- mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
123
- for mask in masks
124
- ]
125
- for rle in rles:
126
- rle["counts"] = rle["counts"].decode("utf-8")
127
-
128
- coco_results.extend(
129
- [
130
- {
131
- "image_id": original_id,
132
- "category_id": labels[k],
133
- "segmentation": rle,
134
- "score": scores[k],
135
- }
136
- for k, rle in enumerate(rles)
137
- ]
138
- )
139
- return coco_results
140
-
141
- def prepare_for_coco_keypoint(self, predictions):
142
- coco_results = []
143
- for original_id, prediction in predictions.items():
144
- if len(prediction) == 0:
145
- continue
146
-
147
- boxes = prediction["boxes"]
148
- boxes = convert_to_xywh(boxes).tolist()
149
- scores = prediction["scores"].tolist()
150
- labels = prediction["labels"].tolist()
151
- keypoints = prediction["keypoints"]
152
- keypoints = keypoints.flatten(start_dim=1).tolist()
153
-
154
- coco_results.extend(
155
- [
156
- {
157
- "image_id": original_id,
158
- "category_id": labels[k],
159
- 'keypoints': keypoint,
160
- "score": scores[k],
161
- }
162
- for k, keypoint in enumerate(keypoints)
163
- ]
164
- )
165
- return coco_results
166
-
167
-
168
- def convert_to_xywh(boxes):
169
- xmin, ymin, xmax, ymax = boxes.unbind(1)
170
- return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
171
-
172
-
173
- def merge(img_ids, eval_imgs):
174
- all_img_ids = dist.all_gather(img_ids)
175
- all_eval_imgs = dist.all_gather(eval_imgs)
176
-
177
- merged_img_ids = []
178
- for p in all_img_ids:
179
- merged_img_ids.extend(p)
180
-
181
- merged_eval_imgs = []
182
- for p in all_eval_imgs:
183
- merged_eval_imgs.append(p)
184
-
185
- merged_img_ids = np.array(merged_img_ids)
186
- merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
187
-
188
- # keep only unique (and in sorted order) images
189
- merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
190
- merged_eval_imgs = merged_eval_imgs[..., idx]
191
-
192
- return merged_img_ids, merged_eval_imgs
193
-
194
-
195
- def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
196
- img_ids, eval_imgs = merge(img_ids, eval_imgs)
197
- img_ids = list(img_ids)
198
- eval_imgs = list(eval_imgs.flatten())
199
-
200
- coco_eval.evalImgs = eval_imgs
201
- coco_eval.params.imgIds = img_ids
202
- coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
203
-
204
-
205
- #################################################################
206
- # From pycocotools, just removed the prints and fixed
207
- # a Python3 bug about unicode not defined
208
- #################################################################
209
-
210
-
211
- # import io
212
- # from contextlib import redirect_stdout
213
- # def evaluate(imgs):
214
- # with redirect_stdout(io.StringIO()):
215
- # imgs.evaluate()
216
- # return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
217
-
218
-
219
- def evaluate(self):
220
- '''
221
- Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
222
- :return: None
223
- '''
224
- # tic = time.time()
225
- # print('Running per image evaluation...')
226
- p = self.params
227
- # add backward compatibility if useSegm is specified in params
228
- if p.useSegm is not None:
229
- p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
230
- print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
231
- # print('Evaluate annotation type *{}*'.format(p.iouType))
232
- p.imgIds = list(np.unique(p.imgIds))
233
- if p.useCats:
234
- p.catIds = list(np.unique(p.catIds))
235
- p.maxDets = sorted(p.maxDets)
236
- self.params = p
237
-
238
- self._prepare()
239
- # loop through images, area range, max detection number
240
- catIds = p.catIds if p.useCats else [-1]
241
-
242
- if p.iouType == 'segm' or p.iouType == 'bbox':
243
- computeIoU = self.computeIoU
244
- elif p.iouType == 'keypoints':
245
- computeIoU = self.computeOks
246
- self.ious = {
247
- (imgId, catId): computeIoU(imgId, catId)
248
- for imgId in p.imgIds
249
- for catId in catIds}
250
-
251
- evaluateImg = self.evaluateImg
252
- maxDet = p.maxDets[-1]
253
- evalImgs = [
254
- evaluateImg(imgId, catId, areaRng, maxDet)
255
- for catId in catIds
256
- for areaRng in p.areaRng
257
- for imgId in p.imgIds
258
- ]
259
- # this is NOT in the pycocotools code, but could be done outside
260
- evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
261
- self._paramsEval = copy.deepcopy(self.params)
262
- # toc = time.time()
263
- # print('DONE (t={:0.2f}s).'.format(toc-tic))
264
- return p.imgIds, evalImgs
265
-
266
- #################################################################
267
- # end of straight copy from pycocotools, just removing the prints
268
- #################################################################
269
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/coco/coco_utils.py DELETED
@@ -1,184 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import torch.utils.data
5
- import torchvision
6
- from pycocotools import mask as coco_mask
7
- from pycocotools.coco import COCO
8
-
9
-
10
- def convert_coco_poly_to_mask(segmentations, height, width):
11
- masks = []
12
- for polygons in segmentations:
13
- rles = coco_mask.frPyObjects(polygons, height, width)
14
- mask = coco_mask.decode(rles)
15
- if len(mask.shape) < 3:
16
- mask = mask[..., None]
17
- mask = torch.as_tensor(mask, dtype=torch.uint8)
18
- mask = mask.any(dim=2)
19
- masks.append(mask)
20
- if masks:
21
- masks = torch.stack(masks, dim=0)
22
- else:
23
- masks = torch.zeros((0, height, width), dtype=torch.uint8)
24
- return masks
25
-
26
-
27
- class ConvertCocoPolysToMask:
28
- def __call__(self, image, target):
29
- w, h = image.size
30
-
31
- image_id = target["image_id"]
32
-
33
- anno = target["annotations"]
34
-
35
- anno = [obj for obj in anno if obj["iscrowd"] == 0]
36
-
37
- boxes = [obj["bbox"] for obj in anno]
38
- # guard against no boxes via resizing
39
- boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
40
- boxes[:, 2:] += boxes[:, :2]
41
- boxes[:, 0::2].clamp_(min=0, max=w)
42
- boxes[:, 1::2].clamp_(min=0, max=h)
43
-
44
- classes = [obj["category_id"] for obj in anno]
45
- classes = torch.tensor(classes, dtype=torch.int64)
46
-
47
- segmentations = [obj["segmentation"] for obj in anno]
48
- masks = convert_coco_poly_to_mask(segmentations, h, w)
49
-
50
- keypoints = None
51
- if anno and "keypoints" in anno[0]:
52
- keypoints = [obj["keypoints"] for obj in anno]
53
- keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
54
- num_keypoints = keypoints.shape[0]
55
- if num_keypoints:
56
- keypoints = keypoints.view(num_keypoints, -1, 3)
57
-
58
- keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
59
- boxes = boxes[keep]
60
- classes = classes[keep]
61
- masks = masks[keep]
62
- if keypoints is not None:
63
- keypoints = keypoints[keep]
64
-
65
- target = {}
66
- target["boxes"] = boxes
67
- target["labels"] = classes
68
- target["masks"] = masks
69
- target["image_id"] = image_id
70
- if keypoints is not None:
71
- target["keypoints"] = keypoints
72
-
73
- # for conversion to coco api
74
- area = torch.tensor([obj["area"] for obj in anno])
75
- iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
76
- target["area"] = area
77
- target["iscrowd"] = iscrowd
78
-
79
- return image, target
80
-
81
-
82
- def _coco_remove_images_without_annotations(dataset, cat_list=None):
83
- def _has_only_empty_bbox(anno):
84
- return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
85
-
86
- def _count_visible_keypoints(anno):
87
- return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
88
-
89
- min_keypoints_per_image = 10
90
-
91
- def _has_valid_annotation(anno):
92
- # if it's empty, there is no annotation
93
- if len(anno) == 0:
94
- return False
95
- # if all boxes have close to zero area, there is no annotation
96
- if _has_only_empty_bbox(anno):
97
- return False
98
- # keypoints task have a slight different criteria for considering
99
- # if an annotation is valid
100
- if "keypoints" not in anno[0]:
101
- return True
102
- # for keypoint detection tasks, only consider valid images those
103
- # containing at least min_keypoints_per_image
104
- if _count_visible_keypoints(anno) >= min_keypoints_per_image:
105
- return True
106
- return False
107
-
108
- ids = []
109
- for ds_idx, img_id in enumerate(dataset.ids):
110
- ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
111
- anno = dataset.coco.loadAnns(ann_ids)
112
- if cat_list:
113
- anno = [obj for obj in anno if obj["category_id"] in cat_list]
114
- if _has_valid_annotation(anno):
115
- ids.append(ds_idx)
116
-
117
- dataset = torch.utils.data.Subset(dataset, ids)
118
- return dataset
119
-
120
-
121
- def convert_to_coco_api(ds):
122
- coco_ds = COCO()
123
- # annotation IDs need to start at 1, not 0, see torchvision issue #1530
124
- ann_id = 1
125
- dataset = {"images": [], "categories": [], "annotations": []}
126
- categories = set()
127
- for img_idx in range(len(ds)):
128
- # find better way to get target
129
- # targets = ds.get_annotations(img_idx)
130
- img, targets = ds[img_idx]
131
- image_id = targets["image_id"].item()
132
- img_dict = {}
133
- img_dict["id"] = image_id
134
- img_dict["height"] = img.shape[-2]
135
- img_dict["width"] = img.shape[-1]
136
- dataset["images"].append(img_dict)
137
- bboxes = targets["boxes"].clone()
138
- bboxes[:, 2:] -= bboxes[:, :2]
139
- bboxes = bboxes.tolist()
140
- labels = targets["labels"].tolist()
141
- areas = targets["area"].tolist()
142
- iscrowd = targets["iscrowd"].tolist()
143
- if "masks" in targets:
144
- masks = targets["masks"]
145
- # make masks Fortran contiguous for coco_mask
146
- masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
147
- if "keypoints" in targets:
148
- keypoints = targets["keypoints"]
149
- keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
150
- num_objs = len(bboxes)
151
- for i in range(num_objs):
152
- ann = {}
153
- ann["image_id"] = image_id
154
- ann["bbox"] = bboxes[i]
155
- ann["category_id"] = labels[i]
156
- categories.add(labels[i])
157
- ann["area"] = areas[i]
158
- ann["iscrowd"] = iscrowd[i]
159
- ann["id"] = ann_id
160
- if "masks" in targets:
161
- ann["segmentation"] = coco_mask.encode(masks[i].numpy())
162
- if "keypoints" in targets:
163
- ann["keypoints"] = keypoints[i]
164
- ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
165
- dataset["annotations"].append(ann)
166
- ann_id += 1
167
- dataset["categories"] = [{"id": i} for i in sorted(categories)]
168
- coco_ds.dataset = dataset
169
- coco_ds.createIndex()
170
- return coco_ds
171
-
172
-
173
- def get_coco_api_from_dataset(dataset):
174
- # FIXME: This is... awful?
175
- for _ in range(10):
176
- if isinstance(dataset, torchvision.datasets.CocoDetection):
177
- break
178
- if isinstance(dataset, torch.utils.data.Subset):
179
- dataset = dataset.dataset
180
- if isinstance(dataset, torchvision.datasets.CocoDetection):
181
- return dataset.coco
182
- return convert_to_coco_api(dataset)
183
-
184
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/dataloader.py DELETED
@@ -1,28 +0,0 @@
1
- import torch
2
- import torch.utils.data as data
3
-
4
- from src.core import register
5
-
6
-
7
- __all__ = ['DataLoader']
8
-
9
-
10
- @register
11
- class DataLoader(data.DataLoader):
12
- __inject__ = ['dataset', 'collate_fn']
13
-
14
- def __repr__(self) -> str:
15
- format_string = self.__class__.__name__ + "("
16
- for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']:
17
- format_string += "\n"
18
- format_string += " {0}: {1}".format(n, getattr(self, n))
19
- format_string += "\n)"
20
- return format_string
21
-
22
-
23
-
24
- @register
25
- def default_collate_fn(items):
26
- '''default collate_fn
27
- '''
28
- return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/functional.py DELETED
@@ -1,169 +0,0 @@
1
- import torch
2
- import torchvision.transforms.functional as F
3
-
4
- from packaging import version
5
- from typing import Optional, List
6
- from torch import Tensor
7
-
8
- # needed due to empty tensor bug in pytorch and torchvision 0.5
9
- import torchvision
10
- if version.parse(torchvision.__version__) < version.parse('0.7'):
11
- from torchvision.ops import _new_empty_tensor
12
- from torchvision.ops.misc import _output_size
13
-
14
-
15
- def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
16
- # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
17
- """
18
- Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
19
- This will eventually be supported natively by PyTorch, and this
20
- class can go away.
21
- """
22
- if version.parse(torchvision.__version__) < version.parse('0.7'):
23
- if input.numel() > 0:
24
- return torch.nn.functional.interpolate(
25
- input, size, scale_factor, mode, align_corners
26
- )
27
-
28
- output_shape = _output_size(2, input, size, scale_factor)
29
- output_shape = list(input.shape[:-2]) + list(output_shape)
30
- return _new_empty_tensor(input, output_shape)
31
- else:
32
- return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
33
-
34
-
35
-
36
- def crop(image, target, region):
37
- cropped_image = F.crop(image, *region)
38
-
39
- target = target.copy()
40
- i, j, h, w = region
41
-
42
- # should we do something wrt the original size?
43
- target["size"] = torch.tensor([h, w])
44
-
45
- fields = ["labels", "area", "iscrowd"]
46
-
47
- if "boxes" in target:
48
- boxes = target["boxes"]
49
- max_size = torch.as_tensor([w, h], dtype=torch.float32)
50
- cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
51
- cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
52
- cropped_boxes = cropped_boxes.clamp(min=0)
53
- area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
54
- target["boxes"] = cropped_boxes.reshape(-1, 4)
55
- target["area"] = area
56
- fields.append("boxes")
57
-
58
- if "masks" in target:
59
- # FIXME should we update the area here if there are no boxes?
60
- target['masks'] = target['masks'][:, i:i + h, j:j + w]
61
- fields.append("masks")
62
-
63
- # remove elements for which the boxes or masks that have zero area
64
- if "boxes" in target or "masks" in target:
65
- # favor boxes selection when defining which elements to keep
66
- # this is compatible with previous implementation
67
- if "boxes" in target:
68
- cropped_boxes = target['boxes'].reshape(-1, 2, 2)
69
- keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
70
- else:
71
- keep = target['masks'].flatten(1).any(1)
72
-
73
- for field in fields:
74
- target[field] = target[field][keep]
75
-
76
- return cropped_image, target
77
-
78
-
79
- def hflip(image, target):
80
- flipped_image = F.hflip(image)
81
-
82
- w, h = image.size
83
-
84
- target = target.copy()
85
- if "boxes" in target:
86
- boxes = target["boxes"]
87
- boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
88
- target["boxes"] = boxes
89
-
90
- if "masks" in target:
91
- target['masks'] = target['masks'].flip(-1)
92
-
93
- return flipped_image, target
94
-
95
-
96
- def resize(image, target, size, max_size=None):
97
- # size can be min_size (scalar) or (w, h) tuple
98
-
99
- def get_size_with_aspect_ratio(image_size, size, max_size=None):
100
- w, h = image_size
101
- if max_size is not None:
102
- min_original_size = float(min((w, h)))
103
- max_original_size = float(max((w, h)))
104
- if max_original_size / min_original_size * size > max_size:
105
- size = int(round(max_size * min_original_size / max_original_size))
106
-
107
- if (w <= h and w == size) or (h <= w and h == size):
108
- return (h, w)
109
-
110
- if w < h:
111
- ow = size
112
- oh = int(size * h / w)
113
- else:
114
- oh = size
115
- ow = int(size * w / h)
116
-
117
- # r = min(size / min(h, w), max_size / max(h, w))
118
- # ow = int(w * r)
119
- # oh = int(h * r)
120
-
121
- return (oh, ow)
122
-
123
- def get_size(image_size, size, max_size=None):
124
- if isinstance(size, (list, tuple)):
125
- return size[::-1]
126
- else:
127
- return get_size_with_aspect_ratio(image_size, size, max_size)
128
-
129
- size = get_size(image.size, size, max_size)
130
- rescaled_image = F.resize(image, size)
131
-
132
- if target is None:
133
- return rescaled_image, None
134
-
135
- ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
136
- ratio_width, ratio_height = ratios
137
-
138
- target = target.copy()
139
- if "boxes" in target:
140
- boxes = target["boxes"]
141
- scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
142
- target["boxes"] = scaled_boxes
143
-
144
- if "area" in target:
145
- area = target["area"]
146
- scaled_area = area * (ratio_width * ratio_height)
147
- target["area"] = scaled_area
148
-
149
- h, w = size
150
- target["size"] = torch.tensor([h, w])
151
-
152
- if "masks" in target:
153
- target['masks'] = interpolate(
154
- target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
155
-
156
- return rescaled_image, target
157
-
158
-
159
- def pad(image, target, padding):
160
- # assumes that we only pad on the bottom right corners
161
- padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
162
- if target is None:
163
- return padded_image, None
164
- target = target.copy()
165
- # should we do something wrt the original size?
166
- target["size"] = torch.tensor(padded_image.size[::-1])
167
- if "masks" in target:
168
- target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
169
- return padded_image, target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/data/transforms.py DELETED
@@ -1,142 +0,0 @@
1
- """"by lyuwenyu
2
- """
3
-
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- import torchvision
9
- torchvision.disable_beta_transforms_warning()
10
- from torchvision import datapoints
11
-
12
- import torchvision.transforms.v2 as T
13
- import torchvision.transforms.v2.functional as F
14
-
15
- from PIL import Image
16
- from typing import Any, Dict, List, Optional
17
-
18
- from src.core import register, GLOBAL_CONFIG
19
-
20
-
21
- __all__ = ['Compose', ]
22
-
23
-
24
- RandomPhotometricDistort = register(T.RandomPhotometricDistort)
25
- RandomZoomOut = register(T.RandomZoomOut)
26
- # RandomIoUCrop = register(T.RandomIoUCrop)
27
- RandomHorizontalFlip = register(T.RandomHorizontalFlip)
28
- Resize = register(T.Resize)
29
- ToImageTensor = register(T.ToImageTensor)
30
- ConvertDtype = register(T.ConvertDtype)
31
- SanitizeBoundingBox = register(T.SanitizeBoundingBox)
32
- RandomCrop = register(T.RandomCrop)
33
- Normalize = register(T.Normalize)
34
-
35
-
36
-
37
- @register
38
- class Compose(T.Compose):
39
- def __init__(self, ops) -> None:
40
- transforms = []
41
- if ops is not None:
42
- for op in ops:
43
- if isinstance(op, dict):
44
- name = op.pop('type')
45
- transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op)
46
- transforms.append(transfom)
47
- # op['type'] = name
48
- elif isinstance(op, nn.Module):
49
- transforms.append(op)
50
-
51
- else:
52
- raise ValueError('')
53
- else:
54
- transforms =[EmptyTransform(), ]
55
-
56
- super().__init__(transforms=transforms)
57
-
58
-
59
- @register
60
- class EmptyTransform(T.Transform):
61
- def __init__(self, ) -> None:
62
- super().__init__()
63
-
64
- def forward(self, *inputs):
65
- inputs = inputs if len(inputs) > 1 else inputs[0]
66
- return inputs
67
-
68
-
69
- @register
70
- class PadToSize(T.Pad):
71
- _transformed_types = (
72
- Image.Image,
73
- datapoints.Image,
74
- datapoints.Video,
75
- datapoints.Mask,
76
- datapoints.BoundingBox,
77
- )
78
- def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
79
- sz = F.get_spatial_size(flat_inputs[0])
80
- h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1]
81
- self.padding = [0, 0, w, h]
82
- return dict(padding=self.padding)
83
-
84
- def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
85
- if isinstance(spatial_size, int):
86
- spatial_size = (spatial_size, spatial_size)
87
-
88
- self.spatial_size = spatial_size
89
- super().__init__(0, fill, padding_mode)
90
-
91
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
92
- fill = self._fill[type(inpt)]
93
- padding = params['padding']
94
- return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
95
-
96
- def __call__(self, *inputs: Any) -> Any:
97
- outputs = super().forward(*inputs)
98
- if len(outputs) > 1 and isinstance(outputs[1], dict):
99
- outputs[1]['padding'] = torch.tensor(self.padding)
100
- return outputs
101
-
102
-
103
- @register
104
- class RandomIoUCrop(T.RandomIoUCrop):
105
- def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0):
106
- super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials)
107
- self.p = p
108
-
109
- def __call__(self, *inputs: Any) -> Any:
110
- if torch.rand(1) >= self.p:
111
- return inputs if len(inputs) > 1 else inputs[0]
112
-
113
- return super().forward(*inputs)
114
-
115
-
116
- @register
117
- class ConvertBox(T.Transform):
118
- _transformed_types = (
119
- datapoints.BoundingBox,
120
- )
121
- def __init__(self, out_fmt='', normalize=False) -> None:
122
- super().__init__()
123
- self.out_fmt = out_fmt
124
- self.normalize = normalize
125
-
126
- self.data_fmt = {
127
- 'xyxy': datapoints.BoundingBoxFormat.XYXY,
128
- 'cxcywh': datapoints.BoundingBoxFormat.CXCYWH
129
- }
130
-
131
- def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
132
- if self.out_fmt:
133
- spatial_size = inpt.spatial_size
134
- in_fmt = inpt.format.value.lower()
135
- inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt)
136
- inpt = datapoints.BoundingBox(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size)
137
-
138
- if self.normalize:
139
- inpt = inpt / torch.tensor(inpt.spatial_size[::-1]).tile(2)[None]
140
-
141
- return inpt
142
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/misc/__init__.py DELETED
@@ -1,3 +0,0 @@
1
-
2
- from .logger import *
3
- from .visualizer import *
 
 
 
 
src2/misc/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (211 Bytes)
 
src2/misc/__pycache__/dist.cpython-310.pyc DELETED
Binary file (4.79 kB)
 
src2/misc/__pycache__/logger.cpython-310.pyc DELETED
Binary file (7.77 kB)
 
src2/misc/__pycache__/visualizer.cpython-310.pyc DELETED
Binary file (1.09 kB)
 
src2/misc/dist.py DELETED
@@ -1,190 +0,0 @@
1
- """
2
- reference
3
- - https://github.com/pytorch/vision/blob/main/references/detection/utils.py
4
- - https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
5
-
6
- by lyuwenyu
7
- """
8
-
9
- import random
10
- import numpy as np
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.distributed
15
- import torch.distributed as tdist
16
-
17
- from torch.nn.parallel import DistributedDataParallel as DDP
18
-
19
- from torch.utils.data import DistributedSampler
20
- from torch.utils.data.dataloader import DataLoader
21
-
22
-
23
- def init_distributed():
24
- '''
25
- distributed setup
26
- args:
27
- backend (str), ('nccl', 'gloo')
28
- '''
29
- try:
30
- # # https://pytorch.org/docs/stable/elastic/run.html
31
- # LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
32
- # RANK = int(os.getenv('RANK', -1))
33
- # WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
34
-
35
- tdist.init_process_group(init_method='env://', )
36
- torch.distributed.barrier()
37
-
38
- rank = get_rank()
39
- device = torch.device(f'cuda:{rank}')
40
- torch.cuda.set_device(device)
41
-
42
- setup_print(rank == 0)
43
- print('Initialized distributed mode...')
44
-
45
- return True
46
-
47
- except:
48
- print('Not init distributed mode.')
49
- return False
50
-
51
-
52
- def setup_print(is_main):
53
- '''This function disables printing when not in master process
54
- '''
55
- import builtins as __builtin__
56
- builtin_print = __builtin__.print
57
-
58
- def print(*args, **kwargs):
59
- force = kwargs.pop('force', False)
60
- if is_main or force:
61
- builtin_print(*args, **kwargs)
62
-
63
- __builtin__.print = print
64
-
65
-
66
- def is_dist_available_and_initialized():
67
- if not tdist.is_available():
68
- return False
69
- if not tdist.is_initialized():
70
- return False
71
- return True
72
-
73
-
74
- def get_rank():
75
- if not is_dist_available_and_initialized():
76
- return 0
77
- return tdist.get_rank()
78
-
79
-
80
- def get_world_size():
81
- if not is_dist_available_and_initialized():
82
- return 1
83
- return tdist.get_world_size()
84
-
85
-
86
- def is_main_process():
87
- return get_rank() == 0
88
-
89
-
90
- def save_on_master(*args, **kwargs):
91
- if is_main_process():
92
- torch.save(*args, **kwargs)
93
-
94
-
95
-
96
- def warp_model(model, find_unused_parameters=False, sync_bn=False,):
97
- if is_dist_available_and_initialized():
98
- rank = get_rank()
99
- model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
100
- model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
101
- return model
102
-
103
-
104
- def warp_loader(loader, shuffle=False):
105
- if is_dist_available_and_initialized():
106
- sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
107
- loader = DataLoader(loader.dataset,
108
- loader.batch_size,
109
- sampler=sampler,
110
- drop_last=loader.drop_last,
111
- collate_fn=loader.collate_fn,
112
- pin_memory=loader.pin_memory,
113
- num_workers=loader.num_workers, )
114
- return loader
115
-
116
-
117
-
118
- def is_parallel(model) -> bool:
119
- # Returns True if model is of type DP or DDP
120
- return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel)
121
-
122
-
123
- def de_parallel(model) -> nn.Module:
124
- # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
125
- return model.module if is_parallel(model) else model
126
-
127
-
128
- def reduce_dict(data, avg=True):
129
- '''
130
- Args
131
- data dict: input, {k: v, ...}
132
- avg bool: true
133
- '''
134
- world_size = get_world_size()
135
- if world_size < 2:
136
- return data
137
-
138
- with torch.no_grad():
139
- keys, values = [], []
140
- for k in sorted(data.keys()):
141
- keys.append(k)
142
- values.append(data[k])
143
-
144
- values = torch.stack(values, dim=0)
145
- tdist.all_reduce(values)
146
-
147
- if avg is True:
148
- values /= world_size
149
-
150
- _data = {k: v for k, v in zip(keys, values)}
151
-
152
- return _data
153
-
154
-
155
-
156
- def all_gather(data):
157
- """
158
- Run all_gather on arbitrary picklable data (not necessarily tensors)
159
- Args:
160
- data: any picklable object
161
- Returns:
162
- list[data]: list of data gathered from each rank
163
- """
164
- world_size = get_world_size()
165
- if world_size == 1:
166
- return [data]
167
- data_list = [None] * world_size
168
- tdist.all_gather_object(data_list, data)
169
- return data_list
170
-
171
-
172
- import time
173
- def sync_time():
174
- '''sync_time
175
- '''
176
- if torch.cuda.is_available():
177
- torch.cuda.synchronize()
178
-
179
- return time.time()
180
-
181
-
182
-
183
- def set_seed(seed):
184
- # fix the seed for reproducibility
185
- seed = seed + get_rank()
186
- torch.manual_seed(seed)
187
- np.random.seed(seed)
188
- random.seed(seed)
189
-
190
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/misc/logger.py DELETED
@@ -1,239 +0,0 @@
1
- """
2
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
- https://github.com/facebookresearch/detr/blob/main/util/misc.py
4
- Mostly copy-paste from torchvision references.
5
- """
6
-
7
- import time
8
- import pickle
9
- import datetime
10
- from collections import defaultdict, deque
11
- from typing import Dict
12
-
13
- import torch
14
- import torch.distributed as tdist
15
-
16
- from .dist import is_dist_available_and_initialized, get_world_size
17
-
18
-
19
- class SmoothedValue(object):
20
- """Track a series of values and provide access to smoothed values over a
21
- window or the global series average.
22
- """
23
-
24
- def __init__(self, window_size=20, fmt=None):
25
- if fmt is None:
26
- fmt = "{median:.4f} ({global_avg:.4f})"
27
- self.deque = deque(maxlen=window_size)
28
- self.total = 0.0
29
- self.count = 0
30
- self.fmt = fmt
31
-
32
- def update(self, value, n=1):
33
- self.deque.append(value)
34
- self.count += n
35
- self.total += value * n
36
-
37
- def synchronize_between_processes(self):
38
- """
39
- Warning: does not synchronize the deque!
40
- """
41
- if not is_dist_available_and_initialized():
42
- return
43
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
44
- tdist.barrier()
45
- tdist.all_reduce(t)
46
- t = t.tolist()
47
- self.count = int(t[0])
48
- self.total = t[1]
49
-
50
- @property
51
- def median(self):
52
- d = torch.tensor(list(self.deque))
53
- return d.median().item()
54
-
55
- @property
56
- def avg(self):
57
- d = torch.tensor(list(self.deque), dtype=torch.float32)
58
- return d.mean().item()
59
-
60
- @property
61
- def global_avg(self):
62
- return self.total / self.count
63
-
64
- @property
65
- def max(self):
66
- return max(self.deque)
67
-
68
- @property
69
- def value(self):
70
- return self.deque[-1]
71
-
72
- def __str__(self):
73
- return self.fmt.format(
74
- median=self.median,
75
- avg=self.avg,
76
- global_avg=self.global_avg,
77
- max=self.max,
78
- value=self.value)
79
-
80
-
81
- def all_gather(data):
82
- """
83
- Run all_gather on arbitrary picklable data (not necessarily tensors)
84
- Args:
85
- data: any picklable object
86
- Returns:
87
- list[data]: list of data gathered from each rank
88
- """
89
- world_size = get_world_size()
90
- if world_size == 1:
91
- return [data]
92
-
93
- # serialized to a Tensor
94
- buffer = pickle.dumps(data)
95
- storage = torch.ByteStorage.from_buffer(buffer)
96
- tensor = torch.ByteTensor(storage).to("cuda")
97
-
98
- # obtain Tensor size of each rank
99
- local_size = torch.tensor([tensor.numel()], device="cuda")
100
- size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
101
- tdist.all_gather(size_list, local_size)
102
- size_list = [int(size.item()) for size in size_list]
103
- max_size = max(size_list)
104
-
105
- # receiving Tensor from all ranks
106
- # we pad the tensor because torch all_gather does not support
107
- # gathering tensors of different shapes
108
- tensor_list = []
109
- for _ in size_list:
110
- tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
111
- if local_size != max_size:
112
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
113
- tensor = torch.cat((tensor, padding), dim=0)
114
- tdist.all_gather(tensor_list, tensor)
115
-
116
- data_list = []
117
- for size, tensor in zip(size_list, tensor_list):
118
- buffer = tensor.cpu().numpy().tobytes()[:size]
119
- data_list.append(pickle.loads(buffer))
120
-
121
- return data_list
122
-
123
-
124
- def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]:
125
- """
126
- Args:
127
- input_dict (dict): all the values will be reduced
128
- average (bool): whether to do average or sum
129
- Reduce the values in the dictionary from all processes so that all processes
130
- have the averaged results. Returns a dict with the same fields as
131
- input_dict, after reduction.
132
- """
133
- world_size = get_world_size()
134
- if world_size < 2:
135
- return input_dict
136
- with torch.no_grad():
137
- names = []
138
- values = []
139
- # sort the keys so that they are consistent across processes
140
- for k in sorted(input_dict.keys()):
141
- names.append(k)
142
- values.append(input_dict[k])
143
- values = torch.stack(values, dim=0)
144
- tdist.all_reduce(values)
145
- if average:
146
- values /= world_size
147
- reduced_dict = {k: v for k, v in zip(names, values)}
148
- return reduced_dict
149
-
150
-
151
- class MetricLogger(object):
152
- def __init__(self, delimiter="\t"):
153
- self.meters = defaultdict(SmoothedValue)
154
- self.delimiter = delimiter
155
-
156
- def update(self, **kwargs):
157
- for k, v in kwargs.items():
158
- if isinstance(v, torch.Tensor):
159
- v = v.item()
160
- assert isinstance(v, (float, int))
161
- self.meters[k].update(v)
162
-
163
- def __getattr__(self, attr):
164
- if attr in self.meters:
165
- return self.meters[attr]
166
- if attr in self.__dict__:
167
- return self.__dict__[attr]
168
- raise AttributeError("'{}' object has no attribute '{}'".format(
169
- type(self).__name__, attr))
170
-
171
- def __str__(self):
172
- loss_str = []
173
- for name, meter in self.meters.items():
174
- loss_str.append(
175
- "{}: {}".format(name, str(meter))
176
- )
177
- return self.delimiter.join(loss_str)
178
-
179
- def synchronize_between_processes(self):
180
- for meter in self.meters.values():
181
- meter.synchronize_between_processes()
182
-
183
- def add_meter(self, name, meter):
184
- self.meters[name] = meter
185
-
186
- def log_every(self, iterable, print_freq, header=None):
187
- i = 0
188
- if not header:
189
- header = ''
190
- start_time = time.time()
191
- end = time.time()
192
- iter_time = SmoothedValue(fmt='{avg:.4f}')
193
- data_time = SmoothedValue(fmt='{avg:.4f}')
194
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
195
- if torch.cuda.is_available():
196
- log_msg = self.delimiter.join([
197
- header,
198
- '[{0' + space_fmt + '}/{1}]',
199
- 'eta: {eta}',
200
- '{meters}',
201
- 'time: {time}',
202
- 'data: {data}',
203
- 'max mem: {memory:.0f}'
204
- ])
205
- else:
206
- log_msg = self.delimiter.join([
207
- header,
208
- '[{0' + space_fmt + '}/{1}]',
209
- 'eta: {eta}',
210
- '{meters}',
211
- 'time: {time}',
212
- 'data: {data}'
213
- ])
214
- MB = 1024.0 * 1024.0
215
- for obj in iterable:
216
- data_time.update(time.time() - end)
217
- yield obj
218
- iter_time.update(time.time() - end)
219
- if i % print_freq == 0 or i == len(iterable) - 1:
220
- eta_seconds = iter_time.global_avg * (len(iterable) - i)
221
- eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
222
- if torch.cuda.is_available():
223
- print(log_msg.format(
224
- i, len(iterable), eta=eta_string,
225
- meters=str(self),
226
- time=str(iter_time), data=str(data_time),
227
- memory=torch.cuda.max_memory_allocated() / MB))
228
- else:
229
- print(log_msg.format(
230
- i, len(iterable), eta=eta_string,
231
- meters=str(self),
232
- time=str(iter_time), data=str(data_time)))
233
- i += 1
234
- end = time.time()
235
- total_time = time.time() - start_time
236
- total_time_str = str(datetime.timedelta(seconds=int(total_time)))
237
- print('{} Total time: {} ({:.4f} s / it)'.format(
238
- header, total_time_str, total_time / len(iterable)))
239
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/misc/visualizer.py DELETED
@@ -1,34 +0,0 @@
1
- """"by lyuwenyu
2
- """
3
-
4
- import torch
5
- import torch.utils.data
6
-
7
- import torchvision
8
- torchvision.disable_beta_transforms_warning()
9
-
10
- import PIL
11
-
12
- __all__ = ['show_sample']
13
-
14
- def show_sample(sample):
15
- """for coco dataset/dataloader
16
- """
17
- import matplotlib.pyplot as plt
18
- from torchvision.transforms.v2 import functional as F
19
- from torchvision.utils import draw_bounding_boxes
20
-
21
- image, target = sample
22
- if isinstance(image, PIL.Image.Image):
23
- image = F.to_image_tensor(image)
24
-
25
- image = F.convert_dtype(image, torch.uint8)
26
- annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
27
-
28
- fig, ax = plt.subplots()
29
- ax.imshow(annotated_image.permute(1, 2, 0).numpy())
30
- ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
31
- fig.tight_layout()
32
- fig.show()
33
- plt.show()
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/nn/__init__.py DELETED
@@ -1,7 +0,0 @@
1
-
2
- from .arch import *
3
- from .criterion import *
4
-
5
- #
6
- from .backbone import *
7
-
 
 
 
 
 
 
 
 
src2/nn/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (226 Bytes)
 
src2/nn/arch/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .classification import *
 
 
src2/nn/arch/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (200 Bytes)
 
src2/nn/arch/__pycache__/classification.cpython-310.pyc DELETED
Binary file (1.55 kB)
 
src2/nn/arch/classification.py DELETED
@@ -1,41 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from src.core import register
5
-
6
-
7
- __all__ = ['Classification', 'ClassHead']
8
-
9
-
10
- @register
11
- class Classification(nn.Module):
12
- __inject__ = ['backbone', 'head']
13
-
14
- def __init__(self, backbone: nn.Module, head: nn.Module=None):
15
- super().__init__()
16
-
17
- self.backbone = backbone
18
- self.head = head
19
-
20
- def forward(self, x):
21
- x = self.backbone(x)
22
-
23
- if self.head is not None:
24
- x = self.head(x)
25
-
26
- return x
27
-
28
-
29
- @register
30
- class ClassHead(nn.Module):
31
- def __init__(self, hidden_dim, num_classes):
32
- super().__init__()
33
- self.pool = nn.AdaptiveAvgPool2d(1)
34
- self.proj = nn.Linear(hidden_dim, num_classes)
35
-
36
- def forward(self, x):
37
- x = x[0] if isinstance(x, (list, tuple)) else x
38
- x = self.pool(x)
39
- x = x.reshape(x.shape[0], -1)
40
- x = self.proj(x)
41
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/nn/backbone/__init__.py DELETED
@@ -1,5 +0,0 @@
1
-
2
- from .presnet import *
3
- from .test_resnet import *
4
-
5
- from .common import *
 
 
 
 
 
 
src2/nn/backbone/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (238 Bytes)
 
src2/nn/backbone/__pycache__/common.cpython-310.pyc DELETED
Binary file (3.28 kB)
 
src2/nn/backbone/__pycache__/presnet.cpython-310.pyc DELETED
Binary file (6.45 kB)
 
src2/nn/backbone/__pycache__/test_resnet.cpython-310.pyc DELETED
Binary file (3.06 kB)
 
src2/nn/backbone/common.py DELETED
@@ -1,102 +0,0 @@
1
- '''by lyuwenyu
2
- '''
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
-
9
- class ConvNormLayer(nn.Module):
10
- def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
11
- super().__init__()
12
- self.conv = nn.Conv2d(
13
- ch_in,
14
- ch_out,
15
- kernel_size,
16
- stride,
17
- padding=(kernel_size-1)//2 if padding is None else padding,
18
- bias=bias)
19
- self.norm = nn.BatchNorm2d(ch_out)
20
- self.act = nn.Identity() if act is None else get_activation(act)
21
-
22
- def forward(self, x):
23
- return self.act(self.norm(self.conv(x)))
24
-
25
-
26
- class FrozenBatchNorm2d(nn.Module):
27
- """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
28
- BatchNorm2d where the batch statistics and the affine parameters are fixed.
29
- Copy-paste from torchvision.misc.ops with added eps before rqsrt,
30
- without which any other models than torchvision.models.resnet[18,34,50,101]
31
- produce nans.
32
- """
33
- def __init__(self, num_features, eps=1e-5):
34
- super(FrozenBatchNorm2d, self).__init__()
35
- n = num_features
36
- self.register_buffer("weight", torch.ones(n))
37
- self.register_buffer("bias", torch.zeros(n))
38
- self.register_buffer("running_mean", torch.zeros(n))
39
- self.register_buffer("running_var", torch.ones(n))
40
- self.eps = eps
41
- self.num_features = n
42
-
43
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
44
- missing_keys, unexpected_keys, error_msgs):
45
- num_batches_tracked_key = prefix + 'num_batches_tracked'
46
- if num_batches_tracked_key in state_dict:
47
- del state_dict[num_batches_tracked_key]
48
-
49
- super(FrozenBatchNorm2d, self)._load_from_state_dict(
50
- state_dict, prefix, local_metadata, strict,
51
- missing_keys, unexpected_keys, error_msgs)
52
-
53
- def forward(self, x):
54
- # move reshapes to the beginning
55
- # to make it fuser-friendly
56
- w = self.weight.reshape(1, -1, 1, 1)
57
- b = self.bias.reshape(1, -1, 1, 1)
58
- rv = self.running_var.reshape(1, -1, 1, 1)
59
- rm = self.running_mean.reshape(1, -1, 1, 1)
60
- scale = w * (rv + self.eps).rsqrt()
61
- bias = b - rm * scale
62
- return x * scale + bias
63
-
64
- def extra_repr(self):
65
- return (
66
- "{num_features}, eps={eps}".format(**self.__dict__)
67
- )
68
-
69
-
70
- def get_activation(act: str, inpace: bool=True):
71
- '''get activation
72
- '''
73
- act = act.lower()
74
-
75
- if act == 'silu':
76
- m = nn.SiLU()
77
-
78
- elif act == 'relu':
79
- m = nn.ReLU()
80
-
81
- elif act == 'leaky_relu':
82
- m = nn.LeakyReLU()
83
-
84
- elif act == 'silu':
85
- m = nn.SiLU()
86
-
87
- elif act == 'gelu':
88
- m = nn.GELU()
89
-
90
- elif act is None:
91
- m = nn.Identity()
92
-
93
- elif isinstance(act, nn.Module):
94
- m = act
95
-
96
- else:
97
- raise RuntimeError('')
98
-
99
- if hasattr(m, 'inplace'):
100
- m.inplace = inpace
101
-
102
- return m
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/nn/backbone/presnet.py DELETED
@@ -1,225 +0,0 @@
1
- '''by lyuwenyu
2
- '''
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from collections import OrderedDict
8
-
9
- from .common import get_activation, ConvNormLayer, FrozenBatchNorm2d
10
-
11
- from src.core import register
12
-
13
-
14
- __all__ = ['PResNet']
15
-
16
-
17
- ResNet_cfg = {
18
- 18: [2, 2, 2, 2],
19
- 34: [3, 4, 6, 3],
20
- 50: [3, 4, 6, 3],
21
- 101: [3, 4, 23, 3],
22
- # 152: [3, 8, 36, 3],
23
- }
24
-
25
-
26
- donwload_url = {
27
- 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth',
28
- 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth',
29
- 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth',
30
- 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth',
31
- }
32
-
33
-
34
- class BasicBlock(nn.Module):
35
- expansion = 1
36
-
37
- def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
38
- super().__init__()
39
-
40
- self.shortcut = shortcut
41
-
42
- if not shortcut:
43
- if variant == 'd' and stride == 2:
44
- self.short = nn.Sequential(OrderedDict([
45
- ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
46
- ('conv', ConvNormLayer(ch_in, ch_out, 1, 1))
47
- ]))
48
- else:
49
- self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
50
-
51
- self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
52
- self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
53
- self.act = nn.Identity() if act is None else get_activation(act)
54
-
55
-
56
- def forward(self, x):
57
- out = self.branch2a(x)
58
- out = self.branch2b(out)
59
- if self.shortcut:
60
- short = x
61
- else:
62
- short = self.short(x)
63
-
64
- out = out + short
65
- out = self.act(out)
66
-
67
- return out
68
-
69
-
70
- class BottleNeck(nn.Module):
71
- expansion = 4
72
-
73
- def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
74
- super().__init__()
75
-
76
- if variant == 'a':
77
- stride1, stride2 = stride, 1
78
- else:
79
- stride1, stride2 = 1, stride
80
-
81
- width = ch_out
82
-
83
- self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
84
- self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
85
- self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
86
-
87
- self.shortcut = shortcut
88
- if not shortcut:
89
- if variant == 'd' and stride == 2:
90
- self.short = nn.Sequential(OrderedDict([
91
- ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
92
- ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
93
- ]))
94
- else:
95
- self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
96
-
97
- self.act = nn.Identity() if act is None else get_activation(act)
98
-
99
- def forward(self, x):
100
- out = self.branch2a(x)
101
- out = self.branch2b(out)
102
- out = self.branch2c(out)
103
-
104
- if self.shortcut:
105
- short = x
106
- else:
107
- short = self.short(x)
108
-
109
- out = out + short
110
- out = self.act(out)
111
-
112
- return out
113
-
114
-
115
- class Blocks(nn.Module):
116
- def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'):
117
- super().__init__()
118
-
119
- self.blocks = nn.ModuleList()
120
- for i in range(count):
121
- self.blocks.append(
122
- block(
123
- ch_in,
124
- ch_out,
125
- stride=2 if i == 0 and stage_num != 2 else 1,
126
- shortcut=False if i == 0 else True,
127
- variant=variant,
128
- act=act)
129
- )
130
-
131
- if i == 0:
132
- ch_in = ch_out * block.expansion
133
-
134
- def forward(self, x):
135
- out = x
136
- for block in self.blocks:
137
- out = block(out)
138
- return out
139
-
140
-
141
- @register
142
- class PResNet(nn.Module):
143
- def __init__(
144
- self,
145
- depth,
146
- variant='d',
147
- num_stages=4,
148
- return_idx=[0, 1, 2, 3],
149
- act='relu',
150
- freeze_at=-1,
151
- freeze_norm=True,
152
- pretrained=False):
153
- super().__init__()
154
-
155
- block_nums = ResNet_cfg[depth]
156
- ch_in = 64
157
- if variant in ['c', 'd']:
158
- conv_def = [
159
- [3, ch_in // 2, 3, 2, "conv1_1"],
160
- [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
161
- [ch_in // 2, ch_in, 3, 1, "conv1_3"],
162
- ]
163
- else:
164
- conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
165
-
166
- self.conv1 = nn.Sequential(OrderedDict([
167
- (_name, ConvNormLayer(c_in, c_out, k, s, act=act)) for c_in, c_out, k, s, _name in conv_def
168
- ]))
169
-
170
- ch_out_list = [64, 128, 256, 512]
171
- block = BottleNeck if depth >= 50 else BasicBlock
172
-
173
- _out_channels = [block.expansion * v for v in ch_out_list]
174
- _out_strides = [4, 8, 16, 32]
175
-
176
- self.res_layers = nn.ModuleList()
177
- for i in range(num_stages):
178
- stage_num = i + 2
179
- self.res_layers.append(
180
- Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant)
181
- )
182
- ch_in = _out_channels[i]
183
-
184
- self.return_idx = return_idx
185
- self.out_channels = [_out_channels[_i] for _i in return_idx]
186
- self.out_strides = [_out_strides[_i] for _i in return_idx]
187
-
188
- if freeze_at >= 0:
189
- self._freeze_parameters(self.conv1)
190
- for i in range(min(freeze_at, num_stages)):
191
- self._freeze_parameters(self.res_layers[i])
192
-
193
- if freeze_norm:
194
- self._freeze_norm(self)
195
-
196
- if pretrained:
197
- state = torch.hub.load_state_dict_from_url(donwload_url[depth])
198
- self.load_state_dict(state)
199
- print(f'Load PResNet{depth} state_dict')
200
-
201
- def _freeze_parameters(self, m: nn.Module):
202
- for p in m.parameters():
203
- p.requires_grad = False
204
-
205
- def _freeze_norm(self, m: nn.Module):
206
- if isinstance(m, nn.BatchNorm2d):
207
- m = FrozenBatchNorm2d(m.num_features)
208
- else:
209
- for name, child in m.named_children():
210
- _child = self._freeze_norm(child)
211
- if _child is not child:
212
- setattr(m, name, _child)
213
- return m
214
-
215
- def forward(self, x):
216
- conv1 = self.conv1(x)
217
- x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
218
- outs = []
219
- for idx, stage in enumerate(self.res_layers):
220
- x = stage(x)
221
- if idx in self.return_idx:
222
- outs.append(x)
223
- return outs
224
-
225
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/nn/backbone/test_resnet.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from collections import OrderedDict
6
-
7
-
8
- from src.core import register
9
-
10
-
11
- class BasicBlock(nn.Module):
12
- expansion = 1
13
-
14
- def __init__(self, in_planes, planes, stride=1):
15
- super(BasicBlock, self).__init__()
16
-
17
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
18
- self.bn1 = nn.BatchNorm2d(planes)
19
-
20
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)
21
- self.bn2 = nn.BatchNorm2d(planes)
22
-
23
- self.shortcut = nn.Sequential()
24
- if stride != 1 or in_planes != self.expansion*planes:
25
- self.shortcut = nn.Sequential(
26
- nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False),
27
- nn.BatchNorm2d(self.expansion*planes)
28
- )
29
- def forward(self, x):
30
- out = F.relu(self.bn1(self.conv1(x)))
31
- out = self.bn2(self.conv2(out))
32
- out += self.shortcut(x)
33
- out = F.relu(out)
34
- return out
35
-
36
-
37
-
38
- class _ResNet(nn.Module):
39
- def __init__(self, block, num_blocks, num_classes=10):
40
- super().__init__()
41
- self.in_planes = 64
42
-
43
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
44
- self.bn1 = nn.BatchNorm2d(64)
45
-
46
- self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
47
- self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
48
- self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
49
- self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
50
-
51
- self.linear = nn.Linear(512 * block.expansion, num_classes)
52
-
53
- def _make_layer(self, block, planes, num_blocks, stride):
54
- strides = [stride] + [1]*(num_blocks-1)
55
- layers = []
56
- for stride in strides:
57
- layers.append(block(self.in_planes, planes, stride))
58
- self.in_planes = planes * block.expansion
59
- return nn.Sequential(*layers)
60
-
61
- def forward(self, x):
62
- out = F.relu(self.bn1(self.conv1(x)))
63
- out = self.layer1(out)
64
- out = self.layer2(out)
65
- out = self.layer3(out)
66
- out = self.layer4(out)
67
- out = F.avg_pool2d(out, 4)
68
- out = out.view(out.size(0), -1)
69
- out = self.linear(out)
70
- return out
71
-
72
-
73
- @register
74
- class MResNet(nn.Module):
75
- def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None:
76
- super().__init__()
77
- self.model = _ResNet(BasicBlock, num_blocks, num_classes)
78
-
79
- def forward(self, x):
80
- return self.model(x)
81
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src2/nn/backbone/utils.py DELETED
@@ -1,58 +0,0 @@
1
- """
2
- https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
3
-
4
- by lyuwenyu
5
- """
6
-
7
- from collections import OrderedDict
8
- from typing import Dict, List
9
-
10
-
11
- import torch.nn as nn
12
-
13
-
14
- class IntermediateLayerGetter(nn.ModuleDict):
15
- """
16
- Module wrapper that returns intermediate layers from a model
17
-
18
- It has a strong assumption that the modules have been registered
19
- into the model in the same order as they are used.
20
- This means that one should **not** reuse the same nn.Module
21
- twice in the forward if you want this to work.
22
-
23
- Additionally, it is only able to query submodules that are directly
24
- assigned to the model. So if `model` is passed, `model.feature1` can
25
- be returned, but not `model.feature1.layer2`.
26
- """
27
-
28
- _version = 3
29
-
30
- def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
31
- if not set(return_layers).issubset([name for name, _ in model.named_children()]):
32
- raise ValueError("return_layers are not present in model. {}"\
33
- .format([name for name, _ in model.named_children()]))
34
- orig_return_layers = return_layers
35
- return_layers = {str(k): str(k) for k in return_layers}
36
- layers = OrderedDict()
37
- for name, module in model.named_children():
38
- layers[name] = module
39
- if name in return_layers:
40
- del return_layers[name]
41
- if not return_layers:
42
- break
43
-
44
- super().__init__(layers)
45
- self.return_layers = orig_return_layers
46
-
47
- def forward(self, x):
48
- # out = OrderedDict()
49
- outputs = []
50
- for name, module in self.items():
51
- x = module(x)
52
- if name in self.return_layers:
53
- # out_name = self.return_layers[name]
54
- # out[out_name] = x
55
- outputs.append(x)
56
-
57
- return outputs
58
-