Caleb Spradlin
initial commit
ab687e7
raw
history blame
3.59 kB
from .swinv2_model import SwinTransformerV2
from .unet_swin_model import unet_swin
from .mim.mim import build_mim_model
from ..training.mim_utils import load_pretrained
import logging
def build_model(config,
pretrain: bool = False,
pretrain_method: str = 'mim',
logger: logging.Logger = None):
"""
Given a config object, builds a pytorch model.
Returns:
model: built model
"""
if pretrain:
if pretrain_method == 'mim':
model = build_mim_model(config)
return model
encoder_architecture = config.MODEL.TYPE
decoder_architecture = config.MODEL.DECODER
if encoder_architecture == 'swinv2':
logger.info(f'Hit encoder only build, building {encoder_architecture}')
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES
model = SwinTransformerV2(
img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWINV2.PATCH_SIZE,
in_chans=config.MODEL.SWINV2.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWINV2.EMBED_DIM,
depths=config.MODEL.SWINV2.DEPTHS,
num_heads=config.MODEL.SWINV2.NUM_HEADS,
window_size=config.MODEL.SWINV2.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWINV2.APE,
patch_norm=config.MODEL.SWINV2.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
pretrained_window_sizes=window_sizes)
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
load_pretrained(config, model, logger)
else:
errorMsg = f'Unknown encoder architecture {encoder_architecture}'
logger.error(errorMsg)
raise NotImplementedError(errorMsg)
if decoder_architecture is not None:
if encoder_architecture == 'swinv2':
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES
model = SwinTransformerV2(
img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWINV2.PATCH_SIZE,
in_chans=config.MODEL.SWINV2.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWINV2.EMBED_DIM,
depths=config.MODEL.SWINV2.DEPTHS,
num_heads=config.MODEL.SWINV2.NUM_HEADS,
window_size=config.MODEL.SWINV2.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWINV2.APE,
patch_norm=config.MODEL.SWINV2.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
pretrained_window_sizes=window_sizes)
else:
raise NotImplementedError()
if decoder_architecture == 'unet':
num_classes = config.MODEL.NUM_CLASSES
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
load_pretrained(config, model, logger)
model = unet_swin(encoder=model, num_classes=num_classes)
else:
error_msg = f'Unknown decoder architecture: {decoder_architecture}'
raise NotImplementedError(error_msg)
return model