File size: 2,302 Bytes
9cc3eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from mmcv.ops import RoIAlign
from mmdet.models import FPN, SingleRoIExtractor

from app.models.model import SAMSegmentor
from app.models.openclip_backbone import OpenCLIPBackbone
from app.models.ovsam_head import OVSAMHead
from app.models.sam_pe import SAMPromptEncoder
from app.models.transformer_neck import MultiLayerTransformerNeck

model = dict(
    type=SAMSegmentor,
    data_preprocessor=None,
    enable_backbone=True,
    backbone=dict(
        type=OpenCLIPBackbone,
        model_name='RN50x16',
        fix=True,
        init_cfg=dict(
            type='clip_pretrain',
            checkpoint='openai'
        )
    ),
    neck=dict(
        type=MultiLayerTransformerNeck,
        input_size=(1024, 1024),
        in_channels=[384, 768, 1536, 3072],
        strides=[4, 8, 16, 32],
        layer_ids=(0, 1, 2, 3),
        embed_channels=1280,
        out_channels=256,
        fix=True,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='./models/sam2clip_vith_rn50.pth',
            prefix='neck_student',
        )
    ),
    prompt_encoder=dict(
        type=SAMPromptEncoder,
        model_name='vit_h',
        fix=True,
        init_cfg=dict(
            type='sam_pretrain',
            checkpoint='vit_h'
        )
    ),
    fpn_neck=dict(
        type=FPN,
        in_channels=[384, 768, 1536, 3072],
        out_channels=256,
        num_outs=4,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='./models/R50x16_fpn_lvis_norare_v3det.pth',
            prefix='fpn_neck',
        ),
    ),
    mask_decoder=dict(
        type=OVSAMHead,
        model_name='vit_h',
        with_label_token=True,
        gen_box=True,
        ov_classifier_name='RN50x16_LVISV1Dataset',
        roi_extractor=dict(
            type=SingleRoIExtractor,
            roi_layer=dict(type=RoIAlign, output_size=7, sampling_ratio=0),
            out_channels=256,
            featmap_strides=[4, 8, 16, 32]
        ),
        fix=False,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='./models/ovsam_R50x16_lvisnorare.pth',
            prefix='mask_decoder',
        ),
        load_roi_conv=dict(
            checkpoint='./models/R50x16_fpn_lvis_norare_v3det.pth',
            prefix='roi_conv',
        )
    )
)