File size: 4,929 Bytes
be2715b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import os 
join = os.path.join
import gc
from tqdm import tqdm
import torch
import monai, random
from dataloader.sam_transforms import ResizeLongestSide
from segment_anything import (
    sam_model_registry,
    our_vit
)
from dataloader.dataloader import sam_dataloader
from utils.SurfaceDice import compute_dice_coefficient

#%% test
def eval_dice(sam_model,
            lvm_med_encoder_path,
            loader,
            device):
    
    """
    Function to evaluate model (for both validation and testing phase)
    """
    
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    """
    Declare LVM Med backbone instead of using SAM's backbone
    """
    arch = 'vit_encoder_b'
    lvm_med_backbone = our_vit.__dict__[arch]()
    lvm_weight = torch.load(lvm_med_encoder_path, map_location ='cpu')
    lvm_med_backbone.load_state_dict(lvm_weight)
    
    dice_score = 0.
    for _, batch in enumerate(tqdm(loader, leave=False)):
        """
        Load precomputed embeddings, mask labels and bounding boxes computed directly from ground truth masks
        """
        image, true_mask, boxes = batch['image'], batch['mask'], batch['bboxes']
        image = image.to(f"cuda:{device}")
        true_mask = true_mask.to(f"cuda:{device}", dtype=torch.float32)

        """
        Compute image embeddings
        """
        encoder = torch.nn.DataParallel(lvm_med_backbone, device_ids=[3, 2, 1, 0], output_device=device)
        encoder = encoder.to(f"cuda:{encoder.device_ids[0]}")
        sam_model = sam_model.to(f"cuda:{encoder.device_ids[0]}")
        image = image.to(f"cuda:{encoder.device_ids[0]}")
        image = sam_model.preprocess(image[:, :, :])
        image_embedding = encoder(image)

        """
        Get bboxes
        """
        box_np = boxes.numpy()
        sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
        box = sam_trans.apply_boxes(box_np, (image_embedding.shape[0], image_embedding.shape[1]))
        box_torch = torch.as_tensor(box, dtype=torch.float32, device=device)
        if len(box_torch.shape) == 2:
            box_torch = box_torch[:, None, :] # (B, 1, 4)
        
        """
        Prompt encoder component
        """
        prompt_encoder = torch.nn.DataParallel(sam_model.prompt_encoder, device_ids=[0,1,2,3], output_device=device)
        prompt_encoder = prompt_encoder.to(f"cuda:{prompt_encoder.device_ids[0]}")
        box_torch = box_torch.to(f"cuda:{prompt_encoder.device_ids[0]}")
        sparse_embeddings, dense_embeddings = prompt_encoder(
            points=None,
            boxes=box_torch,
            masks=None,
        )

        """
        Mask decoder component
        """
        sam_model = sam_model.to(f"cuda:{device}")
        mask_segmentation, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding.to(f"cuda:{device}"), # (B, 256, 64, 64)
            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
        ) # -> (B, 256, 256)
        
        """
        Transform prediction and evaluate
        """
        true_mask = true_mask.to("cpu")
        medsam_seg_prob = torch.sigmoid(mask_segmentation)
        medsam_seg_prob = medsam_seg_prob.detach().cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8) # transform from hard masks to soft masks
        dice_score += compute_dice_coefficient(true_mask>0, medsam_seg>0)
    
    return dice_score.cpu().numpy()/len(loader) 

def zero_shot_lvmmed_sam_2d(yml_args, cfg):
    """
    Training warm up
    """
    torch.multiprocessing.set_start_method('spawn')
    
    random.seed(cfg.base.random_seed)
    np.random.seed(cfg.base.random_seed)
    torch.manual_seed(cfg.base.random_seed)
    torch.cuda.manual_seed(cfg.base.random_seed)

    torch.backends.cudnn.deterministic = True

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    """
    General configuration
    """
    img_shape = (3, 1024) # hard settings image shape as 3 x 1024 x 1024

    """
    Load SAM with its original checkpoint 
    """
    sam_model = sam_model_registry["vit_b"](checkpoint=cfg.base.original_checkpoint)

    """
    Load precomputed embeddings
    """    
    _, _, test_loader, _, _ = sam_dataloader(cfg)

    """
    Test model
    """
    with torch.no_grad():
        sam_model.eval()
        test_dice_score = eval_dice(sam_model,
                                    lvm_med_encoder_path=yml_args.lvm_med_encoder_path,
                                    loader=test_loader,
                                    device=cfg.base.gpu_id)
        print(f"Dice score from zero-shot SAM: {test_dice_score*100}")