File size: 6,043 Bytes
236b7bb
 
 
 
 
 
6c5760f
236b7bb
7637f71
 
 
 
 
8a50c06
 
 
 
 
236b7bb
 
 
 
 
 
 
7637f71
5b9ff28
7a00437
 
236b7bb
5b9ff28
236b7bb
 
 
 
 
 
 
 
 
 
 
4023190
236b7bb
 
7637f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236b7bb
7637f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import subprocess
import sys
import requests
import zipfile
import gradio as gr
import torch
import numpy as np
from torchvision.transforms import ToTensor
from PIL import Image
import cv2

# Ensure the necessary model files are available
def download_file(url, destination):
    response = requests.get(url, stream=True)
    with open(destination, 'wb') as f:
        f.write(response.content)

# Download SAM model
if not os.path.exists("weights/sam_vit_h_4b8939.pth"):
    os.makedirs("weights", exist_ok=True)
    download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "weights/sam_vit_h_4b8939.pth")

# Add EfficientSAM to Python path
sys.path.append(os.path.abspath("EfficientSAM-main"))

# Import SAM and EfficientSAM modules
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from efficient_sam.build_efficient_sam import build_efficient_sam_vits

# Constants
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"

# Load SAM model
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator_sam = SamAutomaticMaskGenerator(sam)

# Load EfficientSAM model
with zipfile.ZipFile("EfficientSAM-main/weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
    zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()

from segment_anything.utils.amg import (
    batched_mask_to_box,
    calculate_stability_score,
    mask_to_rle_pytorch,
    remove_small_regions,
    rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area

def process_small_region(rles):
    new_masks = []
    scores = []
    min_area = 100
    nms_thresh = 0.7
    for rle in rles:
        mask = rle_to_mask(rle[0])
        mask, changed = remove_small_regions(mask, min_area, mode="holes")
        unchanged = not changed
        mask, changed = remove_small_regions(mask, min_area, mode="islands")
        unchanged = unchanged and not changed
        new_masks.append(torch.as_tensor(mask).unsqueeze(0))
        scores.append(float(unchanged))

    masks = torch.cat(new_masks, dim=0)
    boxes = batched_mask_to_box(masks)
    keep_by_nms = batched_nms(
        boxes.float(),
        torch.as_tensor(scores),
        torch.zeros_like(boxes[:, 0]),
        iou_threshold=nms_thresh,
    )
    for i_mask in keep_by_nms:
        if scores[i_mask] == 0.0:
            mask_torch = masks[i_mask].unsqueeze(0)
            rles[i_mask] = mask_to_rle_pytorch(mask_torch)
    masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
    return masks

def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
    predicted_masks, predicted_iou = model(
        img[None, ...], points, point_labels
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    predicted_masks = torch.take_along_dim(
        predicted_masks, sorted_ids[..., None, None], dim=2
    )
    predicted_masks = predicted_masks[0]
    iou = predicted_iou_scores[0, :, 0]
    index_iou = iou > 0.7
    iou_ = iou[index_iou]
    masks = predicted_masks[index_iou]
    score = calculate_stability_score(masks, 0.0, 1.0)
    score = score[:, 0]
    index = score > 0.9
    score_ = score[index]
    masks = masks[index]
    iou_ = iou_[index]
    masks = torch.ge(masks, 0.0)
    return masks, iou_

def run_everything_ours(image_np, model):
    model = model.cpu()
    image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    img_tensor = ToTensor()(image)
    _, original_image_h, original_image_w = img_tensor.shape
    xy = []
    GRID_SIZE = 32
    for i in range(GRID_SIZE):
        curr_x = 0.5 + i / GRID_SIZE * original_image_w
        for j in range(GRID_SIZE):
            curr_y = 0.5 + j / GRID_SIZE * original_image_h
            xy.append([curr_x, curr_y])
    xy = torch.from_numpy(np.array(xy))
    points = xy
    num_pts = xy.shape[0]
    point_labels = torch.ones(num_pts, 1)
    with torch.no_grad():
        predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
            img_tensor.cpu(),
            points.reshape(1, num_pts, 1, 2).cpu(),
            point_labels.reshape(1, num_pts, 1).cpu(),
            model.cpu(),
        )
    rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
    predicted_masks = process_small_region(rle)
    return predicted_masks

def show_anns_ours(masks, image):
    for mask in masks:
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
    return image

def process_image(image):
    # Convert PIL image to numpy array
    image_np = np.array(image)

    # Process with SAM
    image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    sam_result = mask_generator_sam.generate(image_rgb)

    # Annotate SAM result
    sam_annotated_image = image_np.copy()
    for mask in sam_result:
        sam_annotated_image[mask['segmentation']] = [0, 255, 0]

    # Process with EfficientSAM
    mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model)
    efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy())

    return [image, sam_annotated_image, efficient_sam_annotated_image]

# Gradio interface
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")],
    title="SAM vs EfficientSAM Comparison",
    description="Upload an image to compare the segmentation results of SAM and EfficientSAM."
)

interface.launch(debug=True)