Upload 91 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- src/__init__.py +5 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/core/__init__.py +7 -0
- src/core/__pycache__/__init__.cpython-310.pyc +0 -0
- src/core/__pycache__/config.cpython-310.pyc +0 -0
- src/core/__pycache__/yaml_config.cpython-310.pyc +0 -0
- src/core/__pycache__/yaml_utils.cpython-310.pyc +0 -0
- src/core/config.py +264 -0
- src/core/yaml_config.py +152 -0
- src/core/yaml_utils.py +208 -0
- src/data/__init__.py +7 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/dataloader.cpython-310.pyc +0 -0
- src/data/__pycache__/transforms.cpython-310.pyc +0 -0
- src/data/cifar10/__init__.py +14 -0
- src/data/cifar10/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/coco/__init__.py +9 -0
- src/data/coco/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/coco/__pycache__/coco_dataset.cpython-310.pyc +0 -0
- src/data/coco/__pycache__/coco_eval.cpython-310.pyc +0 -0
- src/data/coco/__pycache__/coco_utils.cpython-310.pyc +0 -0
- src/data/coco/coco_dataset.py +238 -0
- src/data/coco/coco_eval.py +269 -0
- src/data/coco/coco_utils.py +184 -0
- src/data/dataloader.py +28 -0
- src/data/functional.py +169 -0
- src/data/transforms.py +142 -0
- src/misc/__init__.py +3 -0
- src/misc/__pycache__/__init__.cpython-310.pyc +0 -0
- src/misc/__pycache__/dist.cpython-310.pyc +0 -0
- src/misc/__pycache__/logger.cpython-310.pyc +0 -0
- src/misc/__pycache__/visualizer.cpython-310.pyc +0 -0
- src/misc/dist.py +190 -0
- src/misc/logger.py +239 -0
- src/misc/visualizer.py +34 -0
- src/nn/__init__.py +7 -0
- src/nn/__pycache__/__init__.cpython-310.pyc +0 -0
- src/nn/arch/__init__.py +1 -0
- src/nn/arch/__pycache__/__init__.cpython-310.pyc +0 -0
- src/nn/arch/__pycache__/classification.cpython-310.pyc +0 -0
- src/nn/arch/classification.py +41 -0
- src/nn/backbone/__init__.py +5 -0
- src/nn/backbone/__pycache__/__init__.cpython-310.pyc +0 -0
- src/nn/backbone/__pycache__/common.cpython-310.pyc +0 -0
- src/nn/backbone/__pycache__/presnet.cpython-310.pyc +0 -0
- src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc +0 -0
- src/nn/backbone/common.py +102 -0
- src/nn/backbone/presnet.py +225 -0
- src/nn/backbone/test_resnet.py +81 -0
- src/nn/backbone/utils.py +58 -0
src/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from . import data
|
3 |
+
from . import nn
|
4 |
+
from . import optim
|
5 |
+
from . import zoo
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (269 Bytes). View file
|
|
src/core/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
# from .yaml_utils import register, create, load_config, merge_config, merge_dict
|
5 |
+
from .yaml_utils import *
|
6 |
+
from .config import BaseConfig
|
7 |
+
from .yaml_config import YAMLConfig
|
src/core/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (309 Bytes). View file
|
|
src/core/__pycache__/config.cpython-310.pyc
ADDED
Binary file (6.52 kB). View file
|
|
src/core/__pycache__/yaml_config.cpython-310.pyc
ADDED
Binary file (4.68 kB). View file
|
|
src/core/__pycache__/yaml_utils.cpython-310.pyc
ADDED
Binary file (4.24 kB). View file
|
|
src/core/config.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
from pprint import pprint
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from torch.optim import Optimizer
|
9 |
+
from torch.optim.lr_scheduler import LRScheduler
|
10 |
+
from torch.cuda.amp.grad_scaler import GradScaler
|
11 |
+
|
12 |
+
from typing import Callable, List, Dict
|
13 |
+
|
14 |
+
|
15 |
+
__all__ = ['BaseConfig', ]
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class BaseConfig(object):
|
20 |
+
# TODO property
|
21 |
+
|
22 |
+
|
23 |
+
def __init__(self) -> None:
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.task :str = None
|
27 |
+
|
28 |
+
self._model :nn.Module = None
|
29 |
+
self._postprocessor :nn.Module = None
|
30 |
+
self._criterion :nn.Module = None
|
31 |
+
self._optimizer :Optimizer = None
|
32 |
+
self._lr_scheduler :LRScheduler = None
|
33 |
+
self._train_dataloader :DataLoader = None
|
34 |
+
self._val_dataloader :DataLoader = None
|
35 |
+
self._ema :nn.Module = None
|
36 |
+
self._scaler :GradScaler = None
|
37 |
+
|
38 |
+
self.train_dataset :Dataset = None
|
39 |
+
self.val_dataset :Dataset = None
|
40 |
+
self.num_workers :int = 0
|
41 |
+
self.collate_fn :Callable = None
|
42 |
+
|
43 |
+
self.batch_size :int = None
|
44 |
+
self._train_batch_size :int = None
|
45 |
+
self._val_batch_size :int = None
|
46 |
+
self._train_shuffle: bool = None
|
47 |
+
self._val_shuffle: bool = None
|
48 |
+
|
49 |
+
self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None
|
50 |
+
|
51 |
+
# runtime
|
52 |
+
self.resume :str = None
|
53 |
+
self.tuning :str = None
|
54 |
+
|
55 |
+
self.epoches :int = None
|
56 |
+
self.last_epoch :int = -1
|
57 |
+
self.end_epoch :int = None
|
58 |
+
|
59 |
+
self.use_amp :bool = False
|
60 |
+
self.use_ema :bool = False
|
61 |
+
self.sync_bn :bool = False
|
62 |
+
self.clip_max_norm : float = None
|
63 |
+
self.find_unused_parameters :bool = None
|
64 |
+
# self.ema_decay: float = 0.9999
|
65 |
+
# self.grad_clip_: Callable = None
|
66 |
+
|
67 |
+
self.log_dir :str = './logs/'
|
68 |
+
self.log_step :int = 10
|
69 |
+
self._output_dir :str = None
|
70 |
+
self._print_freq :int = None
|
71 |
+
self.checkpoint_step :int = 1
|
72 |
+
|
73 |
+
# self.device :str = torch.device('cpu')
|
74 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
75 |
+
self.device = torch.device(device)
|
76 |
+
|
77 |
+
|
78 |
+
@property
|
79 |
+
def model(self, ) -> nn.Module:
|
80 |
+
return self._model
|
81 |
+
|
82 |
+
@model.setter
|
83 |
+
def model(self, m):
|
84 |
+
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
|
85 |
+
self._model = m
|
86 |
+
|
87 |
+
@property
|
88 |
+
def postprocessor(self, ) -> nn.Module:
|
89 |
+
return self._postprocessor
|
90 |
+
|
91 |
+
@postprocessor.setter
|
92 |
+
def postprocessor(self, m):
|
93 |
+
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
|
94 |
+
self._postprocessor = m
|
95 |
+
|
96 |
+
@property
|
97 |
+
def criterion(self, ) -> nn.Module:
|
98 |
+
return self._criterion
|
99 |
+
|
100 |
+
@criterion.setter
|
101 |
+
def criterion(self, m):
|
102 |
+
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class'
|
103 |
+
self._criterion = m
|
104 |
+
|
105 |
+
@property
|
106 |
+
def optimizer(self, ) -> Optimizer:
|
107 |
+
return self._optimizer
|
108 |
+
|
109 |
+
@optimizer.setter
|
110 |
+
def optimizer(self, m):
|
111 |
+
assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class'
|
112 |
+
self._optimizer = m
|
113 |
+
|
114 |
+
@property
|
115 |
+
def lr_scheduler(self, ) -> LRScheduler:
|
116 |
+
return self._lr_scheduler
|
117 |
+
|
118 |
+
@lr_scheduler.setter
|
119 |
+
def lr_scheduler(self, m):
|
120 |
+
assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class'
|
121 |
+
self._lr_scheduler = m
|
122 |
+
|
123 |
+
|
124 |
+
@property
|
125 |
+
def train_dataloader(self):
|
126 |
+
if self._train_dataloader is None and self.train_dataset is not None:
|
127 |
+
loader = DataLoader(self.train_dataset,
|
128 |
+
batch_size=self.train_batch_size,
|
129 |
+
num_workers=self.num_workers,
|
130 |
+
collate_fn=self.collate_fn,
|
131 |
+
shuffle=self.train_shuffle, )
|
132 |
+
loader.shuffle = self.train_shuffle
|
133 |
+
self._train_dataloader = loader
|
134 |
+
|
135 |
+
return self._train_dataloader
|
136 |
+
|
137 |
+
@train_dataloader.setter
|
138 |
+
def train_dataloader(self, loader):
|
139 |
+
self._train_dataloader = loader
|
140 |
+
|
141 |
+
@property
|
142 |
+
def val_dataloader(self):
|
143 |
+
if self._val_dataloader is None and self.val_dataset is not None:
|
144 |
+
loader = DataLoader(self.val_dataset,
|
145 |
+
batch_size=self.val_batch_size,
|
146 |
+
num_workers=self.num_workers,
|
147 |
+
drop_last=False,
|
148 |
+
collate_fn=self.collate_fn,
|
149 |
+
shuffle=self.val_shuffle)
|
150 |
+
loader.shuffle = self.val_shuffle
|
151 |
+
self._val_dataloader = loader
|
152 |
+
|
153 |
+
return self._val_dataloader
|
154 |
+
|
155 |
+
@val_dataloader.setter
|
156 |
+
def val_dataloader(self, loader):
|
157 |
+
self._val_dataloader = loader
|
158 |
+
|
159 |
+
|
160 |
+
# TODO method
|
161 |
+
# @property
|
162 |
+
# def ema(self, ) -> nn.Module:
|
163 |
+
# if self._ema is None and self.use_ema and self.model is not None:
|
164 |
+
# self._ema = ModelEMA(self.model, self.ema_decay)
|
165 |
+
# return self._ema
|
166 |
+
|
167 |
+
@property
|
168 |
+
def ema(self, ) -> nn.Module:
|
169 |
+
return self._ema
|
170 |
+
|
171 |
+
@ema.setter
|
172 |
+
def ema(self, obj):
|
173 |
+
self._ema = obj
|
174 |
+
|
175 |
+
|
176 |
+
@property
|
177 |
+
def scaler(self) -> GradScaler:
|
178 |
+
if self._scaler is None and self.use_amp and torch.cuda.is_available():
|
179 |
+
self._scaler = GradScaler()
|
180 |
+
return self._scaler
|
181 |
+
|
182 |
+
@scaler.setter
|
183 |
+
def scaler(self, obj: GradScaler):
|
184 |
+
self._scaler = obj
|
185 |
+
|
186 |
+
|
187 |
+
@property
|
188 |
+
def val_shuffle(self):
|
189 |
+
if self._val_shuffle is None:
|
190 |
+
print('warning: set default val_shuffle=False')
|
191 |
+
return False
|
192 |
+
return self._val_shuffle
|
193 |
+
|
194 |
+
@val_shuffle.setter
|
195 |
+
def val_shuffle(self, shuffle):
|
196 |
+
assert isinstance(shuffle, bool), 'shuffle must be bool'
|
197 |
+
self._val_shuffle = shuffle
|
198 |
+
|
199 |
+
@property
|
200 |
+
def train_shuffle(self):
|
201 |
+
if self._train_shuffle is None:
|
202 |
+
print('warning: set default train_shuffle=True')
|
203 |
+
return True
|
204 |
+
return self._train_shuffle
|
205 |
+
|
206 |
+
@train_shuffle.setter
|
207 |
+
def train_shuffle(self, shuffle):
|
208 |
+
assert isinstance(shuffle, bool), 'shuffle must be bool'
|
209 |
+
self._train_shuffle = shuffle
|
210 |
+
|
211 |
+
|
212 |
+
@property
|
213 |
+
def train_batch_size(self):
|
214 |
+
if self._train_batch_size is None and isinstance(self.batch_size, int):
|
215 |
+
print(f'warning: set train_batch_size=batch_size={self.batch_size}')
|
216 |
+
return self.batch_size
|
217 |
+
return self._train_batch_size
|
218 |
+
|
219 |
+
@train_batch_size.setter
|
220 |
+
def train_batch_size(self, batch_size):
|
221 |
+
assert isinstance(batch_size, int), 'batch_size must be int'
|
222 |
+
self._train_batch_size = batch_size
|
223 |
+
|
224 |
+
@property
|
225 |
+
def val_batch_size(self):
|
226 |
+
if self._val_batch_size is None:
|
227 |
+
print(f'warning: set val_batch_size=batch_size={self.batch_size}')
|
228 |
+
return self.batch_size
|
229 |
+
return self._val_batch_size
|
230 |
+
|
231 |
+
@val_batch_size.setter
|
232 |
+
def val_batch_size(self, batch_size):
|
233 |
+
assert isinstance(batch_size, int), 'batch_size must be int'
|
234 |
+
self._val_batch_size = batch_size
|
235 |
+
|
236 |
+
|
237 |
+
@property
|
238 |
+
def output_dir(self):
|
239 |
+
if self._output_dir is None:
|
240 |
+
return self.log_dir
|
241 |
+
return self._output_dir
|
242 |
+
|
243 |
+
@output_dir.setter
|
244 |
+
def output_dir(self, root):
|
245 |
+
self._output_dir = root
|
246 |
+
|
247 |
+
@property
|
248 |
+
def print_freq(self):
|
249 |
+
if self._print_freq is None:
|
250 |
+
# self._print_freq = self.log_step
|
251 |
+
return self.log_step
|
252 |
+
return self._print_freq
|
253 |
+
|
254 |
+
@print_freq.setter
|
255 |
+
def print_freq(self, n):
|
256 |
+
assert isinstance(n, int), 'print_freq must be int'
|
257 |
+
self._print_freq = n
|
258 |
+
|
259 |
+
|
260 |
+
# def __repr__(self) -> str:
|
261 |
+
# pass
|
262 |
+
|
263 |
+
|
264 |
+
|
src/core/yaml_config.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
import re
|
8 |
+
import copy
|
9 |
+
|
10 |
+
from .config import BaseConfig
|
11 |
+
from .yaml_utils import load_config, merge_config, create, merge_dict
|
12 |
+
|
13 |
+
|
14 |
+
class YAMLConfig(BaseConfig):
|
15 |
+
def __init__(self, cfg_path: str, **kwargs) -> None:
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
cfg = load_config(cfg_path)
|
19 |
+
merge_dict(cfg, kwargs)
|
20 |
+
|
21 |
+
# pprint(cfg)
|
22 |
+
|
23 |
+
self.yaml_cfg = cfg
|
24 |
+
|
25 |
+
self.log_step = cfg.get('log_step', 100)
|
26 |
+
self.checkpoint_step = cfg.get('checkpoint_step', 1)
|
27 |
+
self.epoches = cfg.get('epoches', -1)
|
28 |
+
self.resume = cfg.get('resume', '')
|
29 |
+
self.tuning = cfg.get('tuning', '')
|
30 |
+
self.sync_bn = cfg.get('sync_bn', False)
|
31 |
+
self.output_dir = cfg.get('output_dir', None)
|
32 |
+
|
33 |
+
self.use_ema = cfg.get('use_ema', False)
|
34 |
+
self.use_amp = cfg.get('use_amp', False)
|
35 |
+
self.autocast = cfg.get('autocast', dict())
|
36 |
+
self.find_unused_parameters = cfg.get('find_unused_parameters', None)
|
37 |
+
self.clip_max_norm = cfg.get('clip_max_norm', 0.)
|
38 |
+
|
39 |
+
|
40 |
+
@property
|
41 |
+
def model(self, ) -> torch.nn.Module:
|
42 |
+
if self._model is None and 'model' in self.yaml_cfg:
|
43 |
+
merge_config(self.yaml_cfg)
|
44 |
+
self._model = create(self.yaml_cfg['model'])
|
45 |
+
return self._model
|
46 |
+
|
47 |
+
@property
|
48 |
+
def postprocessor(self, ) -> torch.nn.Module:
|
49 |
+
if self._postprocessor is None and 'postprocessor' in self.yaml_cfg:
|
50 |
+
merge_config(self.yaml_cfg)
|
51 |
+
self._postprocessor = create(self.yaml_cfg['postprocessor'])
|
52 |
+
return self._postprocessor
|
53 |
+
|
54 |
+
@property
|
55 |
+
def criterion(self, ):
|
56 |
+
if self._criterion is None and 'criterion' in self.yaml_cfg:
|
57 |
+
merge_config(self.yaml_cfg)
|
58 |
+
self._criterion = create(self.yaml_cfg['criterion'])
|
59 |
+
return self._criterion
|
60 |
+
|
61 |
+
|
62 |
+
@property
|
63 |
+
def optimizer(self, ):
|
64 |
+
if self._optimizer is None and 'optimizer' in self.yaml_cfg:
|
65 |
+
merge_config(self.yaml_cfg)
|
66 |
+
params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)
|
67 |
+
self._optimizer = create('optimizer', params=params)
|
68 |
+
|
69 |
+
return self._optimizer
|
70 |
+
|
71 |
+
@property
|
72 |
+
def lr_scheduler(self, ):
|
73 |
+
if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
|
74 |
+
merge_config(self.yaml_cfg)
|
75 |
+
self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
|
76 |
+
print('Initial lr: ', self._lr_scheduler.get_last_lr())
|
77 |
+
|
78 |
+
return self._lr_scheduler
|
79 |
+
|
80 |
+
@property
|
81 |
+
def train_dataloader(self, ):
|
82 |
+
if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
|
83 |
+
merge_config(self.yaml_cfg)
|
84 |
+
self._train_dataloader = create('train_dataloader')
|
85 |
+
self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
|
86 |
+
|
87 |
+
return self._train_dataloader
|
88 |
+
|
89 |
+
@property
|
90 |
+
def val_dataloader(self, ):
|
91 |
+
if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg:
|
92 |
+
merge_config(self.yaml_cfg)
|
93 |
+
self._val_dataloader = create('val_dataloader')
|
94 |
+
self._val_dataloader.shuffle = self.yaml_cfg['val_dataloader'].get('shuffle', False)
|
95 |
+
|
96 |
+
return self._val_dataloader
|
97 |
+
|
98 |
+
|
99 |
+
@property
|
100 |
+
def ema(self, ):
|
101 |
+
if self._ema is None and self.yaml_cfg.get('use_ema', False):
|
102 |
+
merge_config(self.yaml_cfg)
|
103 |
+
self._ema = create('ema', model=self.model)
|
104 |
+
|
105 |
+
return self._ema
|
106 |
+
|
107 |
+
|
108 |
+
@property
|
109 |
+
def scaler(self, ):
|
110 |
+
if self._scaler is None and self.yaml_cfg.get('use_amp', False):
|
111 |
+
merge_config(self.yaml_cfg)
|
112 |
+
self._scaler = create('scaler')
|
113 |
+
|
114 |
+
return self._scaler
|
115 |
+
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def get_optim_params(cfg: dict, model: nn.Module):
|
119 |
+
'''
|
120 |
+
E.g.:
|
121 |
+
^(?=.*a)(?=.*b).*$ means including a and b
|
122 |
+
^((?!b.)*a((?!b).)*$ means including a but not b
|
123 |
+
^((?!b|c).)*a((?!b|c).)*$ means including a but not (b | c)
|
124 |
+
'''
|
125 |
+
assert 'type' in cfg, ''
|
126 |
+
cfg = copy.deepcopy(cfg)
|
127 |
+
|
128 |
+
if 'params' not in cfg:
|
129 |
+
return model.parameters()
|
130 |
+
|
131 |
+
assert isinstance(cfg['params'], list), ''
|
132 |
+
|
133 |
+
param_groups = []
|
134 |
+
visited = []
|
135 |
+
for pg in cfg['params']:
|
136 |
+
pattern = pg['params']
|
137 |
+
params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
|
138 |
+
pg['params'] = params.values()
|
139 |
+
param_groups.append(pg)
|
140 |
+
visited.extend(list(params.keys()))
|
141 |
+
|
142 |
+
names = [k for k, v in model.named_parameters() if v.requires_grad]
|
143 |
+
|
144 |
+
if len(visited) < len(names):
|
145 |
+
unseen = set(names) - set(visited)
|
146 |
+
params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
|
147 |
+
param_groups.append({'params': params.values()})
|
148 |
+
visited.extend(list(params.keys()))
|
149 |
+
|
150 |
+
assert len(visited) == len(names), ''
|
151 |
+
|
152 |
+
return param_groups
|
src/core/yaml_utils.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import yaml
|
6 |
+
import inspect
|
7 |
+
import importlib
|
8 |
+
|
9 |
+
__all__ = ['GLOBAL_CONFIG', 'register', 'create', 'load_config', 'merge_config', 'merge_dict']
|
10 |
+
|
11 |
+
|
12 |
+
GLOBAL_CONFIG = dict()
|
13 |
+
INCLUDE_KEY = '__include__'
|
14 |
+
|
15 |
+
|
16 |
+
def register(cls: type):
|
17 |
+
'''
|
18 |
+
Args:
|
19 |
+
cls (type): Module class to be registered.
|
20 |
+
'''
|
21 |
+
if cls.__name__ in GLOBAL_CONFIG:
|
22 |
+
raise ValueError('{} already registered'.format(cls.__name__))
|
23 |
+
|
24 |
+
if inspect.isfunction(cls):
|
25 |
+
GLOBAL_CONFIG[cls.__name__] = cls
|
26 |
+
|
27 |
+
elif inspect.isclass(cls):
|
28 |
+
GLOBAL_CONFIG[cls.__name__] = extract_schema(cls)
|
29 |
+
|
30 |
+
else:
|
31 |
+
raise ValueError(f'register {cls}')
|
32 |
+
|
33 |
+
return cls
|
34 |
+
|
35 |
+
|
36 |
+
def extract_schema(cls: type):
|
37 |
+
'''
|
38 |
+
Args:
|
39 |
+
cls (type),
|
40 |
+
Return:
|
41 |
+
Dict,
|
42 |
+
'''
|
43 |
+
argspec = inspect.getfullargspec(cls.__init__)
|
44 |
+
arg_names = [arg for arg in argspec.args if arg != 'self']
|
45 |
+
num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
|
46 |
+
num_requires = len(arg_names) - num_defualts
|
47 |
+
|
48 |
+
schame = dict()
|
49 |
+
schame['_name'] = cls.__name__
|
50 |
+
schame['_pymodule'] = importlib.import_module(cls.__module__)
|
51 |
+
schame['_inject'] = getattr(cls, '__inject__', [])
|
52 |
+
schame['_share'] = getattr(cls, '__share__', [])
|
53 |
+
|
54 |
+
for i, name in enumerate(arg_names):
|
55 |
+
if name in schame['_share']:
|
56 |
+
assert i >= num_requires, 'share config must have default value.'
|
57 |
+
value = argspec.defaults[i - num_requires]
|
58 |
+
|
59 |
+
elif i >= num_requires:
|
60 |
+
value = argspec.defaults[i - num_requires]
|
61 |
+
|
62 |
+
else:
|
63 |
+
value = None
|
64 |
+
|
65 |
+
schame[name] = value
|
66 |
+
|
67 |
+
return schame
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def create(type_or_name, **kwargs):
|
72 |
+
'''
|
73 |
+
'''
|
74 |
+
assert type(type_or_name) in (type, str), 'create should be class or name.'
|
75 |
+
|
76 |
+
name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__
|
77 |
+
|
78 |
+
if name in GLOBAL_CONFIG:
|
79 |
+
if hasattr(GLOBAL_CONFIG[name], '__dict__'):
|
80 |
+
return GLOBAL_CONFIG[name]
|
81 |
+
else:
|
82 |
+
raise ValueError('The module {} is not registered'.format(name))
|
83 |
+
|
84 |
+
cfg = GLOBAL_CONFIG[name]
|
85 |
+
|
86 |
+
if isinstance(cfg, dict) and 'type' in cfg:
|
87 |
+
_cfg: dict = GLOBAL_CONFIG[cfg['type']]
|
88 |
+
_cfg.update(cfg) # update global cls default args
|
89 |
+
_cfg.update(kwargs) # TODO
|
90 |
+
name = _cfg.pop('type')
|
91 |
+
|
92 |
+
return create(name)
|
93 |
+
|
94 |
+
|
95 |
+
cls = getattr(cfg['_pymodule'], name)
|
96 |
+
argspec = inspect.getfullargspec(cls.__init__)
|
97 |
+
arg_names = [arg for arg in argspec.args if arg != 'self']
|
98 |
+
|
99 |
+
cls_kwargs = {}
|
100 |
+
cls_kwargs.update(cfg)
|
101 |
+
|
102 |
+
# shared var
|
103 |
+
for k in cfg['_share']:
|
104 |
+
if k in GLOBAL_CONFIG:
|
105 |
+
cls_kwargs[k] = GLOBAL_CONFIG[k]
|
106 |
+
else:
|
107 |
+
cls_kwargs[k] = cfg[k]
|
108 |
+
|
109 |
+
# inject
|
110 |
+
for k in cfg['_inject']:
|
111 |
+
_k = cfg[k]
|
112 |
+
|
113 |
+
if _k is None:
|
114 |
+
continue
|
115 |
+
|
116 |
+
if isinstance(_k, str):
|
117 |
+
if _k not in GLOBAL_CONFIG:
|
118 |
+
raise ValueError(f'Missing inject config of {_k}.')
|
119 |
+
|
120 |
+
_cfg = GLOBAL_CONFIG[_k]
|
121 |
+
|
122 |
+
if isinstance(_cfg, dict):
|
123 |
+
cls_kwargs[k] = create(_cfg['_name'])
|
124 |
+
else:
|
125 |
+
cls_kwargs[k] = _cfg
|
126 |
+
|
127 |
+
elif isinstance(_k, dict):
|
128 |
+
if 'type' not in _k.keys():
|
129 |
+
raise ValueError(f'Missing inject for `type` style.')
|
130 |
+
|
131 |
+
_type = str(_k['type'])
|
132 |
+
if _type not in GLOBAL_CONFIG:
|
133 |
+
raise ValueError(f'Missing {_type} in inspect stage.')
|
134 |
+
|
135 |
+
# TODO modified inspace, maybe get wrong result for using `> 1`
|
136 |
+
_cfg: dict = GLOBAL_CONFIG[_type]
|
137 |
+
# _cfg_copy = copy.deepcopy(_cfg)
|
138 |
+
_cfg.update(_k) # update
|
139 |
+
cls_kwargs[k] = create(_type)
|
140 |
+
# _cfg.update(_cfg_copy) # resume
|
141 |
+
|
142 |
+
else:
|
143 |
+
raise ValueError(f'Inject does not support {_k}')
|
144 |
+
|
145 |
+
|
146 |
+
cls_kwargs = {n: cls_kwargs[n] for n in arg_names}
|
147 |
+
|
148 |
+
return cls(**cls_kwargs)
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def load_config(file_path, cfg=dict()):
|
153 |
+
'''load config
|
154 |
+
'''
|
155 |
+
_, ext = os.path.splitext(file_path)
|
156 |
+
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
157 |
+
|
158 |
+
with open(file_path) as f:
|
159 |
+
file_cfg = yaml.load(f, Loader=yaml.Loader)
|
160 |
+
if file_cfg is None:
|
161 |
+
return {}
|
162 |
+
|
163 |
+
if INCLUDE_KEY in file_cfg:
|
164 |
+
base_yamls = list(file_cfg[INCLUDE_KEY])
|
165 |
+
for base_yaml in base_yamls:
|
166 |
+
if base_yaml.startswith('~'):
|
167 |
+
base_yaml = os.path.expanduser(base_yaml)
|
168 |
+
|
169 |
+
if not base_yaml.startswith('/'):
|
170 |
+
base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)
|
171 |
+
|
172 |
+
with open(base_yaml) as f:
|
173 |
+
base_cfg = load_config(base_yaml, cfg)
|
174 |
+
merge_config(base_cfg, cfg)
|
175 |
+
|
176 |
+
return merge_config(file_cfg, cfg)
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
def merge_dict(dct, another_dct):
|
181 |
+
'''merge another_dct into dct
|
182 |
+
'''
|
183 |
+
for k in another_dct:
|
184 |
+
if (k in dct and isinstance(dct[k], dict) and isinstance(another_dct[k], dict)):
|
185 |
+
merge_dict(dct[k], another_dct[k])
|
186 |
+
else:
|
187 |
+
dct[k] = another_dct[k]
|
188 |
+
|
189 |
+
return dct
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
def merge_config(config, another_cfg=None):
|
194 |
+
"""
|
195 |
+
Merge config into global config or another_cfg.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
config (dict): Config to be merged.
|
199 |
+
|
200 |
+
Returns: global config
|
201 |
+
"""
|
202 |
+
global GLOBAL_CONFIG
|
203 |
+
dct = GLOBAL_CONFIG if another_cfg is None else another_cfg
|
204 |
+
|
205 |
+
return merge_dict(dct, config)
|
206 |
+
|
207 |
+
|
208 |
+
|
src/data/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .coco import *
|
3 |
+
from .cifar10 import CIFAR10
|
4 |
+
|
5 |
+
from .dataloader import *
|
6 |
+
from .transforms import *
|
7 |
+
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (270 Bytes). View file
|
|
src/data/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (1.29 kB). View file
|
|
src/data/__pycache__/transforms.cpython-310.pyc
ADDED
Binary file (5.19 kB). View file
|
|
src/data/cifar10/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torchvision
|
3 |
+
from typing import Optional, Callable
|
4 |
+
|
5 |
+
from src.core import register
|
6 |
+
|
7 |
+
|
8 |
+
@register
|
9 |
+
class CIFAR10(torchvision.datasets.CIFAR10):
|
10 |
+
__inject__ = ['transform', 'target_transform']
|
11 |
+
|
12 |
+
def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None:
|
13 |
+
super().__init__(root, train, transform, target_transform, download)
|
14 |
+
|
src/data/cifar10/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (823 Bytes). View file
|
|
src/data/coco/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .coco_dataset import (
|
2 |
+
CocoDetection,
|
3 |
+
mscoco_category2label,
|
4 |
+
mscoco_label2category,
|
5 |
+
mscoco_category2name,
|
6 |
+
)
|
7 |
+
from .coco_eval import *
|
8 |
+
|
9 |
+
from .coco_utils import get_coco_api_from_dataset
|
src/data/coco/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (402 Bytes). View file
|
|
src/data/coco/__pycache__/coco_dataset.cpython-310.pyc
ADDED
Binary file (7.07 kB). View file
|
|
src/data/coco/__pycache__/coco_eval.cpython-310.pyc
ADDED
Binary file (7.26 kB). View file
|
|
src/data/coco/__pycache__/coco_utils.cpython-310.pyc
ADDED
Binary file (6.61 kB). View file
|
|
src/data/coco/coco_dataset.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
|
4 |
+
COCO dataset which returns image_id for evaluation.
|
5 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
|
11 |
+
import torchvision
|
12 |
+
torchvision.disable_beta_transforms_warning()
|
13 |
+
|
14 |
+
from torchvision import datapoints
|
15 |
+
|
16 |
+
from pycocotools import mask as coco_mask
|
17 |
+
|
18 |
+
from src.core import register
|
19 |
+
|
20 |
+
__all__ = ['CocoDetection']
|
21 |
+
|
22 |
+
|
23 |
+
@register
|
24 |
+
class CocoDetection(torchvision.datasets.CocoDetection):
|
25 |
+
__inject__ = ['transforms']
|
26 |
+
__share__ = ['remap_mscoco_category']
|
27 |
+
|
28 |
+
def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False):
|
29 |
+
super(CocoDetection, self).__init__(img_folder, ann_file)
|
30 |
+
self._transforms = transforms
|
31 |
+
self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category)
|
32 |
+
self.img_folder = img_folder
|
33 |
+
self.ann_file = ann_file
|
34 |
+
self.return_masks = return_masks
|
35 |
+
self.remap_mscoco_category = remap_mscoco_category
|
36 |
+
|
37 |
+
def __getitem__(self, idx):
|
38 |
+
img, target = super(CocoDetection, self).__getitem__(idx)
|
39 |
+
image_id = self.ids[idx]
|
40 |
+
target = {'image_id': image_id, 'annotations': target}
|
41 |
+
img, target = self.prepare(img, target)
|
42 |
+
|
43 |
+
# ['boxes', 'masks', 'labels']:
|
44 |
+
if 'boxes' in target:
|
45 |
+
target['boxes'] = datapoints.BoundingBox(
|
46 |
+
target['boxes'],
|
47 |
+
format=datapoints.BoundingBoxFormat.XYXY,
|
48 |
+
spatial_size=img.size[::-1]) # h w
|
49 |
+
|
50 |
+
if 'masks' in target:
|
51 |
+
target['masks'] = datapoints.Mask(target['masks'])
|
52 |
+
|
53 |
+
if self._transforms is not None:
|
54 |
+
img, target = self._transforms(img, target)
|
55 |
+
|
56 |
+
return img, target
|
57 |
+
|
58 |
+
def extra_repr(self) -> str:
|
59 |
+
s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n'
|
60 |
+
s += f' return_masks: {self.return_masks}\n'
|
61 |
+
if hasattr(self, '_transforms') and self._transforms is not None:
|
62 |
+
s += f' transforms:\n {repr(self._transforms)}'
|
63 |
+
|
64 |
+
return s
|
65 |
+
|
66 |
+
|
67 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
68 |
+
masks = []
|
69 |
+
for polygons in segmentations:
|
70 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
71 |
+
mask = coco_mask.decode(rles)
|
72 |
+
if len(mask.shape) < 3:
|
73 |
+
mask = mask[..., None]
|
74 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
75 |
+
mask = mask.any(dim=2)
|
76 |
+
masks.append(mask)
|
77 |
+
if masks:
|
78 |
+
masks = torch.stack(masks, dim=0)
|
79 |
+
else:
|
80 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
81 |
+
return masks
|
82 |
+
|
83 |
+
|
84 |
+
class ConvertCocoPolysToMask(object):
|
85 |
+
def __init__(self, return_masks=False, remap_mscoco_category=False):
|
86 |
+
self.return_masks = return_masks
|
87 |
+
self.remap_mscoco_category = remap_mscoco_category
|
88 |
+
|
89 |
+
def __call__(self, image, target):
|
90 |
+
w, h = image.size
|
91 |
+
|
92 |
+
image_id = target["image_id"]
|
93 |
+
image_id = torch.tensor([image_id])
|
94 |
+
|
95 |
+
anno = target["annotations"]
|
96 |
+
|
97 |
+
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
|
98 |
+
|
99 |
+
boxes = [obj["bbox"] for obj in anno]
|
100 |
+
# guard against no boxes via resizing
|
101 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
102 |
+
boxes[:, 2:] += boxes[:, :2]
|
103 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
104 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
105 |
+
|
106 |
+
if self.remap_mscoco_category:
|
107 |
+
classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
|
108 |
+
else:
|
109 |
+
classes = [obj["category_id"] for obj in anno]
|
110 |
+
|
111 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
112 |
+
|
113 |
+
if self.return_masks:
|
114 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
115 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
116 |
+
|
117 |
+
keypoints = None
|
118 |
+
if anno and "keypoints" in anno[0]:
|
119 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
120 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
121 |
+
num_keypoints = keypoints.shape[0]
|
122 |
+
if num_keypoints:
|
123 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
124 |
+
|
125 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
126 |
+
boxes = boxes[keep]
|
127 |
+
classes = classes[keep]
|
128 |
+
if self.return_masks:
|
129 |
+
masks = masks[keep]
|
130 |
+
if keypoints is not None:
|
131 |
+
keypoints = keypoints[keep]
|
132 |
+
|
133 |
+
target = {}
|
134 |
+
target["boxes"] = boxes
|
135 |
+
target["labels"] = classes
|
136 |
+
if self.return_masks:
|
137 |
+
target["masks"] = masks
|
138 |
+
target["image_id"] = image_id
|
139 |
+
if keypoints is not None:
|
140 |
+
target["keypoints"] = keypoints
|
141 |
+
|
142 |
+
# for conversion to coco api
|
143 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
144 |
+
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
|
145 |
+
target["area"] = area[keep]
|
146 |
+
target["iscrowd"] = iscrowd[keep]
|
147 |
+
|
148 |
+
target["orig_size"] = torch.as_tensor([int(w), int(h)])
|
149 |
+
target["size"] = torch.as_tensor([int(w), int(h)])
|
150 |
+
|
151 |
+
return image, target
|
152 |
+
|
153 |
+
|
154 |
+
mscoco_category2name = {
|
155 |
+
1: 'person',
|
156 |
+
2: 'bicycle',
|
157 |
+
3: 'car',
|
158 |
+
4: 'motorcycle',
|
159 |
+
5: 'airplane',
|
160 |
+
6: 'bus',
|
161 |
+
7: 'train',
|
162 |
+
8: 'truck',
|
163 |
+
9: 'boat',
|
164 |
+
10: 'traffic light',
|
165 |
+
11: 'fire hydrant',
|
166 |
+
13: 'stop sign',
|
167 |
+
14: 'parking meter',
|
168 |
+
15: 'bench',
|
169 |
+
16: 'bird',
|
170 |
+
17: 'cat',
|
171 |
+
18: 'dog',
|
172 |
+
19: 'horse',
|
173 |
+
20: 'sheep',
|
174 |
+
21: 'cow',
|
175 |
+
22: 'elephant',
|
176 |
+
23: 'bear',
|
177 |
+
24: 'zebra',
|
178 |
+
25: 'giraffe',
|
179 |
+
27: 'backpack',
|
180 |
+
28: 'umbrella',
|
181 |
+
31: 'handbag',
|
182 |
+
32: 'tie',
|
183 |
+
33: 'suitcase',
|
184 |
+
34: 'frisbee',
|
185 |
+
35: 'skis',
|
186 |
+
36: 'snowboard',
|
187 |
+
37: 'sports ball',
|
188 |
+
38: 'kite',
|
189 |
+
39: 'baseball bat',
|
190 |
+
40: 'baseball glove',
|
191 |
+
41: 'skateboard',
|
192 |
+
42: 'surfboard',
|
193 |
+
43: 'tennis racket',
|
194 |
+
44: 'bottle',
|
195 |
+
46: 'wine glass',
|
196 |
+
47: 'cup',
|
197 |
+
48: 'fork',
|
198 |
+
49: 'knife',
|
199 |
+
50: 'spoon',
|
200 |
+
51: 'bowl',
|
201 |
+
52: 'banana',
|
202 |
+
53: 'apple',
|
203 |
+
54: 'sandwich',
|
204 |
+
55: 'orange',
|
205 |
+
56: 'broccoli',
|
206 |
+
57: 'carrot',
|
207 |
+
58: 'hot dog',
|
208 |
+
59: 'pizza',
|
209 |
+
60: 'donut',
|
210 |
+
61: 'cake',
|
211 |
+
62: 'chair',
|
212 |
+
63: 'couch',
|
213 |
+
64: 'potted plant',
|
214 |
+
65: 'bed',
|
215 |
+
67: 'dining table',
|
216 |
+
70: 'toilet',
|
217 |
+
72: 'tv',
|
218 |
+
73: 'laptop',
|
219 |
+
74: 'mouse',
|
220 |
+
75: 'remote',
|
221 |
+
76: 'keyboard',
|
222 |
+
77: 'cell phone',
|
223 |
+
78: 'microwave',
|
224 |
+
79: 'oven',
|
225 |
+
80: 'toaster',
|
226 |
+
81: 'sink',
|
227 |
+
82: 'refrigerator',
|
228 |
+
84: 'book',
|
229 |
+
85: 'clock',
|
230 |
+
86: 'vase',
|
231 |
+
87: 'scissors',
|
232 |
+
88: 'teddy bear',
|
233 |
+
89: 'hair drier',
|
234 |
+
90: 'toothbrush'
|
235 |
+
}
|
236 |
+
|
237 |
+
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
|
238 |
+
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
|
src/data/coco/coco_eval.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
COCO evaluator that works in distributed mode.
|
4 |
+
|
5 |
+
Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
|
6 |
+
The difference is that there is less copy-pasting from pycocotools
|
7 |
+
in the end of the file, as python3 can suppress prints with contextlib
|
8 |
+
"""
|
9 |
+
import os
|
10 |
+
import contextlib
|
11 |
+
import copy
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from pycocotools.cocoeval import COCOeval
|
16 |
+
from pycocotools.coco import COCO
|
17 |
+
import pycocotools.mask as mask_util
|
18 |
+
|
19 |
+
from src.misc import dist
|
20 |
+
|
21 |
+
|
22 |
+
__all__ = ['CocoEvaluator',]
|
23 |
+
|
24 |
+
|
25 |
+
class CocoEvaluator(object):
|
26 |
+
def __init__(self, coco_gt, iou_types):
|
27 |
+
assert isinstance(iou_types, (list, tuple))
|
28 |
+
coco_gt = copy.deepcopy(coco_gt)
|
29 |
+
self.coco_gt = coco_gt
|
30 |
+
|
31 |
+
self.iou_types = iou_types
|
32 |
+
self.coco_eval = {}
|
33 |
+
for iou_type in iou_types:
|
34 |
+
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
|
35 |
+
|
36 |
+
self.img_ids = []
|
37 |
+
self.eval_imgs = {k: [] for k in iou_types}
|
38 |
+
|
39 |
+
def update(self, predictions):
|
40 |
+
img_ids = list(np.unique(list(predictions.keys())))
|
41 |
+
self.img_ids.extend(img_ids)
|
42 |
+
|
43 |
+
for iou_type in self.iou_types:
|
44 |
+
results = self.prepare(predictions, iou_type)
|
45 |
+
|
46 |
+
# suppress pycocotools prints
|
47 |
+
with open(os.devnull, 'w') as devnull:
|
48 |
+
with contextlib.redirect_stdout(devnull):
|
49 |
+
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
|
50 |
+
coco_eval = self.coco_eval[iou_type]
|
51 |
+
|
52 |
+
coco_eval.cocoDt = coco_dt
|
53 |
+
coco_eval.params.imgIds = list(img_ids)
|
54 |
+
img_ids, eval_imgs = evaluate(coco_eval)
|
55 |
+
|
56 |
+
self.eval_imgs[iou_type].append(eval_imgs)
|
57 |
+
|
58 |
+
def synchronize_between_processes(self):
|
59 |
+
for iou_type in self.iou_types:
|
60 |
+
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
|
61 |
+
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
|
62 |
+
|
63 |
+
def accumulate(self):
|
64 |
+
for coco_eval in self.coco_eval.values():
|
65 |
+
coco_eval.accumulate()
|
66 |
+
|
67 |
+
def summarize(self):
|
68 |
+
for iou_type, coco_eval in self.coco_eval.items():
|
69 |
+
print("IoU metric: {}".format(iou_type))
|
70 |
+
coco_eval.summarize()
|
71 |
+
|
72 |
+
def prepare(self, predictions, iou_type):
|
73 |
+
if iou_type == "bbox":
|
74 |
+
return self.prepare_for_coco_detection(predictions)
|
75 |
+
elif iou_type == "segm":
|
76 |
+
return self.prepare_for_coco_segmentation(predictions)
|
77 |
+
elif iou_type == "keypoints":
|
78 |
+
return self.prepare_for_coco_keypoint(predictions)
|
79 |
+
else:
|
80 |
+
raise ValueError("Unknown iou type {}".format(iou_type))
|
81 |
+
|
82 |
+
def prepare_for_coco_detection(self, predictions):
|
83 |
+
coco_results = []
|
84 |
+
for original_id, prediction in predictions.items():
|
85 |
+
if len(prediction) == 0:
|
86 |
+
continue
|
87 |
+
|
88 |
+
boxes = prediction["boxes"]
|
89 |
+
boxes = convert_to_xywh(boxes).tolist()
|
90 |
+
scores = prediction["scores"].tolist()
|
91 |
+
labels = prediction["labels"].tolist()
|
92 |
+
|
93 |
+
coco_results.extend(
|
94 |
+
[
|
95 |
+
{
|
96 |
+
"image_id": original_id,
|
97 |
+
"category_id": labels[k],
|
98 |
+
"bbox": box,
|
99 |
+
"score": scores[k],
|
100 |
+
}
|
101 |
+
for k, box in enumerate(boxes)
|
102 |
+
]
|
103 |
+
)
|
104 |
+
return coco_results
|
105 |
+
|
106 |
+
def prepare_for_coco_segmentation(self, predictions):
|
107 |
+
coco_results = []
|
108 |
+
for original_id, prediction in predictions.items():
|
109 |
+
if len(prediction) == 0:
|
110 |
+
continue
|
111 |
+
|
112 |
+
scores = prediction["scores"]
|
113 |
+
labels = prediction["labels"]
|
114 |
+
masks = prediction["masks"]
|
115 |
+
|
116 |
+
masks = masks > 0.5
|
117 |
+
|
118 |
+
scores = prediction["scores"].tolist()
|
119 |
+
labels = prediction["labels"].tolist()
|
120 |
+
|
121 |
+
rles = [
|
122 |
+
mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
|
123 |
+
for mask in masks
|
124 |
+
]
|
125 |
+
for rle in rles:
|
126 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
127 |
+
|
128 |
+
coco_results.extend(
|
129 |
+
[
|
130 |
+
{
|
131 |
+
"image_id": original_id,
|
132 |
+
"category_id": labels[k],
|
133 |
+
"segmentation": rle,
|
134 |
+
"score": scores[k],
|
135 |
+
}
|
136 |
+
for k, rle in enumerate(rles)
|
137 |
+
]
|
138 |
+
)
|
139 |
+
return coco_results
|
140 |
+
|
141 |
+
def prepare_for_coco_keypoint(self, predictions):
|
142 |
+
coco_results = []
|
143 |
+
for original_id, prediction in predictions.items():
|
144 |
+
if len(prediction) == 0:
|
145 |
+
continue
|
146 |
+
|
147 |
+
boxes = prediction["boxes"]
|
148 |
+
boxes = convert_to_xywh(boxes).tolist()
|
149 |
+
scores = prediction["scores"].tolist()
|
150 |
+
labels = prediction["labels"].tolist()
|
151 |
+
keypoints = prediction["keypoints"]
|
152 |
+
keypoints = keypoints.flatten(start_dim=1).tolist()
|
153 |
+
|
154 |
+
coco_results.extend(
|
155 |
+
[
|
156 |
+
{
|
157 |
+
"image_id": original_id,
|
158 |
+
"category_id": labels[k],
|
159 |
+
'keypoints': keypoint,
|
160 |
+
"score": scores[k],
|
161 |
+
}
|
162 |
+
for k, keypoint in enumerate(keypoints)
|
163 |
+
]
|
164 |
+
)
|
165 |
+
return coco_results
|
166 |
+
|
167 |
+
|
168 |
+
def convert_to_xywh(boxes):
|
169 |
+
xmin, ymin, xmax, ymax = boxes.unbind(1)
|
170 |
+
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
|
171 |
+
|
172 |
+
|
173 |
+
def merge(img_ids, eval_imgs):
|
174 |
+
all_img_ids = dist.all_gather(img_ids)
|
175 |
+
all_eval_imgs = dist.all_gather(eval_imgs)
|
176 |
+
|
177 |
+
merged_img_ids = []
|
178 |
+
for p in all_img_ids:
|
179 |
+
merged_img_ids.extend(p)
|
180 |
+
|
181 |
+
merged_eval_imgs = []
|
182 |
+
for p in all_eval_imgs:
|
183 |
+
merged_eval_imgs.append(p)
|
184 |
+
|
185 |
+
merged_img_ids = np.array(merged_img_ids)
|
186 |
+
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
|
187 |
+
|
188 |
+
# keep only unique (and in sorted order) images
|
189 |
+
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
|
190 |
+
merged_eval_imgs = merged_eval_imgs[..., idx]
|
191 |
+
|
192 |
+
return merged_img_ids, merged_eval_imgs
|
193 |
+
|
194 |
+
|
195 |
+
def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
|
196 |
+
img_ids, eval_imgs = merge(img_ids, eval_imgs)
|
197 |
+
img_ids = list(img_ids)
|
198 |
+
eval_imgs = list(eval_imgs.flatten())
|
199 |
+
|
200 |
+
coco_eval.evalImgs = eval_imgs
|
201 |
+
coco_eval.params.imgIds = img_ids
|
202 |
+
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
|
203 |
+
|
204 |
+
|
205 |
+
#################################################################
|
206 |
+
# From pycocotools, just removed the prints and fixed
|
207 |
+
# a Python3 bug about unicode not defined
|
208 |
+
#################################################################
|
209 |
+
|
210 |
+
|
211 |
+
# import io
|
212 |
+
# from contextlib import redirect_stdout
|
213 |
+
# def evaluate(imgs):
|
214 |
+
# with redirect_stdout(io.StringIO()):
|
215 |
+
# imgs.evaluate()
|
216 |
+
# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
|
217 |
+
|
218 |
+
|
219 |
+
def evaluate(self):
|
220 |
+
'''
|
221 |
+
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
|
222 |
+
:return: None
|
223 |
+
'''
|
224 |
+
# tic = time.time()
|
225 |
+
# print('Running per image evaluation...')
|
226 |
+
p = self.params
|
227 |
+
# add backward compatibility if useSegm is specified in params
|
228 |
+
if p.useSegm is not None:
|
229 |
+
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
|
230 |
+
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
|
231 |
+
# print('Evaluate annotation type *{}*'.format(p.iouType))
|
232 |
+
p.imgIds = list(np.unique(p.imgIds))
|
233 |
+
if p.useCats:
|
234 |
+
p.catIds = list(np.unique(p.catIds))
|
235 |
+
p.maxDets = sorted(p.maxDets)
|
236 |
+
self.params = p
|
237 |
+
|
238 |
+
self._prepare()
|
239 |
+
# loop through images, area range, max detection number
|
240 |
+
catIds = p.catIds if p.useCats else [-1]
|
241 |
+
|
242 |
+
if p.iouType == 'segm' or p.iouType == 'bbox':
|
243 |
+
computeIoU = self.computeIoU
|
244 |
+
elif p.iouType == 'keypoints':
|
245 |
+
computeIoU = self.computeOks
|
246 |
+
self.ious = {
|
247 |
+
(imgId, catId): computeIoU(imgId, catId)
|
248 |
+
for imgId in p.imgIds
|
249 |
+
for catId in catIds}
|
250 |
+
|
251 |
+
evaluateImg = self.evaluateImg
|
252 |
+
maxDet = p.maxDets[-1]
|
253 |
+
evalImgs = [
|
254 |
+
evaluateImg(imgId, catId, areaRng, maxDet)
|
255 |
+
for catId in catIds
|
256 |
+
for areaRng in p.areaRng
|
257 |
+
for imgId in p.imgIds
|
258 |
+
]
|
259 |
+
# this is NOT in the pycocotools code, but could be done outside
|
260 |
+
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
|
261 |
+
self._paramsEval = copy.deepcopy(self.params)
|
262 |
+
# toc = time.time()
|
263 |
+
# print('DONE (t={:0.2f}s).'.format(toc-tic))
|
264 |
+
return p.imgIds, evalImgs
|
265 |
+
|
266 |
+
#################################################################
|
267 |
+
# end of straight copy from pycocotools, just removing the prints
|
268 |
+
#################################################################
|
269 |
+
|
src/data/coco/coco_utils.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.utils.data
|
5 |
+
import torchvision
|
6 |
+
from pycocotools import mask as coco_mask
|
7 |
+
from pycocotools.coco import COCO
|
8 |
+
|
9 |
+
|
10 |
+
def convert_coco_poly_to_mask(segmentations, height, width):
|
11 |
+
masks = []
|
12 |
+
for polygons in segmentations:
|
13 |
+
rles = coco_mask.frPyObjects(polygons, height, width)
|
14 |
+
mask = coco_mask.decode(rles)
|
15 |
+
if len(mask.shape) < 3:
|
16 |
+
mask = mask[..., None]
|
17 |
+
mask = torch.as_tensor(mask, dtype=torch.uint8)
|
18 |
+
mask = mask.any(dim=2)
|
19 |
+
masks.append(mask)
|
20 |
+
if masks:
|
21 |
+
masks = torch.stack(masks, dim=0)
|
22 |
+
else:
|
23 |
+
masks = torch.zeros((0, height, width), dtype=torch.uint8)
|
24 |
+
return masks
|
25 |
+
|
26 |
+
|
27 |
+
class ConvertCocoPolysToMask:
|
28 |
+
def __call__(self, image, target):
|
29 |
+
w, h = image.size
|
30 |
+
|
31 |
+
image_id = target["image_id"]
|
32 |
+
|
33 |
+
anno = target["annotations"]
|
34 |
+
|
35 |
+
anno = [obj for obj in anno if obj["iscrowd"] == 0]
|
36 |
+
|
37 |
+
boxes = [obj["bbox"] for obj in anno]
|
38 |
+
# guard against no boxes via resizing
|
39 |
+
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
|
40 |
+
boxes[:, 2:] += boxes[:, :2]
|
41 |
+
boxes[:, 0::2].clamp_(min=0, max=w)
|
42 |
+
boxes[:, 1::2].clamp_(min=0, max=h)
|
43 |
+
|
44 |
+
classes = [obj["category_id"] for obj in anno]
|
45 |
+
classes = torch.tensor(classes, dtype=torch.int64)
|
46 |
+
|
47 |
+
segmentations = [obj["segmentation"] for obj in anno]
|
48 |
+
masks = convert_coco_poly_to_mask(segmentations, h, w)
|
49 |
+
|
50 |
+
keypoints = None
|
51 |
+
if anno and "keypoints" in anno[0]:
|
52 |
+
keypoints = [obj["keypoints"] for obj in anno]
|
53 |
+
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
|
54 |
+
num_keypoints = keypoints.shape[0]
|
55 |
+
if num_keypoints:
|
56 |
+
keypoints = keypoints.view(num_keypoints, -1, 3)
|
57 |
+
|
58 |
+
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
|
59 |
+
boxes = boxes[keep]
|
60 |
+
classes = classes[keep]
|
61 |
+
masks = masks[keep]
|
62 |
+
if keypoints is not None:
|
63 |
+
keypoints = keypoints[keep]
|
64 |
+
|
65 |
+
target = {}
|
66 |
+
target["boxes"] = boxes
|
67 |
+
target["labels"] = classes
|
68 |
+
target["masks"] = masks
|
69 |
+
target["image_id"] = image_id
|
70 |
+
if keypoints is not None:
|
71 |
+
target["keypoints"] = keypoints
|
72 |
+
|
73 |
+
# for conversion to coco api
|
74 |
+
area = torch.tensor([obj["area"] for obj in anno])
|
75 |
+
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
|
76 |
+
target["area"] = area
|
77 |
+
target["iscrowd"] = iscrowd
|
78 |
+
|
79 |
+
return image, target
|
80 |
+
|
81 |
+
|
82 |
+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
|
83 |
+
def _has_only_empty_bbox(anno):
|
84 |
+
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
85 |
+
|
86 |
+
def _count_visible_keypoints(anno):
|
87 |
+
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
88 |
+
|
89 |
+
min_keypoints_per_image = 10
|
90 |
+
|
91 |
+
def _has_valid_annotation(anno):
|
92 |
+
# if it's empty, there is no annotation
|
93 |
+
if len(anno) == 0:
|
94 |
+
return False
|
95 |
+
# if all boxes have close to zero area, there is no annotation
|
96 |
+
if _has_only_empty_bbox(anno):
|
97 |
+
return False
|
98 |
+
# keypoints task have a slight different criteria for considering
|
99 |
+
# if an annotation is valid
|
100 |
+
if "keypoints" not in anno[0]:
|
101 |
+
return True
|
102 |
+
# for keypoint detection tasks, only consider valid images those
|
103 |
+
# containing at least min_keypoints_per_image
|
104 |
+
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
105 |
+
return True
|
106 |
+
return False
|
107 |
+
|
108 |
+
ids = []
|
109 |
+
for ds_idx, img_id in enumerate(dataset.ids):
|
110 |
+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
111 |
+
anno = dataset.coco.loadAnns(ann_ids)
|
112 |
+
if cat_list:
|
113 |
+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
|
114 |
+
if _has_valid_annotation(anno):
|
115 |
+
ids.append(ds_idx)
|
116 |
+
|
117 |
+
dataset = torch.utils.data.Subset(dataset, ids)
|
118 |
+
return dataset
|
119 |
+
|
120 |
+
|
121 |
+
def convert_to_coco_api(ds):
|
122 |
+
coco_ds = COCO()
|
123 |
+
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
|
124 |
+
ann_id = 1
|
125 |
+
dataset = {"images": [], "categories": [], "annotations": []}
|
126 |
+
categories = set()
|
127 |
+
for img_idx in range(len(ds)):
|
128 |
+
# find better way to get target
|
129 |
+
# targets = ds.get_annotations(img_idx)
|
130 |
+
img, targets = ds[img_idx]
|
131 |
+
image_id = targets["image_id"].item()
|
132 |
+
img_dict = {}
|
133 |
+
img_dict["id"] = image_id
|
134 |
+
img_dict["height"] = img.shape[-2]
|
135 |
+
img_dict["width"] = img.shape[-1]
|
136 |
+
dataset["images"].append(img_dict)
|
137 |
+
bboxes = targets["boxes"].clone()
|
138 |
+
bboxes[:, 2:] -= bboxes[:, :2]
|
139 |
+
bboxes = bboxes.tolist()
|
140 |
+
labels = targets["labels"].tolist()
|
141 |
+
areas = targets["area"].tolist()
|
142 |
+
iscrowd = targets["iscrowd"].tolist()
|
143 |
+
if "masks" in targets:
|
144 |
+
masks = targets["masks"]
|
145 |
+
# make masks Fortran contiguous for coco_mask
|
146 |
+
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
147 |
+
if "keypoints" in targets:
|
148 |
+
keypoints = targets["keypoints"]
|
149 |
+
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
|
150 |
+
num_objs = len(bboxes)
|
151 |
+
for i in range(num_objs):
|
152 |
+
ann = {}
|
153 |
+
ann["image_id"] = image_id
|
154 |
+
ann["bbox"] = bboxes[i]
|
155 |
+
ann["category_id"] = labels[i]
|
156 |
+
categories.add(labels[i])
|
157 |
+
ann["area"] = areas[i]
|
158 |
+
ann["iscrowd"] = iscrowd[i]
|
159 |
+
ann["id"] = ann_id
|
160 |
+
if "masks" in targets:
|
161 |
+
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
|
162 |
+
if "keypoints" in targets:
|
163 |
+
ann["keypoints"] = keypoints[i]
|
164 |
+
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
|
165 |
+
dataset["annotations"].append(ann)
|
166 |
+
ann_id += 1
|
167 |
+
dataset["categories"] = [{"id": i} for i in sorted(categories)]
|
168 |
+
coco_ds.dataset = dataset
|
169 |
+
coco_ds.createIndex()
|
170 |
+
return coco_ds
|
171 |
+
|
172 |
+
|
173 |
+
def get_coco_api_from_dataset(dataset):
|
174 |
+
# FIXME: This is... awful?
|
175 |
+
for _ in range(10):
|
176 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
177 |
+
break
|
178 |
+
if isinstance(dataset, torch.utils.data.Subset):
|
179 |
+
dataset = dataset.dataset
|
180 |
+
if isinstance(dataset, torchvision.datasets.CocoDetection):
|
181 |
+
return dataset.coco
|
182 |
+
return convert_to_coco_api(dataset)
|
183 |
+
|
184 |
+
|
src/data/dataloader.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data as data
|
3 |
+
|
4 |
+
from src.core import register
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = ['DataLoader']
|
8 |
+
|
9 |
+
|
10 |
+
@register
|
11 |
+
class DataLoader(data.DataLoader):
|
12 |
+
__inject__ = ['dataset', 'collate_fn']
|
13 |
+
|
14 |
+
def __repr__(self) -> str:
|
15 |
+
format_string = self.__class__.__name__ + "("
|
16 |
+
for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']:
|
17 |
+
format_string += "\n"
|
18 |
+
format_string += " {0}: {1}".format(n, getattr(self, n))
|
19 |
+
format_string += "\n)"
|
20 |
+
return format_string
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
@register
|
25 |
+
def default_collate_fn(items):
|
26 |
+
'''default collate_fn
|
27 |
+
'''
|
28 |
+
return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items]
|
src/data/functional.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms.functional as F
|
3 |
+
|
4 |
+
from packaging import version
|
5 |
+
from typing import Optional, List
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
9 |
+
import torchvision
|
10 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
11 |
+
from torchvision.ops import _new_empty_tensor
|
12 |
+
from torchvision.ops.misc import _output_size
|
13 |
+
|
14 |
+
|
15 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
16 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
17 |
+
"""
|
18 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
19 |
+
This will eventually be supported natively by PyTorch, and this
|
20 |
+
class can go away.
|
21 |
+
"""
|
22 |
+
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
23 |
+
if input.numel() > 0:
|
24 |
+
return torch.nn.functional.interpolate(
|
25 |
+
input, size, scale_factor, mode, align_corners
|
26 |
+
)
|
27 |
+
|
28 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
29 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
30 |
+
return _new_empty_tensor(input, output_shape)
|
31 |
+
else:
|
32 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
def crop(image, target, region):
|
37 |
+
cropped_image = F.crop(image, *region)
|
38 |
+
|
39 |
+
target = target.copy()
|
40 |
+
i, j, h, w = region
|
41 |
+
|
42 |
+
# should we do something wrt the original size?
|
43 |
+
target["size"] = torch.tensor([h, w])
|
44 |
+
|
45 |
+
fields = ["labels", "area", "iscrowd"]
|
46 |
+
|
47 |
+
if "boxes" in target:
|
48 |
+
boxes = target["boxes"]
|
49 |
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
50 |
+
cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
|
51 |
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
52 |
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
53 |
+
area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
|
54 |
+
target["boxes"] = cropped_boxes.reshape(-1, 4)
|
55 |
+
target["area"] = area
|
56 |
+
fields.append("boxes")
|
57 |
+
|
58 |
+
if "masks" in target:
|
59 |
+
# FIXME should we update the area here if there are no boxes?
|
60 |
+
target['masks'] = target['masks'][:, i:i + h, j:j + w]
|
61 |
+
fields.append("masks")
|
62 |
+
|
63 |
+
# remove elements for which the boxes or masks that have zero area
|
64 |
+
if "boxes" in target or "masks" in target:
|
65 |
+
# favor boxes selection when defining which elements to keep
|
66 |
+
# this is compatible with previous implementation
|
67 |
+
if "boxes" in target:
|
68 |
+
cropped_boxes = target['boxes'].reshape(-1, 2, 2)
|
69 |
+
keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
|
70 |
+
else:
|
71 |
+
keep = target['masks'].flatten(1).any(1)
|
72 |
+
|
73 |
+
for field in fields:
|
74 |
+
target[field] = target[field][keep]
|
75 |
+
|
76 |
+
return cropped_image, target
|
77 |
+
|
78 |
+
|
79 |
+
def hflip(image, target):
|
80 |
+
flipped_image = F.hflip(image)
|
81 |
+
|
82 |
+
w, h = image.size
|
83 |
+
|
84 |
+
target = target.copy()
|
85 |
+
if "boxes" in target:
|
86 |
+
boxes = target["boxes"]
|
87 |
+
boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
|
88 |
+
target["boxes"] = boxes
|
89 |
+
|
90 |
+
if "masks" in target:
|
91 |
+
target['masks'] = target['masks'].flip(-1)
|
92 |
+
|
93 |
+
return flipped_image, target
|
94 |
+
|
95 |
+
|
96 |
+
def resize(image, target, size, max_size=None):
|
97 |
+
# size can be min_size (scalar) or (w, h) tuple
|
98 |
+
|
99 |
+
def get_size_with_aspect_ratio(image_size, size, max_size=None):
|
100 |
+
w, h = image_size
|
101 |
+
if max_size is not None:
|
102 |
+
min_original_size = float(min((w, h)))
|
103 |
+
max_original_size = float(max((w, h)))
|
104 |
+
if max_original_size / min_original_size * size > max_size:
|
105 |
+
size = int(round(max_size * min_original_size / max_original_size))
|
106 |
+
|
107 |
+
if (w <= h and w == size) or (h <= w and h == size):
|
108 |
+
return (h, w)
|
109 |
+
|
110 |
+
if w < h:
|
111 |
+
ow = size
|
112 |
+
oh = int(size * h / w)
|
113 |
+
else:
|
114 |
+
oh = size
|
115 |
+
ow = int(size * w / h)
|
116 |
+
|
117 |
+
# r = min(size / min(h, w), max_size / max(h, w))
|
118 |
+
# ow = int(w * r)
|
119 |
+
# oh = int(h * r)
|
120 |
+
|
121 |
+
return (oh, ow)
|
122 |
+
|
123 |
+
def get_size(image_size, size, max_size=None):
|
124 |
+
if isinstance(size, (list, tuple)):
|
125 |
+
return size[::-1]
|
126 |
+
else:
|
127 |
+
return get_size_with_aspect_ratio(image_size, size, max_size)
|
128 |
+
|
129 |
+
size = get_size(image.size, size, max_size)
|
130 |
+
rescaled_image = F.resize(image, size)
|
131 |
+
|
132 |
+
if target is None:
|
133 |
+
return rescaled_image, None
|
134 |
+
|
135 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
136 |
+
ratio_width, ratio_height = ratios
|
137 |
+
|
138 |
+
target = target.copy()
|
139 |
+
if "boxes" in target:
|
140 |
+
boxes = target["boxes"]
|
141 |
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
142 |
+
target["boxes"] = scaled_boxes
|
143 |
+
|
144 |
+
if "area" in target:
|
145 |
+
area = target["area"]
|
146 |
+
scaled_area = area * (ratio_width * ratio_height)
|
147 |
+
target["area"] = scaled_area
|
148 |
+
|
149 |
+
h, w = size
|
150 |
+
target["size"] = torch.tensor([h, w])
|
151 |
+
|
152 |
+
if "masks" in target:
|
153 |
+
target['masks'] = interpolate(
|
154 |
+
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
|
155 |
+
|
156 |
+
return rescaled_image, target
|
157 |
+
|
158 |
+
|
159 |
+
def pad(image, target, padding):
|
160 |
+
# assumes that we only pad on the bottom right corners
|
161 |
+
padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
|
162 |
+
if target is None:
|
163 |
+
return padded_image, None
|
164 |
+
target = target.copy()
|
165 |
+
# should we do something wrt the original size?
|
166 |
+
target["size"] = torch.tensor(padded_image.size[::-1])
|
167 |
+
if "masks" in target:
|
168 |
+
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
|
169 |
+
return padded_image, target
|
src/data/transforms.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
import torchvision
|
9 |
+
torchvision.disable_beta_transforms_warning()
|
10 |
+
from torchvision import datapoints
|
11 |
+
|
12 |
+
import torchvision.transforms.v2 as T
|
13 |
+
import torchvision.transforms.v2.functional as F
|
14 |
+
|
15 |
+
from PIL import Image
|
16 |
+
from typing import Any, Dict, List, Optional
|
17 |
+
|
18 |
+
from src.core import register, GLOBAL_CONFIG
|
19 |
+
|
20 |
+
|
21 |
+
__all__ = ['Compose', ]
|
22 |
+
|
23 |
+
|
24 |
+
RandomPhotometricDistort = register(T.RandomPhotometricDistort)
|
25 |
+
RandomZoomOut = register(T.RandomZoomOut)
|
26 |
+
# RandomIoUCrop = register(T.RandomIoUCrop)
|
27 |
+
RandomHorizontalFlip = register(T.RandomHorizontalFlip)
|
28 |
+
Resize = register(T.Resize)
|
29 |
+
ToImageTensor = register(T.ToImageTensor)
|
30 |
+
ConvertDtype = register(T.ConvertDtype)
|
31 |
+
SanitizeBoundingBox = register(T.SanitizeBoundingBox)
|
32 |
+
RandomCrop = register(T.RandomCrop)
|
33 |
+
Normalize = register(T.Normalize)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
@register
|
38 |
+
class Compose(T.Compose):
|
39 |
+
def __init__(self, ops) -> None:
|
40 |
+
transforms = []
|
41 |
+
if ops is not None:
|
42 |
+
for op in ops:
|
43 |
+
if isinstance(op, dict):
|
44 |
+
name = op.pop('type')
|
45 |
+
transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op)
|
46 |
+
transforms.append(transfom)
|
47 |
+
# op['type'] = name
|
48 |
+
elif isinstance(op, nn.Module):
|
49 |
+
transforms.append(op)
|
50 |
+
|
51 |
+
else:
|
52 |
+
raise ValueError('')
|
53 |
+
else:
|
54 |
+
transforms =[EmptyTransform(), ]
|
55 |
+
|
56 |
+
super().__init__(transforms=transforms)
|
57 |
+
|
58 |
+
|
59 |
+
@register
|
60 |
+
class EmptyTransform(T.Transform):
|
61 |
+
def __init__(self, ) -> None:
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
def forward(self, *inputs):
|
65 |
+
inputs = inputs if len(inputs) > 1 else inputs[0]
|
66 |
+
return inputs
|
67 |
+
|
68 |
+
|
69 |
+
@register
|
70 |
+
class PadToSize(T.Pad):
|
71 |
+
_transformed_types = (
|
72 |
+
Image.Image,
|
73 |
+
datapoints.Image,
|
74 |
+
datapoints.Video,
|
75 |
+
datapoints.Mask,
|
76 |
+
datapoints.BoundingBox,
|
77 |
+
)
|
78 |
+
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
|
79 |
+
sz = F.get_spatial_size(flat_inputs[0])
|
80 |
+
h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1]
|
81 |
+
self.padding = [0, 0, w, h]
|
82 |
+
return dict(padding=self.padding)
|
83 |
+
|
84 |
+
def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
|
85 |
+
if isinstance(spatial_size, int):
|
86 |
+
spatial_size = (spatial_size, spatial_size)
|
87 |
+
|
88 |
+
self.spatial_size = spatial_size
|
89 |
+
super().__init__(0, fill, padding_mode)
|
90 |
+
|
91 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
92 |
+
fill = self._fill[type(inpt)]
|
93 |
+
padding = params['padding']
|
94 |
+
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
|
95 |
+
|
96 |
+
def __call__(self, *inputs: Any) -> Any:
|
97 |
+
outputs = super().forward(*inputs)
|
98 |
+
if len(outputs) > 1 and isinstance(outputs[1], dict):
|
99 |
+
outputs[1]['padding'] = torch.tensor(self.padding)
|
100 |
+
return outputs
|
101 |
+
|
102 |
+
|
103 |
+
@register
|
104 |
+
class RandomIoUCrop(T.RandomIoUCrop):
|
105 |
+
def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0):
|
106 |
+
super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials)
|
107 |
+
self.p = p
|
108 |
+
|
109 |
+
def __call__(self, *inputs: Any) -> Any:
|
110 |
+
if torch.rand(1) >= self.p:
|
111 |
+
return inputs if len(inputs) > 1 else inputs[0]
|
112 |
+
|
113 |
+
return super().forward(*inputs)
|
114 |
+
|
115 |
+
|
116 |
+
@register
|
117 |
+
class ConvertBox(T.Transform):
|
118 |
+
_transformed_types = (
|
119 |
+
datapoints.BoundingBox,
|
120 |
+
)
|
121 |
+
def __init__(self, out_fmt='', normalize=False) -> None:
|
122 |
+
super().__init__()
|
123 |
+
self.out_fmt = out_fmt
|
124 |
+
self.normalize = normalize
|
125 |
+
|
126 |
+
self.data_fmt = {
|
127 |
+
'xyxy': datapoints.BoundingBoxFormat.XYXY,
|
128 |
+
'cxcywh': datapoints.BoundingBoxFormat.CXCYWH
|
129 |
+
}
|
130 |
+
|
131 |
+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
132 |
+
if self.out_fmt:
|
133 |
+
spatial_size = inpt.spatial_size
|
134 |
+
in_fmt = inpt.format.value.lower()
|
135 |
+
inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt)
|
136 |
+
inpt = datapoints.BoundingBox(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size)
|
137 |
+
|
138 |
+
if self.normalize:
|
139 |
+
inpt = inpt / torch.tensor(inpt.spatial_size[::-1]).tile(2)[None]
|
140 |
+
|
141 |
+
return inpt
|
142 |
+
|
src/misc/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .logger import *
|
3 |
+
from .visualizer import *
|
src/misc/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (211 Bytes). View file
|
|
src/misc/__pycache__/dist.cpython-310.pyc
ADDED
Binary file (4.79 kB). View file
|
|
src/misc/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (7.77 kB). View file
|
|
src/misc/__pycache__/visualizer.cpython-310.pyc
ADDED
Binary file (1.09 kB). View file
|
|
src/misc/dist.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
reference
|
3 |
+
- https://github.com/pytorch/vision/blob/main/references/detection/utils.py
|
4 |
+
- https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406
|
5 |
+
|
6 |
+
by lyuwenyu
|
7 |
+
"""
|
8 |
+
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.distributed
|
15 |
+
import torch.distributed as tdist
|
16 |
+
|
17 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
18 |
+
|
19 |
+
from torch.utils.data import DistributedSampler
|
20 |
+
from torch.utils.data.dataloader import DataLoader
|
21 |
+
|
22 |
+
|
23 |
+
def init_distributed():
|
24 |
+
'''
|
25 |
+
distributed setup
|
26 |
+
args:
|
27 |
+
backend (str), ('nccl', 'gloo')
|
28 |
+
'''
|
29 |
+
try:
|
30 |
+
# # https://pytorch.org/docs/stable/elastic/run.html
|
31 |
+
# LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
|
32 |
+
# RANK = int(os.getenv('RANK', -1))
|
33 |
+
# WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
34 |
+
|
35 |
+
tdist.init_process_group(init_method='env://', )
|
36 |
+
torch.distributed.barrier()
|
37 |
+
|
38 |
+
rank = get_rank()
|
39 |
+
device = torch.device(f'cuda:{rank}')
|
40 |
+
torch.cuda.set_device(device)
|
41 |
+
|
42 |
+
setup_print(rank == 0)
|
43 |
+
print('Initialized distributed mode...')
|
44 |
+
|
45 |
+
return True
|
46 |
+
|
47 |
+
except:
|
48 |
+
print('Not init distributed mode.')
|
49 |
+
return False
|
50 |
+
|
51 |
+
|
52 |
+
def setup_print(is_main):
|
53 |
+
'''This function disables printing when not in master process
|
54 |
+
'''
|
55 |
+
import builtins as __builtin__
|
56 |
+
builtin_print = __builtin__.print
|
57 |
+
|
58 |
+
def print(*args, **kwargs):
|
59 |
+
force = kwargs.pop('force', False)
|
60 |
+
if is_main or force:
|
61 |
+
builtin_print(*args, **kwargs)
|
62 |
+
|
63 |
+
__builtin__.print = print
|
64 |
+
|
65 |
+
|
66 |
+
def is_dist_available_and_initialized():
|
67 |
+
if not tdist.is_available():
|
68 |
+
return False
|
69 |
+
if not tdist.is_initialized():
|
70 |
+
return False
|
71 |
+
return True
|
72 |
+
|
73 |
+
|
74 |
+
def get_rank():
|
75 |
+
if not is_dist_available_and_initialized():
|
76 |
+
return 0
|
77 |
+
return tdist.get_rank()
|
78 |
+
|
79 |
+
|
80 |
+
def get_world_size():
|
81 |
+
if not is_dist_available_and_initialized():
|
82 |
+
return 1
|
83 |
+
return tdist.get_world_size()
|
84 |
+
|
85 |
+
|
86 |
+
def is_main_process():
|
87 |
+
return get_rank() == 0
|
88 |
+
|
89 |
+
|
90 |
+
def save_on_master(*args, **kwargs):
|
91 |
+
if is_main_process():
|
92 |
+
torch.save(*args, **kwargs)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def warp_model(model, find_unused_parameters=False, sync_bn=False,):
|
97 |
+
if is_dist_available_and_initialized():
|
98 |
+
rank = get_rank()
|
99 |
+
model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model
|
100 |
+
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters)
|
101 |
+
return model
|
102 |
+
|
103 |
+
|
104 |
+
def warp_loader(loader, shuffle=False):
|
105 |
+
if is_dist_available_and_initialized():
|
106 |
+
sampler = DistributedSampler(loader.dataset, shuffle=shuffle)
|
107 |
+
loader = DataLoader(loader.dataset,
|
108 |
+
loader.batch_size,
|
109 |
+
sampler=sampler,
|
110 |
+
drop_last=loader.drop_last,
|
111 |
+
collate_fn=loader.collate_fn,
|
112 |
+
pin_memory=loader.pin_memory,
|
113 |
+
num_workers=loader.num_workers, )
|
114 |
+
return loader
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
def is_parallel(model) -> bool:
|
119 |
+
# Returns True if model is of type DP or DDP
|
120 |
+
return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel)
|
121 |
+
|
122 |
+
|
123 |
+
def de_parallel(model) -> nn.Module:
|
124 |
+
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
125 |
+
return model.module if is_parallel(model) else model
|
126 |
+
|
127 |
+
|
128 |
+
def reduce_dict(data, avg=True):
|
129 |
+
'''
|
130 |
+
Args
|
131 |
+
data dict: input, {k: v, ...}
|
132 |
+
avg bool: true
|
133 |
+
'''
|
134 |
+
world_size = get_world_size()
|
135 |
+
if world_size < 2:
|
136 |
+
return data
|
137 |
+
|
138 |
+
with torch.no_grad():
|
139 |
+
keys, values = [], []
|
140 |
+
for k in sorted(data.keys()):
|
141 |
+
keys.append(k)
|
142 |
+
values.append(data[k])
|
143 |
+
|
144 |
+
values = torch.stack(values, dim=0)
|
145 |
+
tdist.all_reduce(values)
|
146 |
+
|
147 |
+
if avg is True:
|
148 |
+
values /= world_size
|
149 |
+
|
150 |
+
_data = {k: v for k, v in zip(keys, values)}
|
151 |
+
|
152 |
+
return _data
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def all_gather(data):
|
157 |
+
"""
|
158 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
159 |
+
Args:
|
160 |
+
data: any picklable object
|
161 |
+
Returns:
|
162 |
+
list[data]: list of data gathered from each rank
|
163 |
+
"""
|
164 |
+
world_size = get_world_size()
|
165 |
+
if world_size == 1:
|
166 |
+
return [data]
|
167 |
+
data_list = [None] * world_size
|
168 |
+
tdist.all_gather_object(data_list, data)
|
169 |
+
return data_list
|
170 |
+
|
171 |
+
|
172 |
+
import time
|
173 |
+
def sync_time():
|
174 |
+
'''sync_time
|
175 |
+
'''
|
176 |
+
if torch.cuda.is_available():
|
177 |
+
torch.cuda.synchronize()
|
178 |
+
|
179 |
+
return time.time()
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
def set_seed(seed):
|
184 |
+
# fix the seed for reproducibility
|
185 |
+
seed = seed + get_rank()
|
186 |
+
torch.manual_seed(seed)
|
187 |
+
np.random.seed(seed)
|
188 |
+
random.seed(seed)
|
189 |
+
|
190 |
+
|
src/misc/logger.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
3 |
+
https://github.com/facebookresearch/detr/blob/main/util/misc.py
|
4 |
+
Mostly copy-paste from torchvision references.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import time
|
8 |
+
import pickle
|
9 |
+
import datetime
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
from typing import Dict
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.distributed as tdist
|
15 |
+
|
16 |
+
from .dist import is_dist_available_and_initialized, get_world_size
|
17 |
+
|
18 |
+
|
19 |
+
class SmoothedValue(object):
|
20 |
+
"""Track a series of values and provide access to smoothed values over a
|
21 |
+
window or the global series average.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, window_size=20, fmt=None):
|
25 |
+
if fmt is None:
|
26 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
27 |
+
self.deque = deque(maxlen=window_size)
|
28 |
+
self.total = 0.0
|
29 |
+
self.count = 0
|
30 |
+
self.fmt = fmt
|
31 |
+
|
32 |
+
def update(self, value, n=1):
|
33 |
+
self.deque.append(value)
|
34 |
+
self.count += n
|
35 |
+
self.total += value * n
|
36 |
+
|
37 |
+
def synchronize_between_processes(self):
|
38 |
+
"""
|
39 |
+
Warning: does not synchronize the deque!
|
40 |
+
"""
|
41 |
+
if not is_dist_available_and_initialized():
|
42 |
+
return
|
43 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
44 |
+
tdist.barrier()
|
45 |
+
tdist.all_reduce(t)
|
46 |
+
t = t.tolist()
|
47 |
+
self.count = int(t[0])
|
48 |
+
self.total = t[1]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def median(self):
|
52 |
+
d = torch.tensor(list(self.deque))
|
53 |
+
return d.median().item()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def avg(self):
|
57 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
58 |
+
return d.mean().item()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def global_avg(self):
|
62 |
+
return self.total / self.count
|
63 |
+
|
64 |
+
@property
|
65 |
+
def max(self):
|
66 |
+
return max(self.deque)
|
67 |
+
|
68 |
+
@property
|
69 |
+
def value(self):
|
70 |
+
return self.deque[-1]
|
71 |
+
|
72 |
+
def __str__(self):
|
73 |
+
return self.fmt.format(
|
74 |
+
median=self.median,
|
75 |
+
avg=self.avg,
|
76 |
+
global_avg=self.global_avg,
|
77 |
+
max=self.max,
|
78 |
+
value=self.value)
|
79 |
+
|
80 |
+
|
81 |
+
def all_gather(data):
|
82 |
+
"""
|
83 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
84 |
+
Args:
|
85 |
+
data: any picklable object
|
86 |
+
Returns:
|
87 |
+
list[data]: list of data gathered from each rank
|
88 |
+
"""
|
89 |
+
world_size = get_world_size()
|
90 |
+
if world_size == 1:
|
91 |
+
return [data]
|
92 |
+
|
93 |
+
# serialized to a Tensor
|
94 |
+
buffer = pickle.dumps(data)
|
95 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
96 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
97 |
+
|
98 |
+
# obtain Tensor size of each rank
|
99 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
100 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
101 |
+
tdist.all_gather(size_list, local_size)
|
102 |
+
size_list = [int(size.item()) for size in size_list]
|
103 |
+
max_size = max(size_list)
|
104 |
+
|
105 |
+
# receiving Tensor from all ranks
|
106 |
+
# we pad the tensor because torch all_gather does not support
|
107 |
+
# gathering tensors of different shapes
|
108 |
+
tensor_list = []
|
109 |
+
for _ in size_list:
|
110 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
111 |
+
if local_size != max_size:
|
112 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
113 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
114 |
+
tdist.all_gather(tensor_list, tensor)
|
115 |
+
|
116 |
+
data_list = []
|
117 |
+
for size, tensor in zip(size_list, tensor_list):
|
118 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
119 |
+
data_list.append(pickle.loads(buffer))
|
120 |
+
|
121 |
+
return data_list
|
122 |
+
|
123 |
+
|
124 |
+
def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]:
|
125 |
+
"""
|
126 |
+
Args:
|
127 |
+
input_dict (dict): all the values will be reduced
|
128 |
+
average (bool): whether to do average or sum
|
129 |
+
Reduce the values in the dictionary from all processes so that all processes
|
130 |
+
have the averaged results. Returns a dict with the same fields as
|
131 |
+
input_dict, after reduction.
|
132 |
+
"""
|
133 |
+
world_size = get_world_size()
|
134 |
+
if world_size < 2:
|
135 |
+
return input_dict
|
136 |
+
with torch.no_grad():
|
137 |
+
names = []
|
138 |
+
values = []
|
139 |
+
# sort the keys so that they are consistent across processes
|
140 |
+
for k in sorted(input_dict.keys()):
|
141 |
+
names.append(k)
|
142 |
+
values.append(input_dict[k])
|
143 |
+
values = torch.stack(values, dim=0)
|
144 |
+
tdist.all_reduce(values)
|
145 |
+
if average:
|
146 |
+
values /= world_size
|
147 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
148 |
+
return reduced_dict
|
149 |
+
|
150 |
+
|
151 |
+
class MetricLogger(object):
|
152 |
+
def __init__(self, delimiter="\t"):
|
153 |
+
self.meters = defaultdict(SmoothedValue)
|
154 |
+
self.delimiter = delimiter
|
155 |
+
|
156 |
+
def update(self, **kwargs):
|
157 |
+
for k, v in kwargs.items():
|
158 |
+
if isinstance(v, torch.Tensor):
|
159 |
+
v = v.item()
|
160 |
+
assert isinstance(v, (float, int))
|
161 |
+
self.meters[k].update(v)
|
162 |
+
|
163 |
+
def __getattr__(self, attr):
|
164 |
+
if attr in self.meters:
|
165 |
+
return self.meters[attr]
|
166 |
+
if attr in self.__dict__:
|
167 |
+
return self.__dict__[attr]
|
168 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
169 |
+
type(self).__name__, attr))
|
170 |
+
|
171 |
+
def __str__(self):
|
172 |
+
loss_str = []
|
173 |
+
for name, meter in self.meters.items():
|
174 |
+
loss_str.append(
|
175 |
+
"{}: {}".format(name, str(meter))
|
176 |
+
)
|
177 |
+
return self.delimiter.join(loss_str)
|
178 |
+
|
179 |
+
def synchronize_between_processes(self):
|
180 |
+
for meter in self.meters.values():
|
181 |
+
meter.synchronize_between_processes()
|
182 |
+
|
183 |
+
def add_meter(self, name, meter):
|
184 |
+
self.meters[name] = meter
|
185 |
+
|
186 |
+
def log_every(self, iterable, print_freq, header=None):
|
187 |
+
i = 0
|
188 |
+
if not header:
|
189 |
+
header = ''
|
190 |
+
start_time = time.time()
|
191 |
+
end = time.time()
|
192 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
193 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
194 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
195 |
+
if torch.cuda.is_available():
|
196 |
+
log_msg = self.delimiter.join([
|
197 |
+
header,
|
198 |
+
'[{0' + space_fmt + '}/{1}]',
|
199 |
+
'eta: {eta}',
|
200 |
+
'{meters}',
|
201 |
+
'time: {time}',
|
202 |
+
'data: {data}',
|
203 |
+
'max mem: {memory:.0f}'
|
204 |
+
])
|
205 |
+
else:
|
206 |
+
log_msg = self.delimiter.join([
|
207 |
+
header,
|
208 |
+
'[{0' + space_fmt + '}/{1}]',
|
209 |
+
'eta: {eta}',
|
210 |
+
'{meters}',
|
211 |
+
'time: {time}',
|
212 |
+
'data: {data}'
|
213 |
+
])
|
214 |
+
MB = 1024.0 * 1024.0
|
215 |
+
for obj in iterable:
|
216 |
+
data_time.update(time.time() - end)
|
217 |
+
yield obj
|
218 |
+
iter_time.update(time.time() - end)
|
219 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
220 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
221 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
222 |
+
if torch.cuda.is_available():
|
223 |
+
print(log_msg.format(
|
224 |
+
i, len(iterable), eta=eta_string,
|
225 |
+
meters=str(self),
|
226 |
+
time=str(iter_time), data=str(data_time),
|
227 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
228 |
+
else:
|
229 |
+
print(log_msg.format(
|
230 |
+
i, len(iterable), eta=eta_string,
|
231 |
+
meters=str(self),
|
232 |
+
time=str(iter_time), data=str(data_time)))
|
233 |
+
i += 1
|
234 |
+
end = time.time()
|
235 |
+
total_time = time.time() - start_time
|
236 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
237 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
238 |
+
header, total_time_str, total_time / len(iterable)))
|
239 |
+
|
src/misc/visualizer.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""""by lyuwenyu
|
2 |
+
"""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
|
7 |
+
import torchvision
|
8 |
+
torchvision.disable_beta_transforms_warning()
|
9 |
+
|
10 |
+
import PIL
|
11 |
+
|
12 |
+
__all__ = ['show_sample']
|
13 |
+
|
14 |
+
def show_sample(sample):
|
15 |
+
"""for coco dataset/dataloader
|
16 |
+
"""
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
from torchvision.transforms.v2 import functional as F
|
19 |
+
from torchvision.utils import draw_bounding_boxes
|
20 |
+
|
21 |
+
image, target = sample
|
22 |
+
if isinstance(image, PIL.Image.Image):
|
23 |
+
image = F.to_image_tensor(image)
|
24 |
+
|
25 |
+
image = F.convert_dtype(image, torch.uint8)
|
26 |
+
annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3)
|
27 |
+
|
28 |
+
fig, ax = plt.subplots()
|
29 |
+
ax.imshow(annotated_image.permute(1, 2, 0).numpy())
|
30 |
+
ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
|
31 |
+
fig.tight_layout()
|
32 |
+
fig.show()
|
33 |
+
plt.show()
|
34 |
+
|
src/nn/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .arch import *
|
3 |
+
from .criterion import *
|
4 |
+
|
5 |
+
#
|
6 |
+
from .backbone import *
|
7 |
+
|
src/nn/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (226 Bytes). View file
|
|
src/nn/arch/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .classification import *
|
src/nn/arch/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (200 Bytes). View file
|
|
src/nn/arch/__pycache__/classification.cpython-310.pyc
ADDED
Binary file (1.55 kB). View file
|
|
src/nn/arch/classification.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from src.core import register
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = ['Classification', 'ClassHead']
|
8 |
+
|
9 |
+
|
10 |
+
@register
|
11 |
+
class Classification(nn.Module):
|
12 |
+
__inject__ = ['backbone', 'head']
|
13 |
+
|
14 |
+
def __init__(self, backbone: nn.Module, head: nn.Module=None):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.backbone = backbone
|
18 |
+
self.head = head
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.backbone(x)
|
22 |
+
|
23 |
+
if self.head is not None:
|
24 |
+
x = self.head(x)
|
25 |
+
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
@register
|
30 |
+
class ClassHead(nn.Module):
|
31 |
+
def __init__(self, hidden_dim, num_classes):
|
32 |
+
super().__init__()
|
33 |
+
self.pool = nn.AdaptiveAvgPool2d(1)
|
34 |
+
self.proj = nn.Linear(hidden_dim, num_classes)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = x[0] if isinstance(x, (list, tuple)) else x
|
38 |
+
x = self.pool(x)
|
39 |
+
x = x.reshape(x.shape[0], -1)
|
40 |
+
x = self.proj(x)
|
41 |
+
return x
|
src/nn/backbone/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .presnet import *
|
3 |
+
from .test_resnet import *
|
4 |
+
|
5 |
+
from .common import *
|
src/nn/backbone/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (238 Bytes). View file
|
|
src/nn/backbone/__pycache__/common.cpython-310.pyc
ADDED
Binary file (3.28 kB). View file
|
|
src/nn/backbone/__pycache__/presnet.cpython-310.pyc
ADDED
Binary file (6.45 kB). View file
|
|
src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc
ADDED
Binary file (3.06 kB). View file
|
|
src/nn/backbone/common.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''by lyuwenyu
|
2 |
+
'''
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
class ConvNormLayer(nn.Module):
|
10 |
+
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
|
11 |
+
super().__init__()
|
12 |
+
self.conv = nn.Conv2d(
|
13 |
+
ch_in,
|
14 |
+
ch_out,
|
15 |
+
kernel_size,
|
16 |
+
stride,
|
17 |
+
padding=(kernel_size-1)//2 if padding is None else padding,
|
18 |
+
bias=bias)
|
19 |
+
self.norm = nn.BatchNorm2d(ch_out)
|
20 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.act(self.norm(self.conv(x)))
|
24 |
+
|
25 |
+
|
26 |
+
class FrozenBatchNorm2d(nn.Module):
|
27 |
+
"""copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py
|
28 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
29 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
30 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
31 |
+
produce nans.
|
32 |
+
"""
|
33 |
+
def __init__(self, num_features, eps=1e-5):
|
34 |
+
super(FrozenBatchNorm2d, self).__init__()
|
35 |
+
n = num_features
|
36 |
+
self.register_buffer("weight", torch.ones(n))
|
37 |
+
self.register_buffer("bias", torch.zeros(n))
|
38 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
39 |
+
self.register_buffer("running_var", torch.ones(n))
|
40 |
+
self.eps = eps
|
41 |
+
self.num_features = n
|
42 |
+
|
43 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
44 |
+
missing_keys, unexpected_keys, error_msgs):
|
45 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
46 |
+
if num_batches_tracked_key in state_dict:
|
47 |
+
del state_dict[num_batches_tracked_key]
|
48 |
+
|
49 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
50 |
+
state_dict, prefix, local_metadata, strict,
|
51 |
+
missing_keys, unexpected_keys, error_msgs)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
# move reshapes to the beginning
|
55 |
+
# to make it fuser-friendly
|
56 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
57 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
58 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
59 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
60 |
+
scale = w * (rv + self.eps).rsqrt()
|
61 |
+
bias = b - rm * scale
|
62 |
+
return x * scale + bias
|
63 |
+
|
64 |
+
def extra_repr(self):
|
65 |
+
return (
|
66 |
+
"{num_features}, eps={eps}".format(**self.__dict__)
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def get_activation(act: str, inpace: bool=True):
|
71 |
+
'''get activation
|
72 |
+
'''
|
73 |
+
act = act.lower()
|
74 |
+
|
75 |
+
if act == 'silu':
|
76 |
+
m = nn.SiLU()
|
77 |
+
|
78 |
+
elif act == 'relu':
|
79 |
+
m = nn.ReLU()
|
80 |
+
|
81 |
+
elif act == 'leaky_relu':
|
82 |
+
m = nn.LeakyReLU()
|
83 |
+
|
84 |
+
elif act == 'silu':
|
85 |
+
m = nn.SiLU()
|
86 |
+
|
87 |
+
elif act == 'gelu':
|
88 |
+
m = nn.GELU()
|
89 |
+
|
90 |
+
elif act is None:
|
91 |
+
m = nn.Identity()
|
92 |
+
|
93 |
+
elif isinstance(act, nn.Module):
|
94 |
+
m = act
|
95 |
+
|
96 |
+
else:
|
97 |
+
raise RuntimeError('')
|
98 |
+
|
99 |
+
if hasattr(m, 'inplace'):
|
100 |
+
m.inplace = inpace
|
101 |
+
|
102 |
+
return m
|
src/nn/backbone/presnet.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''by lyuwenyu
|
2 |
+
'''
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
from .common import get_activation, ConvNormLayer, FrozenBatchNorm2d
|
10 |
+
|
11 |
+
from src.core import register
|
12 |
+
|
13 |
+
|
14 |
+
__all__ = ['PResNet']
|
15 |
+
|
16 |
+
|
17 |
+
ResNet_cfg = {
|
18 |
+
18: [2, 2, 2, 2],
|
19 |
+
34: [3, 4, 6, 3],
|
20 |
+
50: [3, 4, 6, 3],
|
21 |
+
101: [3, 4, 23, 3],
|
22 |
+
# 152: [3, 8, 36, 3],
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
donwload_url = {
|
27 |
+
18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth',
|
28 |
+
34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth',
|
29 |
+
50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth',
|
30 |
+
101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth',
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class BasicBlock(nn.Module):
|
35 |
+
expansion = 1
|
36 |
+
|
37 |
+
def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
self.shortcut = shortcut
|
41 |
+
|
42 |
+
if not shortcut:
|
43 |
+
if variant == 'd' and stride == 2:
|
44 |
+
self.short = nn.Sequential(OrderedDict([
|
45 |
+
('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
46 |
+
('conv', ConvNormLayer(ch_in, ch_out, 1, 1))
|
47 |
+
]))
|
48 |
+
else:
|
49 |
+
self.short = ConvNormLayer(ch_in, ch_out, 1, stride)
|
50 |
+
|
51 |
+
self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act)
|
52 |
+
self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None)
|
53 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
54 |
+
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
out = self.branch2a(x)
|
58 |
+
out = self.branch2b(out)
|
59 |
+
if self.shortcut:
|
60 |
+
short = x
|
61 |
+
else:
|
62 |
+
short = self.short(x)
|
63 |
+
|
64 |
+
out = out + short
|
65 |
+
out = self.act(out)
|
66 |
+
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
class BottleNeck(nn.Module):
|
71 |
+
expansion = 4
|
72 |
+
|
73 |
+
def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
if variant == 'a':
|
77 |
+
stride1, stride2 = stride, 1
|
78 |
+
else:
|
79 |
+
stride1, stride2 = 1, stride
|
80 |
+
|
81 |
+
width = ch_out
|
82 |
+
|
83 |
+
self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act)
|
84 |
+
self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act)
|
85 |
+
self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1)
|
86 |
+
|
87 |
+
self.shortcut = shortcut
|
88 |
+
if not shortcut:
|
89 |
+
if variant == 'd' and stride == 2:
|
90 |
+
self.short = nn.Sequential(OrderedDict([
|
91 |
+
('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
|
92 |
+
('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1))
|
93 |
+
]))
|
94 |
+
else:
|
95 |
+
self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
|
96 |
+
|
97 |
+
self.act = nn.Identity() if act is None else get_activation(act)
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
out = self.branch2a(x)
|
101 |
+
out = self.branch2b(out)
|
102 |
+
out = self.branch2c(out)
|
103 |
+
|
104 |
+
if self.shortcut:
|
105 |
+
short = x
|
106 |
+
else:
|
107 |
+
short = self.short(x)
|
108 |
+
|
109 |
+
out = out + short
|
110 |
+
out = self.act(out)
|
111 |
+
|
112 |
+
return out
|
113 |
+
|
114 |
+
|
115 |
+
class Blocks(nn.Module):
|
116 |
+
def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.blocks = nn.ModuleList()
|
120 |
+
for i in range(count):
|
121 |
+
self.blocks.append(
|
122 |
+
block(
|
123 |
+
ch_in,
|
124 |
+
ch_out,
|
125 |
+
stride=2 if i == 0 and stage_num != 2 else 1,
|
126 |
+
shortcut=False if i == 0 else True,
|
127 |
+
variant=variant,
|
128 |
+
act=act)
|
129 |
+
)
|
130 |
+
|
131 |
+
if i == 0:
|
132 |
+
ch_in = ch_out * block.expansion
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
out = x
|
136 |
+
for block in self.blocks:
|
137 |
+
out = block(out)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
@register
|
142 |
+
class PResNet(nn.Module):
|
143 |
+
def __init__(
|
144 |
+
self,
|
145 |
+
depth,
|
146 |
+
variant='d',
|
147 |
+
num_stages=4,
|
148 |
+
return_idx=[0, 1, 2, 3],
|
149 |
+
act='relu',
|
150 |
+
freeze_at=-1,
|
151 |
+
freeze_norm=True,
|
152 |
+
pretrained=False):
|
153 |
+
super().__init__()
|
154 |
+
|
155 |
+
block_nums = ResNet_cfg[depth]
|
156 |
+
ch_in = 64
|
157 |
+
if variant in ['c', 'd']:
|
158 |
+
conv_def = [
|
159 |
+
[3, ch_in // 2, 3, 2, "conv1_1"],
|
160 |
+
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
|
161 |
+
[ch_in // 2, ch_in, 3, 1, "conv1_3"],
|
162 |
+
]
|
163 |
+
else:
|
164 |
+
conv_def = [[3, ch_in, 7, 2, "conv1_1"]]
|
165 |
+
|
166 |
+
self.conv1 = nn.Sequential(OrderedDict([
|
167 |
+
(_name, ConvNormLayer(c_in, c_out, k, s, act=act)) for c_in, c_out, k, s, _name in conv_def
|
168 |
+
]))
|
169 |
+
|
170 |
+
ch_out_list = [64, 128, 256, 512]
|
171 |
+
block = BottleNeck if depth >= 50 else BasicBlock
|
172 |
+
|
173 |
+
_out_channels = [block.expansion * v for v in ch_out_list]
|
174 |
+
_out_strides = [4, 8, 16, 32]
|
175 |
+
|
176 |
+
self.res_layers = nn.ModuleList()
|
177 |
+
for i in range(num_stages):
|
178 |
+
stage_num = i + 2
|
179 |
+
self.res_layers.append(
|
180 |
+
Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant)
|
181 |
+
)
|
182 |
+
ch_in = _out_channels[i]
|
183 |
+
|
184 |
+
self.return_idx = return_idx
|
185 |
+
self.out_channels = [_out_channels[_i] for _i in return_idx]
|
186 |
+
self.out_strides = [_out_strides[_i] for _i in return_idx]
|
187 |
+
|
188 |
+
if freeze_at >= 0:
|
189 |
+
self._freeze_parameters(self.conv1)
|
190 |
+
for i in range(min(freeze_at, num_stages)):
|
191 |
+
self._freeze_parameters(self.res_layers[i])
|
192 |
+
|
193 |
+
if freeze_norm:
|
194 |
+
self._freeze_norm(self)
|
195 |
+
|
196 |
+
if pretrained:
|
197 |
+
state = torch.hub.load_state_dict_from_url(donwload_url[depth])
|
198 |
+
self.load_state_dict(state)
|
199 |
+
print(f'Load PResNet{depth} state_dict')
|
200 |
+
|
201 |
+
def _freeze_parameters(self, m: nn.Module):
|
202 |
+
for p in m.parameters():
|
203 |
+
p.requires_grad = False
|
204 |
+
|
205 |
+
def _freeze_norm(self, m: nn.Module):
|
206 |
+
if isinstance(m, nn.BatchNorm2d):
|
207 |
+
m = FrozenBatchNorm2d(m.num_features)
|
208 |
+
else:
|
209 |
+
for name, child in m.named_children():
|
210 |
+
_child = self._freeze_norm(child)
|
211 |
+
if _child is not child:
|
212 |
+
setattr(m, name, _child)
|
213 |
+
return m
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
conv1 = self.conv1(x)
|
217 |
+
x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
|
218 |
+
outs = []
|
219 |
+
for idx, stage in enumerate(self.res_layers):
|
220 |
+
x = stage(x)
|
221 |
+
if idx in self.return_idx:
|
222 |
+
outs.append(x)
|
223 |
+
return outs
|
224 |
+
|
225 |
+
|
src/nn/backbone/test_resnet.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
|
8 |
+
from src.core import register
|
9 |
+
|
10 |
+
|
11 |
+
class BasicBlock(nn.Module):
|
12 |
+
expansion = 1
|
13 |
+
|
14 |
+
def __init__(self, in_planes, planes, stride=1):
|
15 |
+
super(BasicBlock, self).__init__()
|
16 |
+
|
17 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False)
|
21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
+
|
23 |
+
self.shortcut = nn.Sequential()
|
24 |
+
if stride != 1 or in_planes != self.expansion*planes:
|
25 |
+
self.shortcut = nn.Sequential(
|
26 |
+
nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False),
|
27 |
+
nn.BatchNorm2d(self.expansion*planes)
|
28 |
+
)
|
29 |
+
def forward(self, x):
|
30 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
31 |
+
out = self.bn2(self.conv2(out))
|
32 |
+
out += self.shortcut(x)
|
33 |
+
out = F.relu(out)
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class _ResNet(nn.Module):
|
39 |
+
def __init__(self, block, num_blocks, num_classes=10):
|
40 |
+
super().__init__()
|
41 |
+
self.in_planes = 64
|
42 |
+
|
43 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
44 |
+
self.bn1 = nn.BatchNorm2d(64)
|
45 |
+
|
46 |
+
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
|
47 |
+
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
|
48 |
+
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
|
49 |
+
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
|
50 |
+
|
51 |
+
self.linear = nn.Linear(512 * block.expansion, num_classes)
|
52 |
+
|
53 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
54 |
+
strides = [stride] + [1]*(num_blocks-1)
|
55 |
+
layers = []
|
56 |
+
for stride in strides:
|
57 |
+
layers.append(block(self.in_planes, planes, stride))
|
58 |
+
self.in_planes = planes * block.expansion
|
59 |
+
return nn.Sequential(*layers)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
63 |
+
out = self.layer1(out)
|
64 |
+
out = self.layer2(out)
|
65 |
+
out = self.layer3(out)
|
66 |
+
out = self.layer4(out)
|
67 |
+
out = F.avg_pool2d(out, 4)
|
68 |
+
out = out.view(out.size(0), -1)
|
69 |
+
out = self.linear(out)
|
70 |
+
return out
|
71 |
+
|
72 |
+
|
73 |
+
@register
|
74 |
+
class MResNet(nn.Module):
|
75 |
+
def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None:
|
76 |
+
super().__init__()
|
77 |
+
self.model = _ResNet(BasicBlock, num_blocks, num_classes)
|
78 |
+
|
79 |
+
def forward(self, x):
|
80 |
+
return self.model(x)
|
81 |
+
|
src/nn/backbone/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py
|
3 |
+
|
4 |
+
by lyuwenyu
|
5 |
+
"""
|
6 |
+
|
7 |
+
from collections import OrderedDict
|
8 |
+
from typing import Dict, List
|
9 |
+
|
10 |
+
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
|
14 |
+
class IntermediateLayerGetter(nn.ModuleDict):
|
15 |
+
"""
|
16 |
+
Module wrapper that returns intermediate layers from a model
|
17 |
+
|
18 |
+
It has a strong assumption that the modules have been registered
|
19 |
+
into the model in the same order as they are used.
|
20 |
+
This means that one should **not** reuse the same nn.Module
|
21 |
+
twice in the forward if you want this to work.
|
22 |
+
|
23 |
+
Additionally, it is only able to query submodules that are directly
|
24 |
+
assigned to the model. So if `model` is passed, `model.feature1` can
|
25 |
+
be returned, but not `model.feature1.layer2`.
|
26 |
+
"""
|
27 |
+
|
28 |
+
_version = 3
|
29 |
+
|
30 |
+
def __init__(self, model: nn.Module, return_layers: List[str]) -> None:
|
31 |
+
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
|
32 |
+
raise ValueError("return_layers are not present in model. {}"\
|
33 |
+
.format([name for name, _ in model.named_children()]))
|
34 |
+
orig_return_layers = return_layers
|
35 |
+
return_layers = {str(k): str(k) for k in return_layers}
|
36 |
+
layers = OrderedDict()
|
37 |
+
for name, module in model.named_children():
|
38 |
+
layers[name] = module
|
39 |
+
if name in return_layers:
|
40 |
+
del return_layers[name]
|
41 |
+
if not return_layers:
|
42 |
+
break
|
43 |
+
|
44 |
+
super().__init__(layers)
|
45 |
+
self.return_layers = orig_return_layers
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
# out = OrderedDict()
|
49 |
+
outputs = []
|
50 |
+
for name, module in self.items():
|
51 |
+
x = module(x)
|
52 |
+
if name in self.return_layers:
|
53 |
+
# out_name = self.return_layers[name]
|
54 |
+
# out[out_name] = x
|
55 |
+
outputs.append(x)
|
56 |
+
|
57 |
+
return outputs
|
58 |
+
|