|
from .swinv2_model import SwinTransformerV2 |
|
from .unet_swin_model import unet_swin |
|
from .mim.mim import build_mim_model |
|
from ..training.mim_utils import load_pretrained |
|
|
|
import logging |
|
|
|
|
|
def build_model(config, |
|
pretrain: bool = False, |
|
pretrain_method: str = 'mim', |
|
logger: logging.Logger = None): |
|
""" |
|
Given a config object, builds a pytorch model. |
|
|
|
Returns: |
|
model: built model |
|
""" |
|
|
|
if pretrain: |
|
|
|
if pretrain_method == 'mim': |
|
model = build_mim_model(config) |
|
return model |
|
|
|
encoder_architecture = config.MODEL.TYPE |
|
decoder_architecture = config.MODEL.DECODER |
|
|
|
if encoder_architecture == 'swinv2': |
|
|
|
logger.info(f'Hit encoder only build, building {encoder_architecture}') |
|
|
|
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES |
|
|
|
model = SwinTransformerV2( |
|
img_size=config.DATA.IMG_SIZE, |
|
patch_size=config.MODEL.SWINV2.PATCH_SIZE, |
|
in_chans=config.MODEL.SWINV2.IN_CHANS, |
|
num_classes=config.MODEL.NUM_CLASSES, |
|
embed_dim=config.MODEL.SWINV2.EMBED_DIM, |
|
depths=config.MODEL.SWINV2.DEPTHS, |
|
num_heads=config.MODEL.SWINV2.NUM_HEADS, |
|
window_size=config.MODEL.SWINV2.WINDOW_SIZE, |
|
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, |
|
qkv_bias=config.MODEL.SWINV2.QKV_BIAS, |
|
drop_rate=config.MODEL.DROP_RATE, |
|
drop_path_rate=config.MODEL.DROP_PATH_RATE, |
|
ape=config.MODEL.SWINV2.APE, |
|
patch_norm=config.MODEL.SWINV2.PATCH_NORM, |
|
use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
|
pretrained_window_sizes=window_sizes) |
|
|
|
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): |
|
load_pretrained(config, model, logger) |
|
|
|
else: |
|
|
|
errorMsg = f'Unknown encoder architecture {encoder_architecture}' |
|
|
|
logger.error(errorMsg) |
|
|
|
raise NotImplementedError(errorMsg) |
|
|
|
if decoder_architecture is not None: |
|
|
|
if encoder_architecture == 'swinv2': |
|
|
|
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES |
|
|
|
model = SwinTransformerV2( |
|
img_size=config.DATA.IMG_SIZE, |
|
patch_size=config.MODEL.SWINV2.PATCH_SIZE, |
|
in_chans=config.MODEL.SWINV2.IN_CHANS, |
|
num_classes=config.MODEL.NUM_CLASSES, |
|
embed_dim=config.MODEL.SWINV2.EMBED_DIM, |
|
depths=config.MODEL.SWINV2.DEPTHS, |
|
num_heads=config.MODEL.SWINV2.NUM_HEADS, |
|
window_size=config.MODEL.SWINV2.WINDOW_SIZE, |
|
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, |
|
qkv_bias=config.MODEL.SWINV2.QKV_BIAS, |
|
drop_rate=config.MODEL.DROP_RATE, |
|
drop_path_rate=config.MODEL.DROP_PATH_RATE, |
|
ape=config.MODEL.SWINV2.APE, |
|
patch_norm=config.MODEL.SWINV2.PATCH_NORM, |
|
use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
|
pretrained_window_sizes=window_sizes) |
|
|
|
else: |
|
|
|
raise NotImplementedError() |
|
|
|
if decoder_architecture == 'unet': |
|
|
|
num_classes = config.MODEL.NUM_CLASSES |
|
|
|
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): |
|
load_pretrained(config, model, logger) |
|
|
|
model = unet_swin(encoder=model, num_classes=num_classes) |
|
|
|
else: |
|
error_msg = f'Unknown decoder architecture: {decoder_architecture}' |
|
raise NotImplementedError(error_msg) |
|
|
|
return model |
|
|