|
"""by lyuwenyu |
|
""" |
|
|
|
from pprint import pprint |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.optim import Optimizer |
|
from torch.optim.lr_scheduler import LRScheduler |
|
from torch.cuda.amp.grad_scaler import GradScaler |
|
|
|
from typing import Callable, List, Dict |
|
|
|
|
|
__all__ = ['BaseConfig', ] |
|
|
|
|
|
|
|
class BaseConfig(object): |
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
self.task :str = None |
|
|
|
self._model :nn.Module = None |
|
self._postprocessor :nn.Module = None |
|
self._criterion :nn.Module = None |
|
self._optimizer :Optimizer = None |
|
self._lr_scheduler :LRScheduler = None |
|
self._train_dataloader :DataLoader = None |
|
self._val_dataloader :DataLoader = None |
|
self._ema :nn.Module = None |
|
self._scaler :GradScaler = None |
|
|
|
self.train_dataset :Dataset = None |
|
self.val_dataset :Dataset = None |
|
self.num_workers :int = 0 |
|
self.collate_fn :Callable = None |
|
|
|
self.batch_size :int = None |
|
self._train_batch_size :int = None |
|
self._val_batch_size :int = None |
|
self._train_shuffle: bool = None |
|
self._val_shuffle: bool = None |
|
|
|
self.evaluator :Callable[[nn.Module, DataLoader, str], ] = None |
|
|
|
|
|
self.resume :str = None |
|
self.tuning :str = None |
|
|
|
self.epoches :int = None |
|
self.last_epoch :int = -1 |
|
self.end_epoch :int = None |
|
|
|
self.use_amp :bool = False |
|
self.use_ema :bool = False |
|
self.sync_bn :bool = False |
|
self.clip_max_norm : float = None |
|
self.find_unused_parameters :bool = None |
|
|
|
|
|
|
|
self.log_dir :str = './logs/' |
|
self.log_step :int = 10 |
|
self._output_dir :str = None |
|
self._print_freq :int = None |
|
self.checkpoint_step :int = 1 |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.device = torch.device(device) |
|
|
|
|
|
@property |
|
def model(self, ) -> nn.Module: |
|
return self._model |
|
|
|
@model.setter |
|
def model(self, m): |
|
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' |
|
self._model = m |
|
|
|
@property |
|
def postprocessor(self, ) -> nn.Module: |
|
return self._postprocessor |
|
|
|
@postprocessor.setter |
|
def postprocessor(self, m): |
|
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' |
|
self._postprocessor = m |
|
|
|
@property |
|
def criterion(self, ) -> nn.Module: |
|
return self._criterion |
|
|
|
@criterion.setter |
|
def criterion(self, m): |
|
assert isinstance(m, nn.Module), f'{type(m)} != nn.Module, please check your model class' |
|
self._criterion = m |
|
|
|
@property |
|
def optimizer(self, ) -> Optimizer: |
|
return self._optimizer |
|
|
|
@optimizer.setter |
|
def optimizer(self, m): |
|
assert isinstance(m, Optimizer), f'{type(m)} != optim.Optimizer, please check your model class' |
|
self._optimizer = m |
|
|
|
@property |
|
def lr_scheduler(self, ) -> LRScheduler: |
|
return self._lr_scheduler |
|
|
|
@lr_scheduler.setter |
|
def lr_scheduler(self, m): |
|
assert isinstance(m, LRScheduler), f'{type(m)} != LRScheduler, please check your model class' |
|
self._lr_scheduler = m |
|
|
|
|
|
@property |
|
def train_dataloader(self): |
|
if self._train_dataloader is None and self.train_dataset is not None: |
|
loader = DataLoader(self.train_dataset, |
|
batch_size=self.train_batch_size, |
|
num_workers=self.num_workers, |
|
collate_fn=self.collate_fn, |
|
shuffle=self.train_shuffle, ) |
|
loader.shuffle = self.train_shuffle |
|
self._train_dataloader = loader |
|
|
|
return self._train_dataloader |
|
|
|
@train_dataloader.setter |
|
def train_dataloader(self, loader): |
|
self._train_dataloader = loader |
|
|
|
@property |
|
def val_dataloader(self): |
|
if self._val_dataloader is None and self.val_dataset is not None: |
|
loader = DataLoader(self.val_dataset, |
|
batch_size=self.val_batch_size, |
|
num_workers=self.num_workers, |
|
drop_last=False, |
|
collate_fn=self.collate_fn, |
|
shuffle=self.val_shuffle) |
|
loader.shuffle = self.val_shuffle |
|
self._val_dataloader = loader |
|
|
|
return self._val_dataloader |
|
|
|
@val_dataloader.setter |
|
def val_dataloader(self, loader): |
|
self._val_dataloader = loader |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def ema(self, ) -> nn.Module: |
|
return self._ema |
|
|
|
@ema.setter |
|
def ema(self, obj): |
|
self._ema = obj |
|
|
|
|
|
@property |
|
def scaler(self) -> GradScaler: |
|
if self._scaler is None and self.use_amp and torch.cuda.is_available(): |
|
self._scaler = GradScaler() |
|
return self._scaler |
|
|
|
@scaler.setter |
|
def scaler(self, obj: GradScaler): |
|
self._scaler = obj |
|
|
|
|
|
@property |
|
def val_shuffle(self): |
|
if self._val_shuffle is None: |
|
print('warning: set default val_shuffle=False') |
|
return False |
|
return self._val_shuffle |
|
|
|
@val_shuffle.setter |
|
def val_shuffle(self, shuffle): |
|
assert isinstance(shuffle, bool), 'shuffle must be bool' |
|
self._val_shuffle = shuffle |
|
|
|
@property |
|
def train_shuffle(self): |
|
if self._train_shuffle is None: |
|
print('warning: set default train_shuffle=True') |
|
return True |
|
return self._train_shuffle |
|
|
|
@train_shuffle.setter |
|
def train_shuffle(self, shuffle): |
|
assert isinstance(shuffle, bool), 'shuffle must be bool' |
|
self._train_shuffle = shuffle |
|
|
|
|
|
@property |
|
def train_batch_size(self): |
|
if self._train_batch_size is None and isinstance(self.batch_size, int): |
|
print(f'warning: set train_batch_size=batch_size={self.batch_size}') |
|
return self.batch_size |
|
return self._train_batch_size |
|
|
|
@train_batch_size.setter |
|
def train_batch_size(self, batch_size): |
|
assert isinstance(batch_size, int), 'batch_size must be int' |
|
self._train_batch_size = batch_size |
|
|
|
@property |
|
def val_batch_size(self): |
|
if self._val_batch_size is None: |
|
print(f'warning: set val_batch_size=batch_size={self.batch_size}') |
|
return self.batch_size |
|
return self._val_batch_size |
|
|
|
@val_batch_size.setter |
|
def val_batch_size(self, batch_size): |
|
assert isinstance(batch_size, int), 'batch_size must be int' |
|
self._val_batch_size = batch_size |
|
|
|
|
|
@property |
|
def output_dir(self): |
|
if self._output_dir is None: |
|
return self.log_dir |
|
return self._output_dir |
|
|
|
@output_dir.setter |
|
def output_dir(self, root): |
|
self._output_dir = root |
|
|
|
@property |
|
def print_freq(self): |
|
if self._print_freq is None: |
|
|
|
return self.log_step |
|
return self._print_freq |
|
|
|
@print_freq.setter |
|
def print_freq(self, n): |
|
assert isinstance(n, int), 'print_freq must be int' |
|
self._print_freq = n |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|