File size: 3,589 Bytes
ab687e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from .swinv2_model import SwinTransformerV2
from .unet_swin_model import unet_swin
from .mim.mim import build_mim_model
from ..training.mim_utils import load_pretrained

import logging


def build_model(config,
                pretrain: bool = False,
                pretrain_method: str = 'mim',
                logger: logging.Logger = None):
    """
    Given a config object, builds a pytorch model.

    Returns:
        model: built model
    """

    if pretrain:

        if pretrain_method == 'mim':
            model = build_mim_model(config)
            return model

    encoder_architecture = config.MODEL.TYPE
    decoder_architecture = config.MODEL.DECODER

    if encoder_architecture == 'swinv2':

        logger.info(f'Hit encoder only build, building {encoder_architecture}')

        window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES

        model = SwinTransformerV2(
            img_size=config.DATA.IMG_SIZE,
            patch_size=config.MODEL.SWINV2.PATCH_SIZE,
            in_chans=config.MODEL.SWINV2.IN_CHANS,
            num_classes=config.MODEL.NUM_CLASSES,
            embed_dim=config.MODEL.SWINV2.EMBED_DIM,
            depths=config.MODEL.SWINV2.DEPTHS,
            num_heads=config.MODEL.SWINV2.NUM_HEADS,
            window_size=config.MODEL.SWINV2.WINDOW_SIZE,
            mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
            qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
            drop_rate=config.MODEL.DROP_RATE,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            ape=config.MODEL.SWINV2.APE,
            patch_norm=config.MODEL.SWINV2.PATCH_NORM,
            use_checkpoint=config.TRAIN.USE_CHECKPOINT,
            pretrained_window_sizes=window_sizes)

        if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
            load_pretrained(config, model, logger)

    else:

        errorMsg = f'Unknown encoder architecture {encoder_architecture}'

        logger.error(errorMsg)

        raise NotImplementedError(errorMsg)

    if decoder_architecture is not None:

        if encoder_architecture == 'swinv2':

            window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES

            model = SwinTransformerV2(
                img_size=config.DATA.IMG_SIZE,
                patch_size=config.MODEL.SWINV2.PATCH_SIZE,
                in_chans=config.MODEL.SWINV2.IN_CHANS,
                num_classes=config.MODEL.NUM_CLASSES,
                embed_dim=config.MODEL.SWINV2.EMBED_DIM,
                depths=config.MODEL.SWINV2.DEPTHS,
                num_heads=config.MODEL.SWINV2.NUM_HEADS,
                window_size=config.MODEL.SWINV2.WINDOW_SIZE,
                mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
                qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
                drop_rate=config.MODEL.DROP_RATE,
                drop_path_rate=config.MODEL.DROP_PATH_RATE,
                ape=config.MODEL.SWINV2.APE,
                patch_norm=config.MODEL.SWINV2.PATCH_NORM,
                use_checkpoint=config.TRAIN.USE_CHECKPOINT,
                pretrained_window_sizes=window_sizes)

        else:

            raise NotImplementedError()

        if decoder_architecture == 'unet':

            num_classes = config.MODEL.NUM_CLASSES

            if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
                load_pretrained(config, model, logger)

            model = unet_swin(encoder=model, num_classes=num_classes)

        else:
            error_msg = f'Unknown decoder architecture: {decoder_architecture}'
            raise NotImplementedError(error_msg)

    return model