File size: 3,589 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from .swinv2_model import SwinTransformerV2
from .unet_swin_model import unet_swin
from .mim.mim import build_mim_model
from ..training.mim_utils import load_pretrained
import logging
def build_model(config,
pretrain: bool = False,
pretrain_method: str = 'mim',
logger: logging.Logger = None):
"""
Given a config object, builds a pytorch model.
Returns:
model: built model
"""
if pretrain:
if pretrain_method == 'mim':
model = build_mim_model(config)
return model
encoder_architecture = config.MODEL.TYPE
decoder_architecture = config.MODEL.DECODER
if encoder_architecture == 'swinv2':
logger.info(f'Hit encoder only build, building {encoder_architecture}')
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES
model = SwinTransformerV2(
img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWINV2.PATCH_SIZE,
in_chans=config.MODEL.SWINV2.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWINV2.EMBED_DIM,
depths=config.MODEL.SWINV2.DEPTHS,
num_heads=config.MODEL.SWINV2.NUM_HEADS,
window_size=config.MODEL.SWINV2.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWINV2.APE,
patch_norm=config.MODEL.SWINV2.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
pretrained_window_sizes=window_sizes)
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
load_pretrained(config, model, logger)
else:
errorMsg = f'Unknown encoder architecture {encoder_architecture}'
logger.error(errorMsg)
raise NotImplementedError(errorMsg)
if decoder_architecture is not None:
if encoder_architecture == 'swinv2':
window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES
model = SwinTransformerV2(
img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWINV2.PATCH_SIZE,
in_chans=config.MODEL.SWINV2.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWINV2.EMBED_DIM,
depths=config.MODEL.SWINV2.DEPTHS,
num_heads=config.MODEL.SWINV2.NUM_HEADS,
window_size=config.MODEL.SWINV2.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWINV2.APE,
patch_norm=config.MODEL.SWINV2.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
pretrained_window_sizes=window_sizes)
else:
raise NotImplementedError()
if decoder_architecture == 'unet':
num_classes = config.MODEL.NUM_CLASSES
if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
load_pretrained(config, model, logger)
model = unet_swin(encoder=model, num_classes=num_classes)
else:
error_msg = f'Unknown decoder architecture: {decoder_architecture}'
raise NotImplementedError(error_msg)
return model
|