File size: 5,640 Bytes
7637f71
6c5760f
7637f71
6c5760f
7637f71
 
6c5760f
7637f71
6c5760f
 
 
 
e0b4b52
7637f71
 
8a50c06
 
 
 
 
e0b4b52
 
 
7637f71
e0b4b52
6c5760f
7637f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c5760f
 
 
 
 
 
 
 
 
 
 
 
7637f71
 
 
 
 
 
 
 
 
 
8a50c06
e0b4b52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import io
import cv2
import gradio as gr
import os
import requests
import zipfile
import subprocess

# Ensure the necessary model files are available
def download_file(url, destination):
    response = requests.get(url, stream=True)
    with open(destination, 'wb') as f:
        f.write(response.content)

# Install the necessary packages
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])

# Change directory to EfficientSAM
os.chdir("EfficientSAM")

from segment_anything.utils.amg import (
    batched_mask_to_box,
    calculate_stability_score,
    mask_to_rle_pytorch,
    remove_small_regions,
    rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area

def process_small_region(rles):
    new_masks = []
    scores = []
    min_area = 100
    nms_thresh = 0.7
    for rle in rles:
        mask = rle_to_mask(rle[0])
        mask, changed = remove_small_regions(mask, min_area, mode="holes")
        unchanged = not changed
        mask, changed = remove_small_regions(mask, min_area, mode="islands")
        unchanged = unchanged and not changed
        new_masks.append(torch.as_tensor(mask).unsqueeze(0))
        scores.append(float(unchanged))
    masks = torch.cat(new_masks, dim=0)
    boxes = batched_mask_to_box(masks)
    keep_by_nms = batched_nms(
        boxes.float(),
        torch.as_tensor(scores),
        torch.zeros_like(boxes[:, 0]),
        iou_threshold=nms_thresh,
    )
    for i_mask in keep_by_nms:
        if scores[i_mask] == 0.0:
            mask_torch = masks[i_mask].unsqueeze(0)
            rles[i_mask] = mask_to_rle_pytorch(mask_torch)
    masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
    return masks

def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
    predicted_masks, predicted_iou = model(
        img[None, ...], points, point_labels
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    predicted_masks = torch.take_along_dim(
        predicted_masks, sorted_ids[..., None, None], dim=2
    )
    predicted_masks = predicted_masks[0]
    iou = predicted_iou_scores[0, :, 0]
    index_iou = iou > 0.7
    iou_ = iou[index_iou]
    masks = predicted_masks[index_iou]
    score = calculate_stability_score(masks, 0.0, 1.0)
    score = score[:, 0]
    index = score > 0.9
    score_ = score[index]
    masks = masks[index]
    iou_ = iou_[index]
    masks = torch.ge(masks, 0.0)
    return masks, iou_

def run_everything_ours(image_np, model):
    model = model.cpu()
    image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    img_tensor = ToTensor()(image)
    _, original_image_h, original_image_w = img_tensor.shape
    xy = []
    GRID_SIZE = 32
    for i in range(GRID_SIZE):
        curr_x = 0.5 + i / GRID_SIZE * original_image_w
        for j in range(GRID_SIZE):
            curr_y = 0.5 + j / GRID_SIZE * original_image_h
            xy.append([curr_x, curr_y])
    xy = torch.from_numpy(np.array(xy))
    points = xy
    num_pts = xy.shape[0]
    point_labels = torch.ones(num_pts, 1)
    with torch.no_grad():
        predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
            img_tensor.cpu(),
            points.reshape(1, num_pts, 1, 2).cpu(),
            point_labels.reshape(1, num_pts, 1).cpu(),
            model.cpu(),
        )
    rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
    predicted_masks = process_small_region(rle)
    return predicted_masks

def show_anns_ours(masks, image):
    for mask in masks:
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
    return image

def process_image(image):
    # Convert PIL image to numpy array
    image_np = np.array(image)

    # Process with SAM
    image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
    sam_result = mask_generator_sam.generate(image_rgb)

    # Annotate SAM result
    sam_annotated_image = image_np.copy()
    for mask in sam_result:
        sam_annotated_image[mask['segmentation']] = [0, 255, 0]

    # Process with EfficientSAM
    mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model)
    efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy())

    return [image, sam_annotated_image, efficient_sam_annotated_image]

# Download EfficientSAM model
if not os.path.exists("weights/efficient_sam_vits.pt.zip"):
    download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip")

# Extract EfficientSAM model
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
    zip_ref.extractall("weights")

from efficient_sam.build_efficient_sam import build_efficient_sam_vits
efficient_sam_vits_model = build_efficient_sam_vits()
efficient_sam_vits_model.eval()

# Gradio interface
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),
    outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")],
    title="SAM vs EfficientSAM Comparison",
    description="Upload an image to compare the segmentation results of SAM and EfficientSAM."
)

interface.launch(debug=True)