diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb1033d22edfa3c5dcb167c5ee21fd4ed99523f --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,5 @@ + +from . import data +from . import nn +from . import optim +from . import zoo diff --git a/src/__pycache__/__init__.cpython-310.pyc b/src/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53a43d914de1be5f40ef8ae2d6ecf4c641c4bf0f Binary files /dev/null and b/src/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..35c455c63d4fbb2bbf85a83bc3cadec9913335a8 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,7 @@ +"""by lyuwenyu +""" + +# from .yaml_utils import register, create, load_config, merge_config, merge_dict +from .yaml_utils import * +from .config import BaseConfig +from .yaml_config import YAMLConfig diff --git a/src/core/__pycache__/__init__.cpython-310.pyc b/src/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421eea2819987662712b8e4c2be66c9011c5bf76 Binary files /dev/null and b/src/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/core/__pycache__/config.cpython-310.pyc b/src/core/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..682320aaef33dca015a95e529f1c95f9ab84468a Binary files /dev/null and b/src/core/__pycache__/config.cpython-310.pyc differ diff --git a/src/core/__pycache__/yaml_config.cpython-310.pyc b/src/core/__pycache__/yaml_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20b460caa67bd38069fed297121aa346aed910f0 Binary files /dev/null and b/src/core/__pycache__/yaml_config.cpython-310.pyc differ diff --git a/src/core/__pycache__/yaml_utils.cpython-310.pyc b/src/core/__pycache__/yaml_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6be3540a559cba8a35ca7ca575642e4a56b3f8cf Binary files /dev/null and b/src/core/__pycache__/yaml_utils.cpython-310.pyc differ diff --git a/src/core/config.py b/src/core/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cf803ef56702c00091fa6aa009e0c1367b992e44 --- /dev/null +++ b/src/core/config.py @@ -0,0 +1,264 @@ +"""by lyuwenyu +""" + +from pprint import pprint +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.cuda.amp.grad_scaler import GradScaler + +from typing import Callable, List, Dict + + +__all__ = ['BaseConfig', ] + + + +class BaseConfig(object): + # TODO property + + + def __init__(self) -> None: + super().__init__() + + self.task :str = None + + self._model :nn.Module = None + self._postprocessor :nn.Module = None + self._criterion :nn.Module = None + self._optimizer :Optimizer = None + self._lr_scheduler :LRScheduler = None + self._train_dataloader :DataLoader = None + self._val_dataloader :DataLoader = None + self._ema :nn.Module = None + self._scaler :GradScaler = None + + self.train_dataset :Dataset = None + self.val_dataset :Dataset = None + self.num_workers :int = 0 + self.collate_fn :Callable = None + + self.batch_size :int = None + self._train_batch_size :int = None + self._val_batch_size :int = None + self._train_shuffle: bool = None + self._val_shuffle: bool = None + + self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None + + # runtime + self.resume :str = None + self.tuning :str = None + + self.epoches :int = None + self.last_epoch :int = -1 + self.end_epoch :int = None + + self.use_amp :bool = False + self.use_ema :bool = False + self.sync_bn :bool = False + self.clip_max_norm : float = None + self.find_unused_parameters :bool = None + # self.ema_decay: float = 0.9999 + # self.grad_clip_: Callable = None + + self.log_dir :str = './logs/' + self.log_step :int = 10 + self._output_dir :str = None + self._print_freq :int = None + self.checkpoint_step :int = 1 + + # self.device :str = torch.device('cpu') + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(device) + + + @property + def model(self, ) -> nn.Module: + return self._model + + @model.setter + def model(self, m): + assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' + self._model = m + + @property + def postprocessor(self, ) -> nn.Module: + return self._postprocessor + + @postprocessor.setter + def postprocessor(self, m): + assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' + self._postprocessor = m + + @property + def criterion(self, ) -> nn.Module: + return self._criterion + + @criterion.setter + def criterion(self, m): + assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' + self._criterion = m + + @property + def optimizer(self, ) -> Optimizer: + return self._optimizer + + @optimizer.setter + def optimizer(self, m): + assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class' + self._optimizer = m + + @property + def lr_scheduler(self, ) -> LRScheduler: + return self._lr_scheduler + + @lr_scheduler.setter + def lr_scheduler(self, m): + assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class' + self._lr_scheduler = m + + + @property + def train_dataloader(self): + if self._train_dataloader is None and self.train_dataset is not None: + loader = DataLoader(self.train_dataset, + batch_size=self.train_batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + shuffle=self.train_shuffle, ) + loader.shuffle = self.train_shuffle + self._train_dataloader = loader + + return self._train_dataloader + + @train_dataloader.setter + def train_dataloader(self, loader): + self._train_dataloader = loader + + @property + def val_dataloader(self): + if self._val_dataloader is None and self.val_dataset is not None: + loader = DataLoader(self.val_dataset, + batch_size=self.val_batch_size, + num_workers=self.num_workers, + drop_last=False, + collate_fn=self.collate_fn, + shuffle=self.val_shuffle) + loader.shuffle = self.val_shuffle + self._val_dataloader = loader + + return self._val_dataloader + + @val_dataloader.setter + def val_dataloader(self, loader): + self._val_dataloader = loader + + + # TODO method + # @property + # def ema(self, ) -> nn.Module: + # if self._ema is None and self.use_ema and self.model is not None: + # self._ema = ModelEMA(self.model, self.ema_decay) + # return self._ema + + @property + def ema(self, ) -> nn.Module: + return self._ema + + @ema.setter + def ema(self, obj): + self._ema = obj + + + @property + def scaler(self) -> GradScaler: + if self._scaler is None and self.use_amp and torch.cuda.is_available(): + self._scaler = GradScaler() + return self._scaler + + @scaler.setter + def scaler(self, obj: GradScaler): + self._scaler = obj + + + @property + def val_shuffle(self): + if self._val_shuffle is None: + print('warning: set default val_shuffle=False') + return False + return self._val_shuffle + + @val_shuffle.setter + def val_shuffle(self, shuffle): + assert isinstance(shuffle, bool), 'shuffle must be bool' + self._val_shuffle = shuffle + + @property + def train_shuffle(self): + if self._train_shuffle is None: + print('warning: set default train_shuffle=True') + return True + return self._train_shuffle + + @train_shuffle.setter + def train_shuffle(self, shuffle): + assert isinstance(shuffle, bool), 'shuffle must be bool' + self._train_shuffle = shuffle + + + @property + def train_batch_size(self): + if self._train_batch_size is None and isinstance(self.batch_size, int): + print(f'warning: set train_batch_size=batch_size={self.batch_size}') + return self.batch_size + return self._train_batch_size + + @train_batch_size.setter + def train_batch_size(self, batch_size): + assert isinstance(batch_size, int), 'batch_size must be int' + self._train_batch_size = batch_size + + @property + def val_batch_size(self): + if self._val_batch_size is None: + print(f'warning: set val_batch_size=batch_size={self.batch_size}') + return self.batch_size + return self._val_batch_size + + @val_batch_size.setter + def val_batch_size(self, batch_size): + assert isinstance(batch_size, int), 'batch_size must be int' + self._val_batch_size = batch_size + + + @property + def output_dir(self): + if self._output_dir is None: + return self.log_dir + return self._output_dir + + @output_dir.setter + def output_dir(self, root): + self._output_dir = root + + @property + def print_freq(self): + if self._print_freq is None: + # self._print_freq = self.log_step + return self.log_step + return self._print_freq + + @print_freq.setter + def print_freq(self, n): + assert isinstance(n, int), 'print_freq must be int' + self._print_freq = n + + + # def __repr__(self) -> str: + # pass + + + diff --git a/src/core/yaml_config.py b/src/core/yaml_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8f7ef108e48b730bbff18ddcc299a925a8a5bf --- /dev/null +++ b/src/core/yaml_config.py @@ -0,0 +1,152 @@ +"""by lyuwenyu +""" + +import torch +import torch.nn as nn + +import re +import copy + +from .config import BaseConfig +from .yaml_utils import load_config, merge_config, create, merge_dict + + +class YAMLConfig(BaseConfig): + def __init__(self, cfg_path: str, **kwargs) -> None: + super().__init__() + + cfg = load_config(cfg_path) + merge_dict(cfg, kwargs) + + # pprint(cfg) + + self.yaml_cfg = cfg + + self.log_step = cfg.get('log_step', 100) + self.checkpoint_step = cfg.get('checkpoint_step', 1) + self.epoches = cfg.get('epoches', -1) + self.resume = cfg.get('resume', '') + self.tuning = cfg.get('tuning', '') + self.sync_bn = cfg.get('sync_bn', False) + self.output_dir = cfg.get('output_dir', None) + + self.use_ema = cfg.get('use_ema', False) + self.use_amp = cfg.get('use_amp', False) + self.autocast = cfg.get('autocast', dict()) + self.find_unused_parameters = cfg.get('find_unused_parameters', None) + self.clip_max_norm = cfg.get('clip_max_norm', 0.) + + + @property + def model(self, ) -> torch.nn.Module: + if self._model is None and 'model' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._model = create(self.yaml_cfg['model']) + return self._model + + @property + def postprocessor(self, ) -> torch.nn.Module: + if self._postprocessor is None and 'postprocessor' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._postprocessor = create(self.yaml_cfg['postprocessor']) + return self._postprocessor + + @property + def criterion(self, ): + if self._criterion is None and 'criterion' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._criterion = create(self.yaml_cfg['criterion']) + return self._criterion + + + @property + def optimizer(self, ): + if self._optimizer is None and 'optimizer' in self.yaml_cfg: + merge_config(self.yaml_cfg) + params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model) + self._optimizer = create('optimizer', params=params) + + return self._optimizer + + @property + def lr_scheduler(self, ): + if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer) + print('Initial lr: ', self._lr_scheduler.get_last_lr()) + + return self._lr_scheduler + + @property + def train_dataloader(self, ): + if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._train_dataloader = create('train_dataloader') + self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False) + + return self._train_dataloader + + @property + def val_dataloader(self, ): + if self._val_dataloader is None and 'val_dataloader' in self.yaml_cfg: + merge_config(self.yaml_cfg) + self._val_dataloader = create('val_dataloader') + self._val_dataloader.shuffle = self.yaml_cfg['val_dataloader'].get('shuffle', False) + + return self._val_dataloader + + + @property + def ema(self, ): + if self._ema is None and self.yaml_cfg.get('use_ema', False): + merge_config(self.yaml_cfg) + self._ema = create('ema', model=self.model) + + return self._ema + + + @property + def scaler(self, ): + if self._scaler is None and self.yaml_cfg.get('use_amp', False): + merge_config(self.yaml_cfg) + self._scaler = create('scaler') + + return self._scaler + + + @staticmethod + def get_optim_params(cfg: dict, model: nn.Module): + ''' + E.g.: + ^(?=.*a)(?=.*b).*$ means including a and b + ^((?!b.)*a((?!b).)*$ means including a but not b + ^((?!b|c).)*a((?!b|c).)*$ means including a but not (b | c) + ''' + assert 'type' in cfg, '' + cfg = copy.deepcopy(cfg) + + if 'params' not in cfg: + return model.parameters() + + assert isinstance(cfg['params'], list), '' + + param_groups = [] + visited = [] + for pg in cfg['params']: + pattern = pg['params'] + params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0} + pg['params'] = params.values() + param_groups.append(pg) + visited.extend(list(params.keys())) + + names = [k for k, v in model.named_parameters() if v.requires_grad] + + if len(visited) < len(names): + unseen = set(names) - set(visited) + params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen} + param_groups.append({'params': params.values()}) + visited.extend(list(params.keys())) + + assert len(visited) == len(names), '' + + return param_groups diff --git a/src/core/yaml_utils.py b/src/core/yaml_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c9ed25902cfbec49567dbf36dd99554787cd8b14 --- /dev/null +++ b/src/core/yaml_utils.py @@ -0,0 +1,208 @@ +""""by lyuwenyu +""" + +import os +import yaml +import inspect +import importlib + +__all__ = ['GLOBAL_CONFIG', 'register', 'create', 'load_config', 'merge_config', 'merge_dict'] + + +GLOBAL_CONFIG = dict() +INCLUDE_KEY = '__include__' + + +def register(cls: type): + ''' + Args: + cls (type): Module class to be registered. + ''' + if cls.__name__ in GLOBAL_CONFIG: + raise ValueError('{} already registered'.format(cls.__name__)) + + if inspect.isfunction(cls): + GLOBAL_CONFIG[cls.__name__] = cls + + elif inspect.isclass(cls): + GLOBAL_CONFIG[cls.__name__] = extract_schema(cls) + + else: + raise ValueError(f'register {cls}') + + return cls + + +def extract_schema(cls: type): + ''' + Args: + cls (type), + Return: + Dict, + ''' + argspec = inspect.getfullargspec(cls.__init__) + arg_names = [arg for arg in argspec.args if arg != 'self'] + num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0 + num_requires = len(arg_names) - num_defualts + + schame = dict() + schame['_name'] = cls.__name__ + schame['_pymodule'] = importlib.import_module(cls.__module__) + schame['_inject'] = getattr(cls, '__inject__', []) + schame['_share'] = getattr(cls, '__share__', []) + + for i, name in enumerate(arg_names): + if name in schame['_share']: + assert i >= num_requires, 'share config must have default value.' + value = argspec.defaults[i - num_requires] + + elif i >= num_requires: + value = argspec.defaults[i - num_requires] + + else: + value = None + + schame[name] = value + + return schame + + + +def create(type_or_name, **kwargs): + ''' + ''' + assert type(type_or_name) in (type, str), 'create should be class or name.' + + name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__ + + if name in GLOBAL_CONFIG: + if hasattr(GLOBAL_CONFIG[name], '__dict__'): + return GLOBAL_CONFIG[name] + else: + raise ValueError('The module {} is not registered'.format(name)) + + cfg = GLOBAL_CONFIG[name] + + if isinstance(cfg, dict) and 'type' in cfg: + _cfg: dict = GLOBAL_CONFIG[cfg['type']] + _cfg.update(cfg) # update global cls default args + _cfg.update(kwargs) # TODO + name = _cfg.pop('type') + + return create(name) + + + cls = getattr(cfg['_pymodule'], name) + argspec = inspect.getfullargspec(cls.__init__) + arg_names = [arg for arg in argspec.args if arg != 'self'] + + cls_kwargs = {} + cls_kwargs.update(cfg) + + # shared var + for k in cfg['_share']: + if k in GLOBAL_CONFIG: + cls_kwargs[k] = GLOBAL_CONFIG[k] + else: + cls_kwargs[k] = cfg[k] + + # inject + for k in cfg['_inject']: + _k = cfg[k] + + if _k is None: + continue + + if isinstance(_k, str): + if _k not in GLOBAL_CONFIG: + raise ValueError(f'Missing inject config of {_k}.') + + _cfg = GLOBAL_CONFIG[_k] + + if isinstance(_cfg, dict): + cls_kwargs[k] = create(_cfg['_name']) + else: + cls_kwargs[k] = _cfg + + elif isinstance(_k, dict): + if 'type' not in _k.keys(): + raise ValueError(f'Missing inject for `type` style.') + + _type = str(_k['type']) + if _type not in GLOBAL_CONFIG: + raise ValueError(f'Missing {_type} in inspect stage.') + + # TODO modified inspace, maybe get wrong result for using `> 1` + _cfg: dict = GLOBAL_CONFIG[_type] + # _cfg_copy = copy.deepcopy(_cfg) + _cfg.update(_k) # update + cls_kwargs[k] = create(_type) + # _cfg.update(_cfg_copy) # resume + + else: + raise ValueError(f'Inject does not support {_k}') + + + cls_kwargs = {n: cls_kwargs[n] for n in arg_names} + + return cls(**cls_kwargs) + + + +def load_config(file_path, cfg=dict()): + '''load config + ''' + _, ext = os.path.splitext(file_path) + assert ext in ['.yml', '.yaml'], "only support yaml files for now" + + with open(file_path) as f: + file_cfg = yaml.load(f, Loader=yaml.Loader) + if file_cfg is None: + return {} + + if INCLUDE_KEY in file_cfg: + base_yamls = list(file_cfg[INCLUDE_KEY]) + for base_yaml in base_yamls: + if base_yaml.startswith('~'): + base_yaml = os.path.expanduser(base_yaml) + + if not base_yaml.startswith('/'): + base_yaml = os.path.join(os.path.dirname(file_path), base_yaml) + + with open(base_yaml) as f: + base_cfg = load_config(base_yaml, cfg) + merge_config(base_cfg, cfg) + + return merge_config(file_cfg, cfg) + + + +def merge_dict(dct, another_dct): + '''merge another_dct into dct + ''' + for k in another_dct: + if (k in dct and isinstance(dct[k], dict) and isinstance(another_dct[k], dict)): + merge_dict(dct[k], another_dct[k]) + else: + dct[k] = another_dct[k] + + return dct + + + +def merge_config(config, another_cfg=None): + """ + Merge config into global config or another_cfg. + + Args: + config (dict): Config to be merged. + + Returns: global config + """ + global GLOBAL_CONFIG + dct = GLOBAL_CONFIG if another_cfg is None else another_cfg + + return merge_dict(dct, config) + + + diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95715f8a76937758b2c5ec9d121fc069fddcbabb --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,7 @@ + +from .coco import * +from .cifar10 import CIFAR10 + +from .dataloader import * +from .transforms import * + diff --git a/src/data/__pycache__/__init__.cpython-310.pyc b/src/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..728b72a8f227c671ddcb702953e6919a8df26edb Binary files /dev/null and b/src/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/data/__pycache__/dataloader.cpython-310.pyc b/src/data/__pycache__/dataloader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc43a3505bcb61a2ba33772cf2659e8c7fd108f6 Binary files /dev/null and b/src/data/__pycache__/dataloader.cpython-310.pyc differ diff --git a/src/data/__pycache__/transforms.cpython-310.pyc b/src/data/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0ded5607da300be686c025c918cd5f74e60ab9 Binary files /dev/null and b/src/data/__pycache__/transforms.cpython-310.pyc differ diff --git a/src/data/cifar10/__init__.py b/src/data/cifar10/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5267dccd21c6c4371c14ff0a04b064f608dfc14 --- /dev/null +++ b/src/data/cifar10/__init__.py @@ -0,0 +1,14 @@ + +import torchvision +from typing import Optional, Callable + +from src.core import register + + +@register +class CIFAR10(torchvision.datasets.CIFAR10): + __inject__ = ['transform', 'target_transform'] + + def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False) -> None: + super().__init__(root, train, transform, target_transform, download) + diff --git a/src/data/cifar10/__pycache__/__init__.cpython-310.pyc b/src/data/cifar10/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc91d6c52e59340ccace02e97c6671d55112270c Binary files /dev/null and b/src/data/cifar10/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/data/coco/__init__.py b/src/data/coco/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c83b002187885f1571556b16e1c3632f03d68a0a --- /dev/null +++ b/src/data/coco/__init__.py @@ -0,0 +1,9 @@ +from .coco_dataset import ( + CocoDetection, + mscoco_category2label, + mscoco_label2category, + mscoco_category2name, +) +from .coco_eval import * + +from .coco_utils import get_coco_api_from_dataset \ No newline at end of file diff --git a/src/data/coco/__pycache__/__init__.cpython-310.pyc b/src/data/coco/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59afeb4587966d6aad4c2ae35c2c3685de291606 Binary files /dev/null and b/src/data/coco/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/data/coco/__pycache__/coco_dataset.cpython-310.pyc b/src/data/coco/__pycache__/coco_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d629862dfd2b09122141c3e08b8880e12de3cc0c Binary files /dev/null and b/src/data/coco/__pycache__/coco_dataset.cpython-310.pyc differ diff --git a/src/data/coco/__pycache__/coco_eval.cpython-310.pyc b/src/data/coco/__pycache__/coco_eval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4440587b2f0a6dee0dbe69ca178e87b8376520ce Binary files /dev/null and b/src/data/coco/__pycache__/coco_eval.cpython-310.pyc differ diff --git a/src/data/coco/__pycache__/coco_utils.cpython-310.pyc b/src/data/coco/__pycache__/coco_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b56f5294a0d8a67969e06c1c27417e30ba043624 Binary files /dev/null and b/src/data/coco/__pycache__/coco_utils.cpython-310.pyc differ diff --git a/src/data/coco/coco_dataset.py b/src/data/coco/coco_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef78498d753a76538651148f0692b1515173149 --- /dev/null +++ b/src/data/coco/coco_dataset.py @@ -0,0 +1,238 @@ +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +COCO dataset which returns image_id for evaluation. +Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py +""" + +import torch +import torch.utils.data + +import torchvision +torchvision.disable_beta_transforms_warning() + +from torchvision import datapoints + +from pycocotools import mask as coco_mask + +from src.core import register + +__all__ = ['CocoDetection'] + + +@register +class CocoDetection(torchvision.datasets.CocoDetection): + __inject__ = ['transforms'] + __share__ = ['remap_mscoco_category'] + + def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category) + self.img_folder = img_folder + self.ann_file = ann_file + self.return_masks = return_masks + self.remap_mscoco_category = remap_mscoco_category + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = {'image_id': image_id, 'annotations': target} + img, target = self.prepare(img, target) + + # ['boxes', 'masks', 'labels']: + if 'boxes' in target: + target['boxes'] = datapoints.BoundingBox( + target['boxes'], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=img.size[::-1]) # h w + + if 'masks' in target: + target['masks'] = datapoints.Mask(target['masks']) + + if self._transforms is not None: + img, target = self._transforms(img, target) + + return img, target + + def extra_repr(self) -> str: + s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n' + s += f' return_masks: {self.return_masks}\n' + if hasattr(self, '_transforms') and self._transforms is not None: + s += f' transforms:\n {repr(self._transforms)}' + + return s + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __init__(self, return_masks=False, remap_mscoco_category=False): + self.return_masks = return_masks + self.remap_mscoco_category = remap_mscoco_category + + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + if self.remap_mscoco_category: + classes = [mscoco_category2label[obj["category_id"]] for obj in anno] + else: + classes = [obj["category_id"] for obj in anno] + + classes = torch.tensor(classes, dtype=torch.int64) + + if self.return_masks: + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + if self.return_masks: + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + if self.return_masks: + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) + target["area"] = area[keep] + target["iscrowd"] = iscrowd[keep] + + target["orig_size"] = torch.as_tensor([int(w), int(h)]) + target["size"] = torch.as_tensor([int(w), int(h)]) + + return image, target + + +mscoco_category2name = { + 1: 'person', + 2: 'bicycle', + 3: 'car', + 4: 'motorcycle', + 5: 'airplane', + 6: 'bus', + 7: 'train', + 8: 'truck', + 9: 'boat', + 10: 'traffic light', + 11: 'fire hydrant', + 13: 'stop sign', + 14: 'parking meter', + 15: 'bench', + 16: 'bird', + 17: 'cat', + 18: 'dog', + 19: 'horse', + 20: 'sheep', + 21: 'cow', + 22: 'elephant', + 23: 'bear', + 24: 'zebra', + 25: 'giraffe', + 27: 'backpack', + 28: 'umbrella', + 31: 'handbag', + 32: 'tie', + 33: 'suitcase', + 34: 'frisbee', + 35: 'skis', + 36: 'snowboard', + 37: 'sports ball', + 38: 'kite', + 39: 'baseball bat', + 40: 'baseball glove', + 41: 'skateboard', + 42: 'surfboard', + 43: 'tennis racket', + 44: 'bottle', + 46: 'wine glass', + 47: 'cup', + 48: 'fork', + 49: 'knife', + 50: 'spoon', + 51: 'bowl', + 52: 'banana', + 53: 'apple', + 54: 'sandwich', + 55: 'orange', + 56: 'broccoli', + 57: 'carrot', + 58: 'hot dog', + 59: 'pizza', + 60: 'donut', + 61: 'cake', + 62: 'chair', + 63: 'couch', + 64: 'potted plant', + 65: 'bed', + 67: 'dining table', + 70: 'toilet', + 72: 'tv', + 73: 'laptop', + 74: 'mouse', + 75: 'remote', + 76: 'keyboard', + 77: 'cell phone', + 78: 'microwave', + 79: 'oven', + 80: 'toaster', + 81: 'sink', + 82: 'refrigerator', + 84: 'book', + 85: 'clock', + 86: 'vase', + 87: 'scissors', + 88: 'teddy bear', + 89: 'hair drier', + 90: 'toothbrush' +} + +mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} +mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} \ No newline at end of file diff --git a/src/data/coco/coco_eval.py b/src/data/coco/coco_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..2d629f5aab011357918ef7303a3dab39e6be4b49 --- /dev/null +++ b/src/data/coco/coco_eval.py @@ -0,0 +1,269 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import os +import contextlib +import copy +import numpy as np +import torch + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from src.misc import dist + + +__all__ = ['CocoEvaluator',] + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, 'w') as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = dist.all_gather(img_ids) + all_eval_imgs = dist.all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +# import io +# from contextlib import redirect_stdout +# def evaluate(imgs): +# with redirect_stdout(io.StringIO()): +# imgs.evaluate() +# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds)) + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# + diff --git a/src/data/coco/coco_utils.py b/src/data/coco/coco_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48c099474c63e08a30c124fbfc07082edf9feb49 --- /dev/null +++ b/src/data/coco/coco_utils.py @@ -0,0 +1,184 @@ +import os + +import torch +import torch.utils.data +import torchvision +from pycocotools import mask as coco_mask +from pycocotools.coco import COCO + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask: + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + + anno = target["annotations"] + + anno = [obj for obj in anno if obj["iscrowd"] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + + return image, target + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + min_keypoints_per_image = 10 + + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different criteria for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def convert_to_coco_api(ds): + coco_ds = COCO() + # annotation IDs need to start at 1, not 0, see torchvision issue #1530 + ann_id = 1 + dataset = {"images": [], "categories": [], "annotations": []} + categories = set() + for img_idx in range(len(ds)): + # find better way to get target + # targets = ds.get_annotations(img_idx) + img, targets = ds[img_idx] + image_id = targets["image_id"].item() + img_dict = {} + img_dict["id"] = image_id + img_dict["height"] = img.shape[-2] + img_dict["width"] = img.shape[-1] + dataset["images"].append(img_dict) + bboxes = targets["boxes"].clone() + bboxes[:, 2:] -= bboxes[:, :2] + bboxes = bboxes.tolist() + labels = targets["labels"].tolist() + areas = targets["area"].tolist() + iscrowd = targets["iscrowd"].tolist() + if "masks" in targets: + masks = targets["masks"] + # make masks Fortran contiguous for coco_mask + masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) + if "keypoints" in targets: + keypoints = targets["keypoints"] + keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() + num_objs = len(bboxes) + for i in range(num_objs): + ann = {} + ann["image_id"] = image_id + ann["bbox"] = bboxes[i] + ann["category_id"] = labels[i] + categories.add(labels[i]) + ann["area"] = areas[i] + ann["iscrowd"] = iscrowd[i] + ann["id"] = ann_id + if "masks" in targets: + ann["segmentation"] = coco_mask.encode(masks[i].numpy()) + if "keypoints" in targets: + ann["keypoints"] = keypoints[i] + ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) + dataset["annotations"].append(ann) + ann_id += 1 + dataset["categories"] = [{"id": i} for i in sorted(categories)] + coco_ds.dataset = dataset + coco_ds.createIndex() + return coco_ds + + +def get_coco_api_from_dataset(dataset): + # FIXME: This is... awful? + for _ in range(10): + if isinstance(dataset, torchvision.datasets.CocoDetection): + break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + return convert_to_coco_api(dataset) + + diff --git a/src/data/dataloader.py b/src/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..4db7cadf307780d946bd082bf35aba455cac6816 --- /dev/null +++ b/src/data/dataloader.py @@ -0,0 +1,28 @@ +import torch +import torch.utils.data as data + +from src.core import register + + +__all__ = ['DataLoader'] + + +@register +class DataLoader(data.DataLoader): + __inject__ = ['dataset', 'collate_fn'] + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']: + format_string += "\n" + format_string += " {0}: {1}".format(n, getattr(self, n)) + format_string += "\n)" + return format_string + + + +@register +def default_collate_fn(items): + '''default collate_fn + ''' + return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items] diff --git a/src/data/functional.py b/src/data/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..336baa2ee632591a00db3733c389979f3b454348 --- /dev/null +++ b/src/data/functional.py @@ -0,0 +1,169 @@ +import torch +import torchvision.transforms.functional as F + +from packaging import version +from typing import Optional, List +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if version.parse(torchvision.__version__) < version.parse('0.7'): + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) + + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + # r = min(size / min(h, w), max_size / max(h, w)) + # ow = int(w * r) + # oh = int(h * r) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image.size[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target diff --git a/src/data/transforms.py b/src/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd3945cb5b7444c5e41bbe68f290e2b7e0781ce --- /dev/null +++ b/src/data/transforms.py @@ -0,0 +1,142 @@ +""""by lyuwenyu +""" + + +import torch +import torch.nn as nn + +import torchvision +torchvision.disable_beta_transforms_warning() +from torchvision import datapoints + +import torchvision.transforms.v2 as T +import torchvision.transforms.v2.functional as F + +from PIL import Image +from typing import Any, Dict, List, Optional + +from src.core import register, GLOBAL_CONFIG + + +__all__ = ['Compose', ] + + +RandomPhotometricDistort = register(T.RandomPhotometricDistort) +RandomZoomOut = register(T.RandomZoomOut) +# RandomIoUCrop = register(T.RandomIoUCrop) +RandomHorizontalFlip = register(T.RandomHorizontalFlip) +Resize = register(T.Resize) +ToImageTensor = register(T.ToImageTensor) +ConvertDtype = register(T.ConvertDtype) +SanitizeBoundingBox = register(T.SanitizeBoundingBox) +RandomCrop = register(T.RandomCrop) +Normalize = register(T.Normalize) + + + +@register +class Compose(T.Compose): + def __init__(self, ops) -> None: + transforms = [] + if ops is not None: + for op in ops: + if isinstance(op, dict): + name = op.pop('type') + transfom = getattr(GLOBAL_CONFIG[name]['_pymodule'], name)(**op) + transforms.append(transfom) + # op['type'] = name + elif isinstance(op, nn.Module): + transforms.append(op) + + else: + raise ValueError('') + else: + transforms =[EmptyTransform(), ] + + super().__init__(transforms=transforms) + + +@register +class EmptyTransform(T.Transform): + def __init__(self, ) -> None: + super().__init__() + + def forward(self, *inputs): + inputs = inputs if len(inputs) > 1 else inputs[0] + return inputs + + +@register +class PadToSize(T.Pad): + _transformed_types = ( + Image.Image, + datapoints.Image, + datapoints.Video, + datapoints.Mask, + datapoints.BoundingBox, + ) + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + sz = F.get_spatial_size(flat_inputs[0]) + h, w = self.spatial_size[0] - sz[0], self.spatial_size[1] - sz[1] + self.padding = [0, 0, w, h] + return dict(padding=self.padding) + + def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None: + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + self.spatial_size = spatial_size + super().__init__(0, fill, padding_mode) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + fill = self._fill[type(inpt)] + padding = params['padding'] + return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + + def __call__(self, *inputs: Any) -> Any: + outputs = super().forward(*inputs) + if len(outputs) > 1 and isinstance(outputs[1], dict): + outputs[1]['padding'] = torch.tensor(self.padding) + return outputs + + +@register +class RandomIoUCrop(T.RandomIoUCrop): + def __init__(self, min_scale: float = 0.3, max_scale: float = 1, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2, sampler_options: Optional[List[float]] = None, trials: int = 40, p: float = 1.0): + super().__init__(min_scale, max_scale, min_aspect_ratio, max_aspect_ratio, sampler_options, trials) + self.p = p + + def __call__(self, *inputs: Any) -> Any: + if torch.rand(1) >= self.p: + return inputs if len(inputs) > 1 else inputs[0] + + return super().forward(*inputs) + + +@register +class ConvertBox(T.Transform): + _transformed_types = ( + datapoints.BoundingBox, + ) + def __init__(self, out_fmt='', normalize=False) -> None: + super().__init__() + self.out_fmt = out_fmt + self.normalize = normalize + + self.data_fmt = { + 'xyxy': datapoints.BoundingBoxFormat.XYXY, + 'cxcywh': datapoints.BoundingBoxFormat.CXCYWH + } + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if self.out_fmt: + spatial_size = inpt.spatial_size + in_fmt = inpt.format.value.lower() + inpt = torchvision.ops.box_convert(inpt, in_fmt=in_fmt, out_fmt=self.out_fmt) + inpt = datapoints.BoundingBox(inpt, format=self.data_fmt[self.out_fmt], spatial_size=spatial_size) + + if self.normalize: + inpt = inpt / torch.tensor(inpt.spatial_size[::-1]).tile(2)[None] + + return inpt + diff --git a/src/misc/__init__.py b/src/misc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..802b61ebff92ca70d0106ecbadfa206f0b79d964 --- /dev/null +++ b/src/misc/__init__.py @@ -0,0 +1,3 @@ + +from .logger import * +from .visualizer import * diff --git a/src/misc/__pycache__/__init__.cpython-310.pyc b/src/misc/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..595da0f5dd749727cc81626c0d4dfa2b7b49173a Binary files /dev/null and b/src/misc/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/misc/__pycache__/dist.cpython-310.pyc b/src/misc/__pycache__/dist.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4afc70211adbda5971bb1b254563ecee00229774 Binary files /dev/null and b/src/misc/__pycache__/dist.cpython-310.pyc differ diff --git a/src/misc/__pycache__/logger.cpython-310.pyc b/src/misc/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20bd9badb2f979692e66e50157a268e06ec26297 Binary files /dev/null and b/src/misc/__pycache__/logger.cpython-310.pyc differ diff --git a/src/misc/__pycache__/visualizer.cpython-310.pyc b/src/misc/__pycache__/visualizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b9c6cdea69134b19cd12ad1b0e13b89e48f013 Binary files /dev/null and b/src/misc/__pycache__/visualizer.cpython-310.pyc differ diff --git a/src/misc/dist.py b/src/misc/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..4c547c001b9ccf62d386b18e877ee3a034a11d92 --- /dev/null +++ b/src/misc/dist.py @@ -0,0 +1,190 @@ +""" +reference +- https://github.com/pytorch/vision/blob/main/references/detection/utils.py +- https://github.com/facebookresearch/detr/blob/master/util/misc.py#L406 + +by lyuwenyu +""" + +import random +import numpy as np + +import torch +import torch.nn as nn +import torch.distributed +import torch.distributed as tdist + +from torch.nn.parallel import DistributedDataParallel as DDP + +from torch.utils.data import DistributedSampler +from torch.utils.data.dataloader import DataLoader + + +def init_distributed(): + ''' + distributed setup + args: + backend (str), ('nccl', 'gloo') + ''' + try: + # # https://pytorch.org/docs/stable/elastic/run.html + # LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) + # RANK = int(os.getenv('RANK', -1)) + # WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) + + tdist.init_process_group(init_method='env://', ) + torch.distributed.barrier() + + rank = get_rank() + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + + setup_print(rank == 0) + print('Initialized distributed mode...') + + return True + + except: + print('Not init distributed mode.') + return False + + +def setup_print(is_main): + '''This function disables printing when not in master process + ''' + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_main or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_available_and_initialized(): + if not tdist.is_available(): + return False + if not tdist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_available_and_initialized(): + return 0 + return tdist.get_rank() + + +def get_world_size(): + if not is_dist_available_and_initialized(): + return 1 + return tdist.get_world_size() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + + +def warp_model(model, find_unused_parameters=False, sync_bn=False,): + if is_dist_available_and_initialized(): + rank = get_rank() + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) if sync_bn else model + model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=find_unused_parameters) + return model + + +def warp_loader(loader, shuffle=False): + if is_dist_available_and_initialized(): + sampler = DistributedSampler(loader.dataset, shuffle=shuffle) + loader = DataLoader(loader.dataset, + loader.batch_size, + sampler=sampler, + drop_last=loader.drop_last, + collate_fn=loader.collate_fn, + pin_memory=loader.pin_memory, + num_workers=loader.num_workers, ) + return loader + + + +def is_parallel(model) -> bool: + # Returns True if model is of type DP or DDP + return type(model) in (torch.nn.parallel.DataParallel, torch.nn.parallel.DistributedDataParallel) + + +def de_parallel(model) -> nn.Module: + # De-parallelize a model: returns single-GPU model if model is of type DP or DDP + return model.module if is_parallel(model) else model + + +def reduce_dict(data, avg=True): + ''' + Args + data dict: input, {k: v, ...} + avg bool: true + ''' + world_size = get_world_size() + if world_size < 2: + return data + + with torch.no_grad(): + keys, values = [], [] + for k in sorted(data.keys()): + keys.append(k) + values.append(data[k]) + + values = torch.stack(values, dim=0) + tdist.all_reduce(values) + + if avg is True: + values /= world_size + + _data = {k: v for k, v in zip(keys, values)} + + return _data + + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + data_list = [None] * world_size + tdist.all_gather_object(data_list, data) + return data_list + + +import time +def sync_time(): + '''sync_time + ''' + if torch.cuda.is_available(): + torch.cuda.synchronize() + + return time.time() + + + +def set_seed(seed): + # fix the seed for reproducibility + seed = seed + get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + diff --git a/src/misc/logger.py b/src/misc/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..67405304dd29738a82866b0af0803d1007a661d6 --- /dev/null +++ b/src/misc/logger.py @@ -0,0 +1,239 @@ +""" +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +https://github.com/facebookresearch/detr/blob/main/util/misc.py +Mostly copy-paste from torchvision references. +""" + +import time +import pickle +import datetime +from collections import defaultdict, deque +from typing import Dict + +import torch +import torch.distributed as tdist + +from .dist import is_dist_available_and_initialized, get_world_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_available_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + tdist.barrier() + tdist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + tdist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + tdist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True) -> Dict[str, torch.Tensor]: + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + tdist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + diff --git a/src/misc/visualizer.py b/src/misc/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..843f8eb4ed5090602d9facdd9180d182d7e4f74e --- /dev/null +++ b/src/misc/visualizer.py @@ -0,0 +1,34 @@ +""""by lyuwenyu +""" + +import torch +import torch.utils.data + +import torchvision +torchvision.disable_beta_transforms_warning() + +import PIL + +__all__ = ['show_sample'] + +def show_sample(sample): + """for coco dataset/dataloader + """ + import matplotlib.pyplot as plt + from torchvision.transforms.v2 import functional as F + from torchvision.utils import draw_bounding_boxes + + image, target = sample + if isinstance(image, PIL.Image.Image): + image = F.to_image_tensor(image) + + image = F.convert_dtype(image, torch.uint8) + annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) + + fig, ax = plt.subplots() + ax.imshow(annotated_image.permute(1, 2, 0).numpy()) + ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + fig.tight_layout() + fig.show() + plt.show() + diff --git a/src/nn/__init__.py b/src/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df8a1c0891e690a53162f53fc4a14b90a1351dd --- /dev/null +++ b/src/nn/__init__.py @@ -0,0 +1,7 @@ + +from .arch import * +from .criterion import * + +# +from .backbone import * + diff --git a/src/nn/__pycache__/__init__.cpython-310.pyc b/src/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02efe3bd64fc3037a7a5f51f7b39e38943a827c7 Binary files /dev/null and b/src/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/nn/arch/__init__.py b/src/nn/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..070f19b2f344b67a194424e4f4cf9b5d824ee8f2 --- /dev/null +++ b/src/nn/arch/__init__.py @@ -0,0 +1 @@ +from .classification import * diff --git a/src/nn/arch/__pycache__/__init__.cpython-310.pyc b/src/nn/arch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035fbdd7dd72670e38a655b33378fba40869d39c Binary files /dev/null and b/src/nn/arch/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/nn/arch/__pycache__/classification.cpython-310.pyc b/src/nn/arch/__pycache__/classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af4b32605d35dbaa20718dfe10569db00d98ecf Binary files /dev/null and b/src/nn/arch/__pycache__/classification.cpython-310.pyc differ diff --git a/src/nn/arch/classification.py b/src/nn/arch/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..2f1fa568ff12517b3e9d47bf464180b16964b088 --- /dev/null +++ b/src/nn/arch/classification.py @@ -0,0 +1,41 @@ +import torch +import torch.nn as nn + +from src.core import register + + +__all__ = ['Classification', 'ClassHead'] + + +@register +class Classification(nn.Module): + __inject__ = ['backbone', 'head'] + + def __init__(self, backbone: nn.Module, head: nn.Module=None): + super().__init__() + + self.backbone = backbone + self.head = head + + def forward(self, x): + x = self.backbone(x) + + if self.head is not None: + x = self.head(x) + + return x + + +@register +class ClassHead(nn.Module): + def __init__(self, hidden_dim, num_classes): + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.proj = nn.Linear(hidden_dim, num_classes) + + def forward(self, x): + x = x[0] if isinstance(x, (list, tuple)) else x + x = self.pool(x) + x = x.reshape(x.shape[0], -1) + x = self.proj(x) + return x diff --git a/src/nn/backbone/__init__.py b/src/nn/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea44c6b430e4d814607d4a0f986463a357d781d0 --- /dev/null +++ b/src/nn/backbone/__init__.py @@ -0,0 +1,5 @@ + +from .presnet import * +from .test_resnet import * + +from .common import * \ No newline at end of file diff --git a/src/nn/backbone/__pycache__/__init__.cpython-310.pyc b/src/nn/backbone/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2121957befc8519c57baf1104c2f73c3698601 Binary files /dev/null and b/src/nn/backbone/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/nn/backbone/__pycache__/common.cpython-310.pyc b/src/nn/backbone/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d09935ee09ef7be5511d22a7357e00fde496c05 Binary files /dev/null and b/src/nn/backbone/__pycache__/common.cpython-310.pyc differ diff --git a/src/nn/backbone/__pycache__/presnet.cpython-310.pyc b/src/nn/backbone/__pycache__/presnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed61188ae41893df48d61d02f69bf5b10dda4dbd Binary files /dev/null and b/src/nn/backbone/__pycache__/presnet.cpython-310.pyc differ diff --git a/src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc b/src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cd30af8b5f734fcec2ad02351b403f17bbad72c Binary files /dev/null and b/src/nn/backbone/__pycache__/test_resnet.cpython-310.pyc differ diff --git a/src/nn/backbone/common.py b/src/nn/backbone/common.py new file mode 100644 index 0000000000000000000000000000000000000000..72e38d7d8e9f1460ad09eea8d02e16c133a6e054 --- /dev/null +++ b/src/nn/backbone/common.py @@ -0,0 +1,102 @@ +'''by lyuwenyu +''' + +import torch +import torch.nn as nn + + + +class ConvNormLayer(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): + super().__init__() + self.conv = nn.Conv2d( + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size-1)//2 if padding is None else padding, + bias=bias) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +class FrozenBatchNorm2d(nn.Module): + """copy and modified from https://github.com/facebookresearch/detr/blob/master/models/backbone.py + BatchNorm2d where the batch statistics and the affine parameters are fixed. + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + def __init__(self, num_features, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + n = num_features + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + self.num_features = n + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + def extra_repr(self): + return ( + "{num_features}, eps={eps}".format(**self.__dict__) + ) + + +def get_activation(act: str, inpace: bool=True): + '''get activation + ''' + act = act.lower() + + if act == 'silu': + m = nn.SiLU() + + elif act == 'relu': + m = nn.ReLU() + + elif act == 'leaky_relu': + m = nn.LeakyReLU() + + elif act == 'silu': + m = nn.SiLU() + + elif act == 'gelu': + m = nn.GELU() + + elif act is None: + m = nn.Identity() + + elif isinstance(act, nn.Module): + m = act + + else: + raise RuntimeError('') + + if hasattr(m, 'inplace'): + m.inplace = inpace + + return m diff --git a/src/nn/backbone/presnet.py b/src/nn/backbone/presnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6b4baa86432bf89809c65ac28743d21c7ceb38 --- /dev/null +++ b/src/nn/backbone/presnet.py @@ -0,0 +1,225 @@ +'''by lyuwenyu +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict + +from .common import get_activation, ConvNormLayer, FrozenBatchNorm2d + +from src.core import register + + +__all__ = ['PResNet'] + + +ResNet_cfg = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + # 152: [3, 8, 36, 3], +} + + +donwload_url = { + 18: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet18_vd_pretrained_from_paddle.pth', + 34: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet34_vd_pretrained_from_paddle.pth', + 50: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet50_vd_ssld_v2_pretrained_from_paddle.pth', + 101: 'https://github.com/lyuwenyu/storage/releases/download/v0.1/ResNet101_vd_ssld_pretrained_from_paddle.pth', +} + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): + super().__init__() + + self.shortcut = shortcut + + if not shortcut: + if variant == 'd' and stride == 2: + self.short = nn.Sequential(OrderedDict([ + ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ('conv', ConvNormLayer(ch_in, ch_out, 1, 1)) + ])) + else: + self.short = ConvNormLayer(ch_in, ch_out, 1, stride) + + self.branch2a = ConvNormLayer(ch_in, ch_out, 3, stride, act=act) + self.branch2b = ConvNormLayer(ch_out, ch_out, 3, 1, act=None) + self.act = nn.Identity() if act is None else get_activation(act) + + + def forward(self, x): + out = self.branch2a(x) + out = self.branch2b(out) + if self.shortcut: + short = x + else: + short = self.short(x) + + out = out + short + out = self.act(out) + + return out + + +class BottleNeck(nn.Module): + expansion = 4 + + def __init__(self, ch_in, ch_out, stride, shortcut, act='relu', variant='b'): + super().__init__() + + if variant == 'a': + stride1, stride2 = stride, 1 + else: + stride1, stride2 = 1, stride + + width = ch_out + + self.branch2a = ConvNormLayer(ch_in, width, 1, stride1, act=act) + self.branch2b = ConvNormLayer(width, width, 3, stride2, act=act) + self.branch2c = ConvNormLayer(width, ch_out * self.expansion, 1, 1) + + self.shortcut = shortcut + if not shortcut: + if variant == 'd' and stride == 2: + self.short = nn.Sequential(OrderedDict([ + ('pool', nn.AvgPool2d(2, 2, 0, ceil_mode=True)), + ('conv', ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1)) + ])) + else: + self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride) + + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + out = self.branch2a(x) + out = self.branch2b(out) + out = self.branch2c(out) + + if self.shortcut: + short = x + else: + short = self.short(x) + + out = out + short + out = self.act(out) + + return out + + +class Blocks(nn.Module): + def __init__(self, block, ch_in, ch_out, count, stage_num, act='relu', variant='b'): + super().__init__() + + self.blocks = nn.ModuleList() + for i in range(count): + self.blocks.append( + block( + ch_in, + ch_out, + stride=2 if i == 0 and stage_num != 2 else 1, + shortcut=False if i == 0 else True, + variant=variant, + act=act) + ) + + if i == 0: + ch_in = ch_out * block.expansion + + def forward(self, x): + out = x + for block in self.blocks: + out = block(out) + return out + + +@register +class PResNet(nn.Module): + def __init__( + self, + depth, + variant='d', + num_stages=4, + return_idx=[0, 1, 2, 3], + act='relu', + freeze_at=-1, + freeze_norm=True, + pretrained=False): + super().__init__() + + block_nums = ResNet_cfg[depth] + ch_in = 64 + if variant in ['c', 'd']: + conv_def = [ + [3, ch_in // 2, 3, 2, "conv1_1"], + [ch_in // 2, ch_in // 2, 3, 1, "conv1_2"], + [ch_in // 2, ch_in, 3, 1, "conv1_3"], + ] + else: + conv_def = [[3, ch_in, 7, 2, "conv1_1"]] + + self.conv1 = nn.Sequential(OrderedDict([ + (_name, ConvNormLayer(c_in, c_out, k, s, act=act)) for c_in, c_out, k, s, _name in conv_def + ])) + + ch_out_list = [64, 128, 256, 512] + block = BottleNeck if depth >= 50 else BasicBlock + + _out_channels = [block.expansion * v for v in ch_out_list] + _out_strides = [4, 8, 16, 32] + + self.res_layers = nn.ModuleList() + for i in range(num_stages): + stage_num = i + 2 + self.res_layers.append( + Blocks(block, ch_in, ch_out_list[i], block_nums[i], stage_num, act=act, variant=variant) + ) + ch_in = _out_channels[i] + + self.return_idx = return_idx + self.out_channels = [_out_channels[_i] for _i in return_idx] + self.out_strides = [_out_strides[_i] for _i in return_idx] + + if freeze_at >= 0: + self._freeze_parameters(self.conv1) + for i in range(min(freeze_at, num_stages)): + self._freeze_parameters(self.res_layers[i]) + + if freeze_norm: + self._freeze_norm(self) + + if pretrained: + state = torch.hub.load_state_dict_from_url(donwload_url[depth]) + self.load_state_dict(state) + print(f'Load PResNet{depth} state_dict') + + def _freeze_parameters(self, m: nn.Module): + for p in m.parameters(): + p.requires_grad = False + + def _freeze_norm(self, m: nn.Module): + if isinstance(m, nn.BatchNorm2d): + m = FrozenBatchNorm2d(m.num_features) + else: + for name, child in m.named_children(): + _child = self._freeze_norm(child) + if _child is not child: + setattr(m, name, _child) + return m + + def forward(self, x): + conv1 = self.conv1(x) + x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1) + outs = [] + for idx, stage in enumerate(self.res_layers): + x = stage(x) + if idx in self.return_idx: + outs.append(x) + return outs + + diff --git a/src/nn/backbone/test_resnet.py b/src/nn/backbone/test_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6639d79ec6b9a11fdc756cd94db211d38d566b61 --- /dev/null +++ b/src/nn/backbone/test_resnet.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from collections import OrderedDict + + +from src.core import register + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion*planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion*planes) + ) + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = F.relu(out) + return out + + + +class _ResNet(nn.Module): + def __init__(self, block, num_blocks, num_classes=10): + super().__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + + self.linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1]*(num_blocks-1) + layers = [] + for stride in strides: + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +@register +class MResNet(nn.Module): + def __init__(self, num_classes=10, num_blocks=[2, 2, 2, 2]) -> None: + super().__init__() + self.model = _ResNet(BasicBlock, num_blocks, num_classes) + + def forward(self, x): + return self.model(x) + diff --git a/src/nn/backbone/utils.py b/src/nn/backbone/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee250b1c9389cc11d0f553b28f23567e2c3b6860 --- /dev/null +++ b/src/nn/backbone/utils.py @@ -0,0 +1,58 @@ +""" +https://github.com/pytorch/vision/blob/main/torchvision/models/_utils.py + +by lyuwenyu +""" + +from collections import OrderedDict +from typing import Dict, List + + +import torch.nn as nn + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + """ + + _version = 3 + + def __init__(self, model: nn.Module, return_layers: List[str]) -> None: + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model. {}"\ + .format([name for name, _ in model.named_children()])) + orig_return_layers = return_layers + return_layers = {str(k): str(k) for k in return_layers} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super().__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + # out = OrderedDict() + outputs = [] + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + # out_name = self.return_layers[name] + # out[out_name] = x + outputs.append(x) + + return outputs + diff --git a/src/nn/criterion/__init__.py b/src/nn/criterion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9804569a11aab6016ee0b00b46f3776ab759d63a --- /dev/null +++ b/src/nn/criterion/__init__.py @@ -0,0 +1,6 @@ + +import torch.nn as nn +from src.core import register + +CrossEntropyLoss = register(nn.CrossEntropyLoss) + diff --git a/src/nn/criterion/__pycache__/__init__.cpython-310.pyc b/src/nn/criterion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b40472db141be1e9066464773da63303efced6d Binary files /dev/null and b/src/nn/criterion/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/nn/criterion/utils.py b/src/nn/criterion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8833e30bb6ccf948abf838fca7bf7cdf012e16 --- /dev/null +++ b/src/nn/criterion/utils.py @@ -0,0 +1,20 @@ +import torch +import torchvision + + + +def format_target(targets): + ''' + Args: + targets (List[Dict]), + Return: + tensor (Tensor), [im_id, label, bbox,] + ''' + outputs = [] + for i, tgt in enumerate(targets): + boxes = torchvision.ops.box_convert(tgt['boxes'], in_fmt='xyxy', out_fmt='cxcywh') + labels = tgt['labels'].reshape(-1, 1) + im_ids = torch.ones_like(labels) * i + outputs.append(torch.cat([im_ids, labels, boxes], dim=1)) + + return torch.cat(outputs, dim=0) diff --git a/src/optim/__init__.py b/src/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd7c81f9d09a57ef502c716a6f42566d9c17bae --- /dev/null +++ b/src/optim/__init__.py @@ -0,0 +1,4 @@ + +from .ema import * +from .optim import * +from .amp import * \ No newline at end of file diff --git a/src/optim/__pycache__/__init__.cpython-310.pyc b/src/optim/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f6159973a9f12b08b4adcc3e7429e207d8d00a2 Binary files /dev/null and b/src/optim/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/optim/__pycache__/amp.cpython-310.pyc b/src/optim/__pycache__/amp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dcd62279cfd81c973ce4d1bda89a7f1f95b4a08 Binary files /dev/null and b/src/optim/__pycache__/amp.cpython-310.pyc differ diff --git a/src/optim/__pycache__/ema.cpython-310.pyc b/src/optim/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd7ed7fefcb8a0ae1bc1d06b6894533e636e91f Binary files /dev/null and b/src/optim/__pycache__/ema.cpython-310.pyc differ diff --git a/src/optim/__pycache__/optim.cpython-310.pyc b/src/optim/__pycache__/optim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd5a90f9289a79a3d2a9d7993d8ecc43413412d Binary files /dev/null and b/src/optim/__pycache__/optim.cpython-310.pyc differ diff --git a/src/optim/amp.py b/src/optim/amp.py new file mode 100644 index 0000000000000000000000000000000000000000..e43d0212e445213b658ead34dd047ec17b74e541 --- /dev/null +++ b/src/optim/amp.py @@ -0,0 +1,12 @@ +import torch +import torch.nn as nn +import torch.cuda.amp as amp + + +from src.core import register +import src.misc.dist as dist + + +__all__ = ['GradScaler'] + +GradScaler = register(amp.grad_scaler.GradScaler) diff --git a/src/optim/ema.py b/src/optim/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bf962b3a7a8ef34a600053d3346444b3d17bcae1 --- /dev/null +++ b/src/optim/ema.py @@ -0,0 +1,115 @@ +""" +reference: +https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py#L404 + +by lyuwenyu +""" + +import torch +import torch.nn as nn + +import math +from copy import deepcopy + + + +from src.core import register +import src.misc.dist as dist + + +__all__ = ['ModelEMA'] + + + +@register +class ModelEMA(object): + """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models + Keep a moving average of everything in the model state_dict (parameters and buffers). + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + A smoothed version of the weights is necessary for some training schemes to perform well. + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model: nn.Module, decay: float=0.9999, warmups: int=2000): + super().__init__() + + # Create EMA + self.module = deepcopy(dist.de_parallel(model)).eval() # FP32 EMA + + # if next(model.parameters()).device.type != 'cpu': + # self.module.half() # FP16 EMA + + self.decay = decay + self.warmups = warmups + self.updates = 0 # number of EMA updates + # self.filter_no_grad = filter_no_grad + self.decay_fn = lambda x: decay * (1 - math.exp(-x / warmups)) # decay exponential ramp (to help early epochs) + + for p in self.module.parameters(): + p.requires_grad_(False) + + def update(self, model: nn.Module): + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay_fn(self.updates) + + msd = dist.de_parallel(model).state_dict() + for k, v in self.module.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * msd[k].detach() + + def to(self, *args, **kwargs): + self.module = self.module.to(*args, **kwargs) + return self + + def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): + # Update EMA attributes + self.copy_attr(self.module, model, include, exclude) + + @staticmethod + def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (len(include) and k not in include) or k.startswith('_') or k in exclude: + continue + else: + setattr(a, k, v) + + def state_dict(self, ): + return dict(module=self.module.state_dict(), updates=self.updates, warmups=self.warmups) + + def load_state_dict(self, state): + self.module.load_state_dict(state['module']) + if 'updates' in state: + self.updates = state['updates'] + + def forwad(self, ): + raise RuntimeError('ema...') + + def extra_repr(self) -> str: + return f'decay={self.decay}, warmups={self.warmups}' + + + + +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel `_ + is used to compute the EMA. + """ + def __init__(self, model, decay, device="cpu", use_buffers=True): + + self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000)) + + def ema_avg(avg_model_param, model_param, num_averaged): + decay = self.decay_fn(num_averaged) + return decay * avg_model_param + (1 - decay) * model_param + + super().__init__(model, device, ema_avg, use_buffers=use_buffers) + + + diff --git a/src/optim/optim.py b/src/optim/optim.py new file mode 100644 index 0000000000000000000000000000000000000000..b10bd82926b3f40dd63e9545ba24bfc3d8a3f651 --- /dev/null +++ b/src/optim/optim.py @@ -0,0 +1,22 @@ + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler + +from src.core import register + + +__all__ = ['AdamW', 'SGD', 'Adam', 'MultiStepLR', 'CosineAnnealingLR', 'OneCycleLR', 'LambdaLR'] + + + +SGD = register(optim.SGD) +Adam = register(optim.Adam) +AdamW = register(optim.AdamW) + + +MultiStepLR = register(lr_scheduler.MultiStepLR) +CosineAnnealingLR = register(lr_scheduler.CosineAnnealingLR) +OneCycleLR = register(lr_scheduler.OneCycleLR) +LambdaLR = register(lr_scheduler.LambdaLR) diff --git a/src/solver/__init__.py b/src/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eddab7bf7b9a33bfb28f837b8018fb6bd4690614 --- /dev/null +++ b/src/solver/__init__.py @@ -0,0 +1,12 @@ +"""by lyuwenyu +""" + +from .solver import BaseSolver +from .det_solver import DetSolver + + +from typing import Dict + +TASKS :Dict[str, BaseSolver] = { + 'detection': DetSolver, +} \ No newline at end of file diff --git a/src/solver/__pycache__/__init__.cpython-310.pyc b/src/solver/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcb55b231b24dc158401a9610e1aec1ab57b5d15 Binary files /dev/null and b/src/solver/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/solver/__pycache__/det_engine.cpython-310.pyc b/src/solver/__pycache__/det_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dffcd12597fc7384380da4375cca959c9311b3c Binary files /dev/null and b/src/solver/__pycache__/det_engine.cpython-310.pyc differ diff --git a/src/solver/__pycache__/det_solver.cpython-310.pyc b/src/solver/__pycache__/det_solver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67614f0eee063d3a78ae8157d9fe83152e3d8330 Binary files /dev/null and b/src/solver/__pycache__/det_solver.cpython-310.pyc differ diff --git a/src/solver/__pycache__/solver.cpython-310.pyc b/src/solver/__pycache__/solver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bc49ec567891a95ca3b281a2757c60c297ad5ad Binary files /dev/null and b/src/solver/__pycache__/solver.cpython-310.pyc differ diff --git a/src/solver/det_engine.py b/src/solver/det_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..7de6b03ae2e4bcaa9167d634a46848c1c16416d9 --- /dev/null +++ b/src/solver/det_engine.py @@ -0,0 +1,190 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +https://github.com/facebookresearch/detr/blob/main/engine.py + +by lyuwenyu +""" + +import math +import os +import sys +import pathlib +from typing import Iterable + +import torch +import torch.amp + +from src.data import CocoEvaluator +from src.misc import (MetricLogger, SmoothedValue, reduce_dict) + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, max_norm: float = 0, **kwargs): + model.train() + criterion.train() + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) + # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = kwargs.get('print_freq', 10) + + ema = kwargs.get('ema', None) + scaler = kwargs.get('scaler', None) + + for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + if scaler is not None: + with torch.autocast(device_type=str(device), cache_enabled=True): + outputs = model(samples, targets) + + with torch.autocast(device_type=str(device), enabled=False): + loss_dict = criterion(outputs, targets) + + loss = sum(loss_dict.values()) + scaler.scale(loss).backward() + + if max_norm > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + else: + outputs = model(samples, targets) + loss_dict = criterion(outputs, targets) + + loss = sum(loss_dict.values()) + optimizer.zero_grad() + loss.backward() + + if max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + optimizer.step() + + # ema + if ema is not None: + ema.update(model) + + loss_dict_reduced = reduce_dict(loss_dict) + loss_value = sum(loss_dict_reduced.values()) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + metric_logger.update(loss=loss_value, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + +@torch.no_grad() +def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, base_ds, device, output_dir): + model.eval() + criterion.eval() + + metric_logger = MetricLogger(delimiter=" ") + # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Test:' + + # iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) + iou_types = postprocessors.iou_types + coco_evaluator = CocoEvaluator(base_ds, iou_types) + # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] + + panoptic_evaluator = None + # if 'panoptic' in postprocessors.keys(): + # panoptic_evaluator = PanopticEvaluator( + # data_loader.dataset.ann_file, + # data_loader.dataset.ann_folder, + # output_dir=os.path.join(output_dir, "panoptic_eval"), + # ) + + for samples, targets in metric_logger.log_every(data_loader, 10, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # with torch.autocast(device_type=str(device)): + # outputs = model(samples) + + outputs = model(samples) + print(outputs) + # loss_dict = criterion(outputs, targets) + # weight_dict = criterion.weight_dict + # # reduce losses over all GPUs for logging purposes + # loss_dict_reduced = reduce_dict(loss_dict) + # loss_dict_reduced_scaled = {k: v * weight_dict[k] + # for k, v in loss_dict_reduced.items() if k in weight_dict} + # loss_dict_reduced_unscaled = {f'{k}_unscaled': v + # for k, v in loss_dict_reduced.items()} + # metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), + # **loss_dict_reduced_scaled, + # **loss_dict_reduced_unscaled) + # metric_logger.update(class_error=loss_dict_reduced['class_error']) + + orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + results = postprocessors(outputs, orig_target_sizes) + # results = postprocessors(outputs, targets) + + # if 'segm' in postprocessors.keys(): + # target_sizes = torch.stack([t["size"] for t in targets], dim=0) + # results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) + + res = {target['image_id'].item(): output for target, output in zip(targets, results)} + if coco_evaluator is not None: + coco_evaluator.update(res) + + # if panoptic_evaluator is not None: + # res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) + # for i, target in enumerate(targets): + # image_id = target["image_id"].item() + # file_name = f"{image_id:012d}.png" + # res_pano[i]["image_id"] = image_id + # res_pano[i]["file_name"] = file_name + # panoptic_evaluator.update(res_pano) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + if coco_evaluator is not None: + coco_evaluator.synchronize_between_processes() + if panoptic_evaluator is not None: + panoptic_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + if coco_evaluator is not None: + coco_evaluator.accumulate() + coco_evaluator.summarize() + + # panoptic_res = None + # if panoptic_evaluator is not None: + # panoptic_res = panoptic_evaluator.summarize() + + stats = {} + # stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + if coco_evaluator is not None: + if 'bbox' in iou_types: + stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() + if 'segm' in iou_types: + stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() + + # if panoptic_res is not None: + # stats['PQ_all'] = panoptic_res["All"] + # stats['PQ_th'] = panoptic_res["Things"] + # stats['PQ_st'] = panoptic_res["Stuff"] + + return stats, coco_evaluator + + + diff --git a/src/solver/det_solver.py b/src/solver/det_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..d0a0a8400cf851ccc641f97310059edf56db78ea --- /dev/null +++ b/src/solver/det_solver.py @@ -0,0 +1,104 @@ +''' +by lyuwenyu +''' +import time +import json +import datetime + +import torch + +from src.misc import dist +from src.data import get_coco_api_from_dataset + +from .solver import BaseSolver +from .det_engine import train_one_epoch, evaluate + + +class DetSolver(BaseSolver): + + def fit(self, ): + print("Start training") + self.train() + + args = self.cfg + + n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + print('number of params:', n_parameters) + + base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) + # best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, } + best_stat = {'epoch': -1, } + + start_time = time.time() + for epoch in range(self.last_epoch + 1, args.epoches): + if dist.is_dist_available_and_initialized(): + self.train_dataloader.sampler.set_epoch(epoch) + + train_stats = train_one_epoch( + self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch, + args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler) + + self.lr_scheduler.step() + + if self.output_dir: + checkpoint_paths = [self.output_dir / 'checkpoint.pth'] + # extra checkpoint before LR drop and every 100 epochs + if (epoch + 1) % args.checkpoint_step == 0: + checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth') + for checkpoint_path in checkpoint_paths: + dist.save_on_master(self.state_dict(epoch), checkpoint_path) + + module = self.ema.module if self.ema else self.model + test_stats, coco_evaluator = evaluate( + module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir + ) + + # TODO + for k in test_stats.keys(): + if k in best_stat: + best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch'] + best_stat[k] = max(best_stat[k], test_stats[k][0]) + else: + best_stat['epoch'] = epoch + best_stat[k] = test_stats[k][0] + print('best_stat: ', best_stat) + + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + 'n_parameters': n_parameters} + + if self.output_dir and dist.is_main_process(): + with (self.output_dir / "log.txt").open("a") as f: + f.write(json.dumps(log_stats) + "\n") + + # for evaluation logs + if coco_evaluator is not None: + (self.output_dir / 'eval').mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ['latest.pth'] + if epoch % 50 == 0: + filenames.append(f'{epoch:03}.pth') + for name in filenames: + torch.save(coco_evaluator.coco_eval["bbox"].eval, + self.output_dir / "eval" / name) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + + def val(self, ): + self.eval() + + base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset) + + module = self.ema.module if self.ema else self.model + test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, + self.val_dataloader, base_ds, self.device, self.output_dir) + + if self.output_dir: + dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth") + + return diff --git a/src/solver/solver.py b/src/solver/solver.py new file mode 100644 index 0000000000000000000000000000000000000000..55452f28ff9d43b5cece8879e762017246f0a5f0 --- /dev/null +++ b/src/solver/solver.py @@ -0,0 +1,182 @@ +"""by lyuwenyu +""" + +import torch +import torch.nn as nn + +from datetime import datetime +from pathlib import Path +from typing import Dict + +from src.misc import dist +from src.core import BaseConfig + + +class BaseSolver(object): + def __init__(self, cfg: BaseConfig) -> None: + + self.cfg = cfg + + def setup(self, ): + '''Avoid instantiating unnecessary classes + ''' + cfg = self.cfg + device = cfg.device + self.device = device + self.last_epoch = cfg.last_epoch + + self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn) + self.criterion = cfg.criterion.to(device) + self.postprocessor = cfg.postprocessor + + # NOTE (lvwenyu): should load_tuning_state before ema instance building + if self.cfg.tuning: + print(f'Tuning checkpoint from {self.cfg.tuning}') + self.load_tuning_state(self.cfg.tuning) + + self.scaler = cfg.scaler + self.ema = cfg.ema.to(device) if cfg.ema is not None else None + + self.output_dir = Path(cfg.output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + + def train(self, ): + self.setup() + self.optimizer = self.cfg.optimizer + self.lr_scheduler = self.cfg.lr_scheduler + + # NOTE instantiating order + if self.cfg.resume: + print(f'Resume checkpoint from {self.cfg.resume}') + self.resume(self.cfg.resume) + + self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \ + shuffle=self.cfg.train_dataloader.shuffle) + self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \ + shuffle=self.cfg.val_dataloader.shuffle) + + + def eval(self, ): + self.setup() + self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \ + shuffle=self.cfg.val_dataloader.shuffle) + + if self.cfg.resume: + print(f'resume from {self.cfg.resume}') + self.resume(self.cfg.resume) + + + def state_dict(self, last_epoch): + '''state dict + ''' + state = {} + state['model'] = dist.de_parallel(self.model).state_dict() + state['date'] = datetime.now().isoformat() + + # TODO + state['last_epoch'] = last_epoch + + if self.optimizer is not None: + state['optimizer'] = self.optimizer.state_dict() + + if self.lr_scheduler is not None: + state['lr_scheduler'] = self.lr_scheduler.state_dict() + # state['last_epoch'] = self.lr_scheduler.last_epoch + + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if self.scaler is not None: + state['scaler'] = self.scaler.state_dict() + + return state + + + def load_state_dict(self, state): + '''load state dict + ''' + # TODO + if getattr(self, 'last_epoch', None) and 'last_epoch' in state: + self.last_epoch = state['last_epoch'] + print('Loading last_epoch') + + if getattr(self, 'model', None) and 'model' in state: + if dist.is_parallel(self.model): + self.model.module.load_state_dict(state['model']) + else: + self.model.load_state_dict(state['model']) + print('Loading model.state_dict') + + if getattr(self, 'ema', None) and 'ema' in state: + self.ema.load_state_dict(state['ema']) + print('Loading ema.state_dict') + + if getattr(self, 'optimizer', None) and 'optimizer' in state: + self.optimizer.load_state_dict(state['optimizer']) + print('Loading optimizer.state_dict') + + if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state: + self.lr_scheduler.load_state_dict(state['lr_scheduler']) + print('Loading lr_scheduler.state_dict') + + if getattr(self, 'scaler', None) and 'scaler' in state: + self.scaler.load_state_dict(state['scaler']) + print('Loading scaler.state_dict') + + + def save(self, path): + '''save state + ''' + state = self.state_dict() + dist.save_on_master(state, path) + + + def resume(self, path): + '''load resume + ''' + # for cuda:0 memory + state = torch.load(path, map_location='cpu') + self.load_state_dict(state) + + def load_tuning_state(self, path,): + """only load model for tuning and skip missed/dismatched keys + """ + if 'http' in path: + state = torch.hub.load_state_dict_from_url(path, map_location='cpu') + else: + state = torch.load(path, map_location='cpu') + + module = dist.de_parallel(self.model) + + # TODO hard code + if 'ema' in state: + stat, infos = self._matched_state(module.state_dict(), state['ema']['module']) + else: + stat, infos = self._matched_state(module.state_dict(), state['model']) + + module.load_state_dict(stat, strict=False) + print(f'Load model.state_dict, {infos}') + + @staticmethod + def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]): + missed_list = [] + unmatched_list = [] + matched_state = {} + for k, v in state.items(): + if k in params: + if v.shape == params[k].shape: + matched_state[k] = params[k] + else: + unmatched_list.append(k) + else: + missed_list.append(k) + + return matched_state, {'missed': missed_list, 'unmatched': unmatched_list} + + + def fit(self, ): + raise NotImplementedError('') + + def val(self, ): + raise NotImplementedError('') diff --git a/src/zoo/__init__.py b/src/zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c56d9a47b56332e968dad230c1feeff5a6d7c7 --- /dev/null +++ b/src/zoo/__init__.py @@ -0,0 +1,2 @@ + +from .rtdetr import * diff --git a/src/zoo/__pycache__/__init__.cpython-310.pyc b/src/zoo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87e946803139840579325f6a731195415f7bf1a Binary files /dev/null and b/src/zoo/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__init__.py b/src/zoo/rtdetr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b4583b3b86e7c2ba4044f73d51e1254ca327fe5 --- /dev/null +++ b/src/zoo/rtdetr/__init__.py @@ -0,0 +1,12 @@ +"""by lyuwenyu +""" + + +from .rtdetr import * + +from .hybrid_encoder import * +from .rtdetr_decoder import * +from .rtdetr_postprocessor import * +from .rtdetr_criterion import * + +from .matcher import * diff --git a/src/zoo/rtdetr/__pycache__/__init__.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55915be0c74facacf9316e8f7a9d37b1bab7fa6e Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/box_ops.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/box_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e7bc53ff302d7c443a72130257a90f6b059f166 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/box_ops.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/denoising.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/denoising.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb7a0f7ce07129eb92c92be9675511bcf318ddb4 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/denoising.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/hybrid_encoder.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/hybrid_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dd43758edf71ff8d71b33cc70f64111cc9cd0bf Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/hybrid_encoder.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/matcher.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/matcher.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a1580e44731c7a09b699073a6d5f13df66e7bba Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/matcher.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/rtdetr.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/rtdetr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97cfebdf5fb434e4899051354fab3f201632c3a3 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/rtdetr.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/rtdetr_criterion.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/rtdetr_criterion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8da511973e6f9fe191a4a139b121bbb08becbf6 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/rtdetr_criterion.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/rtdetr_decoder.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/rtdetr_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d22961502e5e256188ed0bf9b69ab2e8bc49060f Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/rtdetr_decoder.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/rtdetr_postprocessor.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/rtdetr_postprocessor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c7780f0e9daa9166fd5da6e776432348092f583 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/rtdetr_postprocessor.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/__pycache__/utils.cpython-310.pyc b/src/zoo/rtdetr/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39417e3c51c33bf1e9e688f7386cefd3c03aa57 Binary files /dev/null and b/src/zoo/rtdetr/__pycache__/utils.cpython-310.pyc differ diff --git a/src/zoo/rtdetr/box_ops.py b/src/zoo/rtdetr/box_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5d65866556d14cc866eac5d597d8a191528c65dc --- /dev/null +++ b/src/zoo/rtdetr/box_ops.py @@ -0,0 +1,89 @@ +''' +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +https://github.com/facebookresearch/detr/blob/main/util/box_ops.py +''' + +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) \ No newline at end of file diff --git a/src/zoo/rtdetr/denoising.py b/src/zoo/rtdetr/denoising.py new file mode 100644 index 0000000000000000000000000000000000000000..68307522f8152c8c224a2c38b9d48abeea2b046a --- /dev/null +++ b/src/zoo/rtdetr/denoising.py @@ -0,0 +1,125 @@ +"""by lyuwenyu +""" + +import torch + +from .utils import inverse_sigmoid +from .box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh + + + +def get_contrastive_denoising_training_group(targets, + num_classes, + num_queries, + class_embed, + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0,): + """cnd""" + if num_denoising <= 0: + return None, None, None, None + + num_gts = [len(t['labels']) for t in targets] + device = targets[0]['labels'].device + + max_gt_num = max(num_gts) + if max_gt_num == 0: + return None, None, None, None + + num_group = num_denoising // max_gt_num + num_group = 1 if num_group == 0 else num_group + # pad gt to max_num of a batch + bs = len(num_gts) + + input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device) + input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device) + pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device) + + for i in range(bs): + num_gt = num_gts[i] + if num_gt > 0: + input_query_class[i, :num_gt] = targets[i]['labels'] + input_query_bbox[i, :num_gt] = targets[i]['boxes'] + pad_gt_mask[i, :num_gt] = 1 + # each group has positive and negative queries. + input_query_class = input_query_class.tile([1, 2 * num_group]) + input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1]) + pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group]) + # positive and negative mask + negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device) + negative_gt_mask[:, max_gt_num:] = 1 + negative_gt_mask = negative_gt_mask.tile([1, num_group, 1]) + positive_gt_mask = 1 - negative_gt_mask + # contrastive denoising training positive index + positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask + dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1] + dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts]) + # total denoising queries + num_denoising = int(max_gt_num * 2 * num_group) + + if label_noise_ratio > 0: + mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5) + # randomly put a new one here + new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype) + input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class) + + # if label_noise_ratio > 0: + # input_query_class = input_query_class.flatten() + # pad_gt_mask = pad_gt_mask.flatten() + # # half of bbox prob + # # mask = torch.rand(input_query_class.shape, device=device) < (label_noise_ratio * 0.5) + # mask = torch.rand_like(input_query_class) < (label_noise_ratio * 0.5) + # chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1) + # # randomly put a new one here + # new_label = torch.randint_like(chosen_idx, 0, num_classes, dtype=input_query_class.dtype) + # # input_query_class.scatter_(dim=0, index=chosen_idx, value=new_label) + # input_query_class[chosen_idx] = new_label + # input_query_class = input_query_class.reshape(bs, num_denoising) + # pad_gt_mask = pad_gt_mask.reshape(bs, num_denoising) + + if box_noise_scale > 0: + known_bbox = box_cxcywh_to_xyxy(input_query_bbox) + diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale + rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(input_query_bbox) + rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask) + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + input_query_bbox = box_xyxy_to_cxcywh(known_bbox) + input_query_bbox = inverse_sigmoid(input_query_bbox) + + # class_embed = torch.concat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=device)]) + # input_query_class = torch.gather( + # class_embed, input_query_class.flatten(), + # axis=0).reshape(bs, num_denoising, -1) + # input_query_class = class_embed(input_query_class.flatten()).reshape(bs, num_denoising, -1) + input_query_class = class_embed(input_query_class) + + tgt_size = num_denoising + num_queries + # attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0 + attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device) + # match query cannot see the reconstruction + attn_mask[num_denoising:, :num_denoising] = True + + # reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True + if i == num_group - 1: + attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True + else: + attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True + attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True + + dn_meta = { + "dn_positive_idx": dn_positive_idx, + "dn_num_group": num_group, + "dn_num_split": [num_denoising, num_queries] + } + + # print(input_query_class.shape) # torch.Size([4, 196, 256]) + # print(input_query_bbox.shape) # torch.Size([4, 196, 4]) + # print(attn_mask.shape) # torch.Size([496, 496]) + + return input_query_class, input_query_bbox, attn_mask, dn_meta diff --git a/src/zoo/rtdetr/hybrid_encoder.py b/src/zoo/rtdetr/hybrid_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..804db69c120bc41c8e9c1b9f81e436c87323609f --- /dev/null +++ b/src/zoo/rtdetr/hybrid_encoder.py @@ -0,0 +1,322 @@ +'''by lyuwenyu +''' + +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import get_activation + +from src.core import register + + +__all__ = ['HybridEncoder'] + + + +class ConvNormLayer(nn.Module): + def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None): + super().__init__() + self.conv = nn.Conv2d( + ch_in, + ch_out, + kernel_size, + stride, + padding=(kernel_size-1)//2 if padding is None else padding, + bias=bias) + self.norm = nn.BatchNorm2d(ch_out) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +class RepVggBlock(nn.Module): + def __init__(self, ch_in, ch_out, act='relu'): + super().__init__() + self.ch_in = ch_in + self.ch_out = ch_out + self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None) + self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + if hasattr(self, 'conv'): + y = self.conv(x) + else: + y = self.conv1(x) + self.conv2(x) + + return self.act(y) + + def convert_to_deploy(self): + if not hasattr(self, 'conv'): + self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1) + + kernel, bias = self.get_equivalent_kernel_bias() + self.conv.weight.data = kernel + self.conv.bias.data = bias + # self.__delattr__('conv1') + # self.__delattr__('conv2') + + def get_equivalent_kernel_bias(self): + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1 + + def _pad_1x1_to_3x3_tensor(self, kernel1x1): + if kernel1x1 is None: + return 0 + else: + return F.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch: ConvNormLayer): + if branch is None: + return 0, 0 + kernel = branch.conv.weight + running_mean = branch.norm.running_mean + running_var = branch.norm.running_var + gamma = branch.norm.weight + beta = branch.norm.bias + eps = branch.norm.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + +class CSPRepLayer(nn.Module): + def __init__(self, + in_channels, + out_channels, + num_blocks=3, + expansion=1.0, + bias=None, + act="silu"): + super(CSPRepLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act) + self.bottlenecks = nn.Sequential(*[ + RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks) + ]) + if hidden_channels != out_channels: + self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act) + else: + self.conv3 = nn.Identity() + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + return self.conv3(x_1 + x_2) + + +# transformer +class TransformerEncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False): + super().__init__() + self.normalize_before = normalize_before + + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout, batch_first=True) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = get_activation(activation) + + @staticmethod + def with_pos_embed(tensor, pos_embed): + return tensor if pos_embed is None else tensor + pos_embed + + def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: + residual = src + if self.normalize_before: + src = self.norm1(src) + q = k = self.with_pos_embed(src, pos_embed) + src, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) + + src = residual + self.dropout1(src) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, src_mask=None, pos_embed=None) -> torch.Tensor: + output = src + for layer in self.layers: + output = layer(output, src_mask=src_mask, pos_embed=pos_embed) + + if self.norm is not None: + output = self.norm(output) + + return output + + +@register +class HybridEncoder(nn.Module): + def __init__(self, + in_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + hidden_dim=256, + nhead=8, + dim_feedforward = 1024, + dropout=0.0, + enc_act='gelu', + use_encoder_idx=[2], + num_encoder_layers=1, + pe_temperature=10000, + expansion=1.0, + depth_mult=1.0, + act='silu', + eval_spatial_size=None): + super().__init__() + self.in_channels = in_channels + self.feat_strides = feat_strides + self.hidden_dim = hidden_dim + self.use_encoder_idx = use_encoder_idx + self.num_encoder_layers = num_encoder_layers + self.pe_temperature = pe_temperature + self.eval_spatial_size = eval_spatial_size + + self.out_channels = [hidden_dim for _ in range(len(in_channels))] + self.out_strides = feat_strides + + # channel projection + self.input_proj = nn.ModuleList() + for in_channel in in_channels: + self.input_proj.append( + nn.Sequential( + nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(hidden_dim) + ) + ) + + # encoder transformer + encoder_layer = TransformerEncoderLayer( + hidden_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=enc_act) + + self.encoder = nn.ModuleList([ + TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx)) + ]) + + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(in_channels) - 1, 0, -1): + self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act)) + self.fpn_blocks.append( + CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) + ) + + # bottom-up pan + self.downsample_convs = nn.ModuleList() + self.pan_blocks = nn.ModuleList() + for _ in range(len(in_channels) - 1): + self.downsample_convs.append( + ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act) + ) + self.pan_blocks.append( + CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion) + ) + + self._reset_parameters() + + def _reset_parameters(self): + if self.eval_spatial_size: + for idx in self.use_encoder_idx: + stride = self.feat_strides[idx] + pos_embed = self.build_2d_sincos_position_embedding( + self.eval_spatial_size[1] // stride, self.eval_spatial_size[0] // stride, + self.hidden_dim, self.pe_temperature) + setattr(self, f'pos_embed{idx}', pos_embed) + # self.register_buffer(f'pos_embed{idx}', pos_embed) + + @staticmethod + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.): + ''' + ''' + grid_w = torch.arange(int(w), dtype=torch.float32) + grid_h = torch.arange(int(h), dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') + assert embed_dim % 4 == 0, \ + 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1. / (temperature ** omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.concat([out_w.sin(), out_w.cos(), out_h.sin(), out_h.cos()], dim=1)[None, :, :] + + def forward(self, feats): + assert len(feats) == len(self.in_channels) + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + + # encoder + if self.num_encoder_layers > 0: + for i, enc_ind in enumerate(self.use_encoder_idx): + h, w = proj_feats[enc_ind].shape[2:] + # flatten [B, C, H, W] to [B, HxW, C] + src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1) + if self.training or self.eval_spatial_size is None: + pos_embed = self.build_2d_sincos_position_embedding( + w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device) + else: + pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device) + + memory = self.encoder[i](src_flatten, pos_embed=pos_embed) + proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() + # print([x.is_contiguous() for x in proj_feats ]) + + # broadcasting and fusion + inner_outs = [proj_feats[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = inner_outs[0] + feat_low = proj_feats[idx - 1] + feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) + inner_outs[0] = feat_high + upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest') + inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) + inner_outs.insert(0, inner_out) + + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_high = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1)) + outs.append(out) + + return outs diff --git a/src/zoo/rtdetr/matcher.py b/src/zoo/rtdetr/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9dec1f8e030258f74d3198423186325d5f3201 --- /dev/null +++ b/src/zoo/rtdetr/matcher.py @@ -0,0 +1,108 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +Modules to compute the matching cost and solve the corresponding LSAP. + +by lyuwenyu +""" + +import torch +import torch.nn.functional as F + +from scipy.optimize import linear_sum_assignment +from torch import nn + +from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou + +from src.core import register + + +@register +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + __share__ = ['use_focal_loss', ] + + def __init__(self, weight_dict, use_focal_loss=False, alpha=0.25, gamma=2.0): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = weight_dict['cost_class'] + self.cost_bbox = weight_dict['cost_bbox'] + self.cost_giou = weight_dict['cost_giou'] + + self.use_focal_loss = use_focal_loss + self.alpha = alpha + self.gamma = gamma + + assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + if self.use_focal_loss: + out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1)) + else: + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + if self.use_focal_loss: + out_prob = out_prob[:, tgt_ids] + neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = self.alpha * ((1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] diff --git a/src/zoo/rtdetr/rtdetr.py b/src/zoo/rtdetr/rtdetr.py new file mode 100644 index 0000000000000000000000000000000000000000..851d4f74bc58d38135499a94427dc707f1726013 --- /dev/null +++ b/src/zoo/rtdetr/rtdetr.py @@ -0,0 +1,44 @@ +"""by lyuwenyu +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import random +import numpy as np + +from src.core import register + + +__all__ = ['RTDETR', ] + + +@register +class RTDETR(nn.Module): + __inject__ = ['backbone', 'encoder', 'decoder', ] + + def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None): + super().__init__() + self.backbone = backbone + self.decoder = decoder + self.encoder = encoder + self.multi_scale = multi_scale + + def forward(self, x, targets=None): + if self.multi_scale and self.training: + sz = np.random.choice(self.multi_scale) + x = F.interpolate(x, size=[sz, sz]) + + x = self.backbone(x) + x = self.encoder(x) + x = self.decoder(x, targets) + + return x + + def deploy(self, ): + self.eval() + for m in self.modules(): + if hasattr(m, 'convert_to_deploy'): + m.convert_to_deploy() + return self diff --git a/src/zoo/rtdetr/rtdetr_criterion.py b/src/zoo/rtdetr/rtdetr_criterion.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce77c0f160a5f2e6d6cfa1b94f943193d022306 --- /dev/null +++ b/src/zoo/rtdetr/rtdetr_criterion.py @@ -0,0 +1,341 @@ +""" +reference: +https://github.com/facebookresearch/detr/blob/main/models/detr.py + +by lyuwenyu +""" + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +# from torchvision.ops import box_convert, generalized_box_iou +from .box_ops import box_cxcywh_to_xyxy, box_iou, generalized_box_iou + +from src.misc.dist import get_world_size, is_dist_available_and_initialized +from src.core import register + + + +@register +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + __share__ = ['num_classes', ] + __inject__ = ['matcher', ] + + def __init__(self, matcher, weight_dict, losses, alpha=0.2, gamma=2.0, eos_coef=1e-4, num_classes=80): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = eos_coef + self.register_buffer('empty_weight', empty_weight) + + self.alpha = alpha + self.gamma = gamma + + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + def loss_labels_bce(self, outputs, targets, indices, num_boxes, log=True): + src_logits = outputs['pred_logits'] + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + loss = F.binary_cross_entropy_with_logits(src_logits, target * 1., reduction='none') + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {'loss_bce': loss} + + def loss_labels_focal(self, outputs, targets, indices, num_boxes, log=True): + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target = F.one_hot(target_classes, num_classes=self.num_classes+1)[..., :-1] + # ce_loss = F.binary_cross_entropy_with_logits(src_logits, target * 1., reduction="none") + # prob = F.sigmoid(src_logits) # TODO .detach() + # p_t = prob * target + (1 - prob) * (1 - target) + # alpha_t = self.alpha * target + (1 - self.alpha) * (1 - target) + # loss = alpha_t * ce_loss * ((1 - p_t) ** self.gamma) + # loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + loss = torchvision.ops.sigmoid_focal_loss(src_logits, target, self.alpha, self.gamma, reduction='none') + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + + return {'loss_focal': loss} + + def loss_labels_vfl(self, outputs, targets, indices, num_boxes, log=True): + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)) + ious = torch.diag(ious).detach() + + src_logits = outputs['pred_logits'] + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1] + + target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype) + target_score_o[idx] = ious.to(target_score_o.dtype) + target_score = target_score_o.unsqueeze(-1) * target + + pred_score = F.sigmoid(src_logits).detach() + weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score + + loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none') + loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes + return {'loss_vfl': loss} + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + losses = {} + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes), + box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + src_masks = outputs["pred_masks"] + src_masks = src_masks[src_idx] + masks = [t["masks"] for t in targets] + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() + target_masks = target_masks.to(src_masks) + target_masks = target_masks[tgt_idx] + + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks.flatten(1) + target_masks = target_masks.view(src_masks.shape) + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks, + + 'bce': self.loss_labels_bce, + 'focal': self.loss_labels_focal, + 'vfl': self.loss_labels_vfl, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_available_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + losses.update(l_dict) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + # In case of cdn auxiliary losses. For rtdetr + if 'dn_aux_outputs' in outputs: + assert 'dn_meta' in outputs, '' + indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets) + num_boxes = num_boxes * outputs['dn_meta']['dn_num_group'] + + for i, aux_outputs in enumerate(outputs['dn_aux_outputs']): + # indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict} + l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + return losses + + @staticmethod + def get_cdn_matched_indices(dn_meta, targets): + '''get_cdn_matched_indices + ''' + dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"] + num_gts = [len(t['labels']) for t in targets] + device = targets[0]['labels'].device + + dn_match_indices = [] + for i, num_gt in enumerate(num_gts): + if num_gt > 0: + gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device) + gt_idx = gt_idx.tile(dn_num_group) + assert len(dn_positive_idx[i]) == len(gt_idx) + dn_match_indices.append((dn_positive_idx[i], gt_idx)) + else: + dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \ + torch.zeros(0, dtype=torch.int64, device=device))) + + return dn_match_indices + + + + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + + + diff --git a/src/zoo/rtdetr/rtdetr_decoder.py b/src/zoo/rtdetr/rtdetr_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..300b50ca3a73d56451be43e58d7f872ec101b314 --- /dev/null +++ b/src/zoo/rtdetr/rtdetr_decoder.py @@ -0,0 +1,627 @@ +"""by lyuwenyu +""" + +import math +import copy +from collections import OrderedDict +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter + +from .denoising import get_contrastive_denoising_training_group +from .utils import deformable_attention_core_func, get_activation, inverse_sigmoid +from .utils import bias_init_with_prob +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear + +from src.core import register + +import numpy as np + +import scipy.linalg as sl + +__all__ = ['RTDETRTransformer'] + + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act='relu'): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.act = nn.Identity() if act is None else get_activation(act) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class CoPE(nn.Module): + def __init__(self,npos_max,head_dim): + super(CoPE, self).__init__() + self.npos_max = npos_max #? + self.pos_emb = nn.parameter.Parameter(torch.zeros(1,head_dim,npos_max)) + + def forward(self,query,attn_logits): + #compute positions + gates = torch.sigmoid(attn_logits) #sig(qk) + pos = gates.flip(-1).cumsum(dim=-1).flip(-1) + pos = pos.clamp(max=self.npos_max-1) + #interpolate from integer positions + pos_ceil = pos.ceil().long() + pos_floor = pos.floor().long() + logits_int = torch.matmul(query,self.pos_emb) + logits_ceil = logits_int.gather(-1,pos_ceil) + logits_floor = logits_int.gather(-1,pos_floor) + w = pos-pos_floor + return logits_ceil*w+logits_floor*(1-w) + + + + +class MSDeformableAttention(nn.Module): + def __init__(self, embed_dim=256, num_heads=8, num_levels=4, num_points=4,): + """ + Multi-Scale Deformable Attention Module + """ + super(MSDeformableAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_levels = num_levels + self.num_points = num_points + self.total_points = num_heads * num_levels * num_points + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2,) + self.attention_weights = nn.Linear(embed_dim, self.total_points) + self.value_proj = nn.Linear(embed_dim, embed_dim) + self.output_proj = nn.Linear(embed_dim, embed_dim) + + self.ms_deformable_attn_core = deformable_attention_core_func + + self._reset_parameters() + + + def _reset_parameters(self): + # sampling_offsets + init.constant_(self.sampling_offsets.weight, 0) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(self.num_heads, 1, 1, 2).tile([1, self.num_levels, self.num_points, 1]) + scaling = torch.arange(1, self.num_points + 1, dtype=torch.float32).reshape(1, 1, -1, 1) + grid_init *= scaling + self.sampling_offsets.bias.data[...] = grid_init.flatten() + + # attention_weights + init.constant_(self.attention_weights.weight, 0) + init.constant_(self.attention_weights.bias, 0) + + # proj + init.xavier_uniform_(self.value_proj.weight) + init.constant_(self.value_proj.bias, 0) + init.xavier_uniform_(self.output_proj.weight) + init.constant_(self.output_proj.bias, 0) + + + def forward(self, + query, + reference_points, + value, + value_spatial_shapes, + value_mask=None): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_level_start_index (List): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + Len_v = value.shape[1] + + value = self.value_proj(value) + if value_mask is not None: + value_mask = value_mask.astype(value.dtype).unsqueeze(-1) + value *= value_mask + value = value.reshape(bs, Len_v, self.num_heads, self.head_dim) + + sampling_offsets = self.sampling_offsets(query).reshape( + bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).reshape( + bs, Len_q, self.num_heads, self.num_levels * self.num_points) + attention_weights = F.softmax(attention_weights, dim=-1).reshape( + bs, Len_q, self.num_heads, self.num_levels, self.num_points) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape( + 1, 1, 1, self.num_levels, 1, 2) + sampling_locations = reference_points.reshape( + bs, Len_q, 1, self.num_levels, 1, 2 + ) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + sampling_locations = ( + reference_points[:, :, None, :, None, :2] + sampling_offsets / + self.num_points * reference_points[:, :, None, :, None, 2:] * 0.5) + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + return output + + +class TransformerDecoderLayer(nn.Module): + def __init__(self, + d_model=256, + n_head=8, + dim_feedforward=1024, + dropout=0., + activation="relu", + n_levels=4, + n_points=4,): + super(TransformerDecoderLayer, self).__init__() + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout, batch_first=True) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # cross attention + self.cross_attn = MSDeformableAttention(d_model, n_head, n_levels, n_points) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.activation = getattr(F, activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + self.cope = CoPE(12,d_model) + + # self._reset_parameters() + + # def _reset_parameters(self): + # linear_init_(self.linear1) + # linear_init_(self.linear2) + # xavier_uniform_(self.linear1.weight) + # xavier_uniform_(self.linear2.weight) + + def with_pos_embed(self, tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + return self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + + def forward(self, + tgt, + reference_points, + memory, + memory_spatial_shapes, + memory_level_start_index, + attn_mask=None, + memory_mask=None, + query_pos_embed=None): + # self attention + #print(query_pos_embed.shape) + qk = torch.bmm (tgt ,tgt.transpose(-1 ,-2)) + mask = torch.tril(torch.ones_like(qk),diagonal=0) + mask = torch.log(mask) + query_pos_embed = self.cope(tgt,qk+mask) #position_embedding + + + n_tgt = tgt.cpu().detach().numpy() + + itgt = tgt.new_tensor(np.array([sl.pinv(i) for i in n_tgt])) #inv_tgt + +# print('qk:',qk.shape) +# print('tgt:',tgt.shape) +# print((query_pos_embed@itgt.transpose(-1,-2)).shape) +# print('ik:',itgt.shape) + + # print(torch.round(itgt@tgt)) + # print(tgt@tgt.transpose(-1,-2)) + + k = tgt + q = tgt + (query_pos_embed@itgt.transpose(-1,-2)) + + # print((q@(k.transpose(-1,-2))-query_pos_embed)) + + # if attn_mask is not None: + # attn_mask = torch.where( + # attn_mask.to(torch.bool), + # torch.zeros_like(attn_mask), + # torch.full_like(attn_mask, float('-inf'), dtype=tgt.dtype)) + + # q = k = self.with_pos_embed(tgt, query_pos_embed) + tgt2, _ = self.self_attn(q, k, value=tgt, attn_mask=attn_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # cross attention + tgt2 = self.cross_attn(\ + self.with_pos_embed(tgt, (query_pos_embed@itgt.transpose(-1,-2))), #self.with_pos_embed(tgt, query_pos_embed), + reference_points, + memory, + memory_spatial_shapes, + memory_mask) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # ffn + tgt2 = self.forward_ffn(tgt) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class TransformerDecoder(nn.Module): + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + super(TransformerDecoder, self).__init__() + self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward(self, + tgt, + ref_points_unact, + memory, + memory_spatial_shapes, + memory_level_start_index, + bbox_head, + score_head, + query_pos_head, + attn_mask=None, + memory_mask=None): + output = tgt + dec_out_bboxes = [] + dec_out_logits = [] + ref_points_detach = F.sigmoid(ref_points_unact) + + for i, layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = query_pos_head(ref_points_detach) + + output = layer(output, ref_points_input, memory, + memory_spatial_shapes, memory_level_start_index, + attn_mask, memory_mask, query_pos_embed) + + inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach)) + + if self.training: + dec_out_logits.append(score_head[i](output)) + if i == 0: + dec_out_bboxes.append(inter_ref_bbox) + else: + dec_out_bboxes.append(F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points))) + + elif i == self.eval_idx: + dec_out_logits.append(score_head[i](output)) + dec_out_bboxes.append(inter_ref_bbox) + break + + ref_points = inter_ref_bbox + ref_points_detach = inter_ref_bbox.detach( + ) if self.training else inter_ref_bbox + + return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits) + + +@register +class RTDETRTransformer(nn.Module): + __share__ = ['num_classes'] + def __init__(self, + num_classes=80, + hidden_dim=256, + num_queries=300, + position_embed_type='sine', + feat_channels=[512, 1024, 2048], + feat_strides=[8, 16, 32], + num_levels=3, + num_decoder_points=4, + nhead=8, + num_decoder_layers=6, + dim_feedforward=1024, + dropout=0., + activation="relu", + num_denoising=100, + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=False, + eval_spatial_size=None, + eval_idx=-1, + eps=1e-2, + aux_loss=True): + + super(RTDETRTransformer, self).__init__() + assert position_embed_type in ['sine', 'learned'], \ + f'ValueError: position_embed_type not supported {position_embed_type}!' + assert len(feat_channels) <= num_levels + assert len(feat_strides) == len(feat_channels) + for _ in range(num_levels - len(feat_strides)): + feat_strides.append(feat_strides[-1] * 2) + + self.hidden_dim = hidden_dim + self.nhead = nhead + self.feat_strides = feat_strides + self.num_levels = num_levels + self.num_classes = num_classes + self.num_queries = num_queries + self.eps = eps + self.num_decoder_layers = num_decoder_layers + self.eval_spatial_size = eval_spatial_size + self.aux_loss = aux_loss + + # backbone feature projection + self._build_input_proj_layer(feat_channels) + + # Transformer module + decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, activation, num_levels, num_decoder_points) + self.decoder = TransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx) + + self.num_denoising = num_denoising + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + # denoising part + if num_denoising > 0: + # self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim, padding_idx=num_classes-1) # TODO for load paddle weights + self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes) + + # decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim) + self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) + + # encoder head + self.enc_output = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim,) + ) + self.enc_score_head = nn.Linear(hidden_dim, num_classes) + self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) + + # decoder head + self.dec_score_head = nn.ModuleList([ + nn.Linear(hidden_dim, num_classes) + for _ in range(num_decoder_layers) + ]) + self.dec_bbox_head = nn.ModuleList([ + MLP(hidden_dim, hidden_dim, 4, num_layers=3) + for _ in range(num_decoder_layers) + ]) + + # init encoder output anchors and valid_mask + if self.eval_spatial_size: + self.anchors, self.valid_mask = self._generate_anchors() + + self._reset_parameters() + + def _reset_parameters(self): + bias = bias_init_with_prob(0.01) + + init.constant_(self.enc_score_head.bias, bias) + init.constant_(self.enc_bbox_head.layers[-1].weight, 0) + init.constant_(self.enc_bbox_head.layers[-1].bias, 0) + + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + init.constant_(cls_.bias, bias) + init.constant_(reg_.layers[-1].weight, 0) + init.constant_(reg_.layers[-1].bias, 0) + + # linear_init_(self.enc_output[0]) + init.xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + init.xavier_uniform_(self.tgt_embed.weight) + init.xavier_uniform_(self.query_pos_head.layers[0].weight) + init.xavier_uniform_(self.query_pos_head.layers[1].weight) + + + def _build_input_proj_layer(self, feat_channels): + self.input_proj = nn.ModuleList() + for in_channels in feat_channels: + self.input_proj.append( + nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(in_channels, self.hidden_dim, 1, bias=False)), + ('norm', nn.BatchNorm2d(self.hidden_dim,))]) + ) + ) + + in_channels = feat_channels[-1] + + for _ in range(self.num_levels - len(feat_channels)): + self.input_proj.append( + nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(in_channels, self.hidden_dim, 3, 2, padding=1, bias=False)), + ('norm', nn.BatchNorm2d(self.hidden_dim))]) + ) + ) + in_channels = self.hidden_dim + + def _get_encoder_input(self, feats): + # get projection features + proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] + if self.num_levels > len(proj_feats): + len_srcs = len(proj_feats) + for i in range(len_srcs, self.num_levels): + if i == len_srcs: + proj_feats.append(self.input_proj[i](feats[-1])) + else: + proj_feats.append(self.input_proj[i](proj_feats[-1])) + + # get encoder inputs + feat_flatten = [] + spatial_shapes = [] + level_start_index = [0, ] + for i, feat in enumerate(proj_feats): + _, _, h, w = feat.shape + # [b, c, h, w] -> [b, h*w, c] + feat_flatten.append(feat.flatten(2).permute(0, 2, 1)) + # [num_levels, 2] + spatial_shapes.append([h, w]) + # [l], start index of each level + level_start_index.append(h * w + level_start_index[-1]) + + # [b, l, c] + feat_flatten = torch.concat(feat_flatten, 1) + level_start_index.pop() + return (feat_flatten, spatial_shapes, level_start_index) + + def _generate_anchors(self, + spatial_shapes=None, + grid_size=0.05, + dtype=torch.float32, + device='cpu'): + if spatial_shapes is None: + spatial_shapes = [[int(self.eval_spatial_size[0] / s), int(self.eval_spatial_size[1] / s)] + for s in self.feat_strides + ] + anchors = [] + for lvl, (h, w) in enumerate(spatial_shapes): + grid_y, grid_x = torch.meshgrid(\ + torch.arange(end=h, dtype=dtype), \ + torch.arange(end=w, dtype=dtype), indexing='ij') + grid_xy = torch.stack([grid_x, grid_y], -1) + valid_WH = torch.tensor([w, h]).to(dtype) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH + wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl) + anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, h * w, 4)) + + anchors = torch.concat(anchors, 1).to(device) + valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) + anchors = torch.log(anchors / (1 - anchors)) + # anchors = torch.where(valid_mask, anchors, float('inf')) + # anchors[valid_mask] = torch.inf # valid_mask [1, 8400, 1] + anchors = torch.where(valid_mask, anchors, torch.inf) + + return anchors, valid_mask + + + def _get_decoder_input(self, + memory, + spatial_shapes, + denoising_class=None, + denoising_bbox_unact=None): + bs, _, _ = memory.shape + # prepare input for decoder + if self.training or self.eval_spatial_size is None: + anchors, valid_mask = self._generate_anchors(spatial_shapes, device=memory.device) + else: + anchors, valid_mask = self.anchors.to(memory.device), self.valid_mask.to(memory.device) + + # memory = torch.where(valid_mask, memory, 0) + memory = valid_mask.to(memory.dtype) * memory # TODO fix type error for onnx export + + output_memory = self.enc_output(memory) + + enc_outputs_class = self.enc_score_head(output_memory) + enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors + + _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1) + + reference_points_unact = enc_outputs_coord_unact.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_unact.shape[-1])) + + enc_topk_bboxes = F.sigmoid(reference_points_unact) + if denoising_bbox_unact is not None: + reference_points_unact = torch.concat( + [denoising_bbox_unact, reference_points_unact], 1) + + enc_topk_logits = enc_outputs_class.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])) + + # extract region features + if self.learnt_init_query: + target = self.tgt_embed.weight.unsqueeze(0).tile([bs, 1, 1]) + else: + target = output_memory.gather(dim=1, \ + index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) + target = target.detach() + + if denoising_class is not None: + target = torch.concat([denoising_class, target], 1) + + return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits + + + def forward(self, feats, targets=None): + + # input projection and embedding + (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) + + # prepare denoising training + if self.training and self.num_denoising > 0: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ + get_contrastive_denoising_training_group(targets, \ + self.num_classes, + self.num_queries, + self.denoising_class_embed, + num_denoising=self.num_denoising, + label_noise_ratio=self.label_noise_ratio, + box_noise_scale=self.box_noise_scale, ) + else: + denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None + + target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ + self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact) + + # decoder + out_bboxes, out_logits = self.decoder( + target, + init_ref_points_unact, + memory, + spatial_shapes, + level_start_index, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask) + + if self.training and dn_meta is not None: + dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2) + dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2) + + out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} + + if self.training and self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) + out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes])) + + if self.training and dn_meta is not None: + out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) + out['dn_meta'] = dn_meta + + return out + + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class, outputs_coord)] diff --git a/src/zoo/rtdetr/rtdetr_postprocessor.py b/src/zoo/rtdetr/rtdetr_postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..344d69ac3b24f431a1c06cf222e2f6e540f36552 --- /dev/null +++ b/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -0,0 +1,80 @@ +"""by lyuwenyu +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import torchvision + +from src.core import register + + +__all__ = ['RTDETRPostProcessor'] + + +@register +class RTDETRPostProcessor(nn.Module): + __share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category'] + + def __init__(self, num_classes=80, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None: + super().__init__() + self.use_focal_loss = use_focal_loss + self.num_top_queries = num_top_queries + self.num_classes = num_classes + self.remap_mscoco_category = remap_mscoco_category + self.deploy_mode = False + + def extra_repr(self) -> str: + return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' + + # def forward(self, outputs, orig_target_sizes): + def forward(self, outputs, orig_target_sizes): + + logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] + # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') + bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1) + labels = index % self.num_classes + index = index // self.num_classes + boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + # TODO for onnx export + if self.deploy_mode: + return labels, boxes, scores + + # TODO + if self.remap_mscoco_category: + from ...data.coco import mscoco_label2category + labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ + .to(boxes.device).reshape(labels.shape) + + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = dict(labels=lab, boxes=box, scores=sco) + results.append(result) + + return results + + + def deploy(self, ): + self.eval() + self.deploy_mode = True + return self + + @property + def iou_types(self, ): + return ('bbox', ) diff --git a/src/zoo/rtdetr/utils.py b/src/zoo/rtdetr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f44cc52c2f0b8ebed01cbb7c49b2317954a7582 --- /dev/null +++ b/src/zoo/rtdetr/utils.py @@ -0,0 +1,101 @@ +"""by lyuwenyu +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def inverse_sigmoid(x: torch.Tensor, eps: float=1e-5) -> torch.Tensor: + x = x.clip(min=0., max=1.) + return torch.log(x.clip(min=eps) / (1 - x).clip(min=eps)) + + +def deformable_attention_core_func(value, value_spatial_shapes, sampling_locations, attention_weights): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, _, n_head, c = value.shape + _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape + + split_shape = [h * w for h, w in value_spatial_shapes] + value_list = value.split(split_shape, dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[level].flatten(2).permute( + 0, 2, 1).reshape(bs * n_head, c, h, w) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].permute( + 0, 2, 1, 3, 4).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample( + value_l_, + sampling_grid_l_, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.permute(0, 2, 1, 3, 4).reshape( + bs * n_head, 1, Len_q, n_levels * n_points) + output = (torch.stack( + sampling_value_list, dim=-2).flatten(-2) * + attention_weights).sum(-1).reshape(bs, n_head * c, Len_q) + + return output.permute(0, 2, 1) + + +import math +def bias_init_with_prob(prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-math.log((1 - prior_prob) / prior_prob)) + return bias_init + + + +def get_activation(act: str, inpace: bool=True): + '''get activation + ''' + act = act.lower() + + if act == 'silu': + m = nn.SiLU() + + elif act == 'relu': + m = nn.ReLU() + + elif act == 'leaky_relu': + m = nn.LeakyReLU() + + elif act == 'silu': + m = nn.SiLU() + + elif act == 'gelu': + m = nn.GELU() + + elif act is None: + m = nn.Identity() + + elif isinstance(act, nn.Module): + m = act + + else: + raise RuntimeError('') + + if hasattr(m, 'inplace'): + m.inplace = inpace + + return m + +