|
import os |
|
import subprocess |
|
import sys |
|
import requests |
|
import zipfile |
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from torchvision.transforms import ToTensor |
|
from PIL import Image |
|
import cv2 |
|
|
|
|
|
def download_file(url, destination): |
|
response = requests.get(url, stream=True) |
|
with open(destination, 'wb') as f: |
|
f.write(response.content) |
|
|
|
|
|
if not os.path.exists("weights/sam_vit_h_4b8939.pth"): |
|
os.makedirs("weights", exist_ok=True) |
|
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "weights/sam_vit_h_4b8939.pth") |
|
|
|
|
|
sys.path.append(os.path.abspath("EfficientSAM-main")) |
|
|
|
|
|
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"]) |
|
subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"]) |
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator |
|
from efficient_sam.build_efficient_sam import build_efficient_sam_vits |
|
|
|
|
|
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
MODEL_TYPE = "vit_h" |
|
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth" |
|
|
|
|
|
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE) |
|
mask_generator_sam = SamAutomaticMaskGenerator(sam) |
|
|
|
|
|
with zipfile.ZipFile("EfficientSAM-main/weights/efficient_sam_vits.pt.zip", 'r') as zip_ref: |
|
zip_ref.extractall("weights") |
|
efficient_sam_vits_model = build_efficient_sam_vits() |
|
|
|
from segment_anything.utils.amg import ( |
|
batched_mask_to_box, |
|
calculate_stability_score, |
|
mask_to_rle_pytorch, |
|
remove_small_regions, |
|
rle_to_mask, |
|
) |
|
from torchvision.ops.boxes import batched_nms, box_area |
|
|
|
def process_small_region(rles): |
|
new_masks = [] |
|
scores = [] |
|
min_area = 100 |
|
nms_thresh = 0.7 |
|
for rle in rles: |
|
mask = rle_to_mask(rle[0]) |
|
mask, changed = remove_small_regions(mask, min_area, mode="holes") |
|
unchanged = not changed |
|
mask, changed = remove_small_regions(mask, min_area, mode="islands") |
|
unchanged = unchanged and not changed |
|
new_masks.append(torch.as_tensor(mask).unsqueeze(0)) |
|
scores.append(float(unchanged)) |
|
|
|
masks = torch.cat(new_masks, dim=0) |
|
boxes = batched_mask_to_box(masks) |
|
keep_by_nms = batched_nms( |
|
boxes.float(), |
|
torch.as_tensor(scores), |
|
torch.zeros_like(boxes[:, 0]), |
|
iou_threshold=nms_thresh, |
|
) |
|
for i_mask in keep_by_nms: |
|
if scores[i_mask] == 0.0: |
|
mask_torch = masks[i_mask].unsqueeze(0) |
|
rles[i_mask] = mask_to_rle_pytorch(mask_torch) |
|
masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms] |
|
return masks |
|
|
|
def get_predictions_given_embeddings_and_queries(img, points, point_labels, model): |
|
predicted_masks, predicted_iou = model( |
|
img[None, ...], points, point_labels |
|
) |
|
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True) |
|
predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2) |
|
predicted_masks = torch.take_along_dim( |
|
predicted_masks, sorted_ids[..., None, None], dim=2 |
|
) |
|
predicted_masks = predicted_masks[0] |
|
iou = predicted_iou_scores[0, :, 0] |
|
index_iou = iou > 0.7 |
|
iou_ = iou[index_iou] |
|
masks = predicted_masks[index_iou] |
|
score = calculate_stability_score(masks, 0.0, 1.0) |
|
score = score[:, 0] |
|
index = score > 0.9 |
|
score_ = score[index] |
|
masks = masks[index] |
|
iou_ = iou_[index] |
|
masks = torch.ge(masks, 0.0) |
|
return masks, iou_ |
|
|
|
def run_everything_ours(image_np, model): |
|
model = model.cpu() |
|
image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
|
img_tensor = ToTensor()(image) |
|
_, original_image_h, original_image_w = img_tensor.shape |
|
xy = [] |
|
GRID_SIZE = 32 |
|
for i in range(GRID_SIZE): |
|
curr_x = 0.5 + i / GRID_SIZE * original_image_w |
|
for j in range(GRID_SIZE): |
|
curr_y = 0.5 + j / GRID_SIZE * original_image_h |
|
xy.append([curr_x, curr_y]) |
|
xy = torch.from_numpy(np.array(xy)) |
|
points = xy |
|
num_pts = xy.shape[0] |
|
point_labels = torch.ones(num_pts, 1) |
|
with torch.no_grad(): |
|
predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries( |
|
img_tensor.cpu(), |
|
points.reshape(1, num_pts, 1, 2).cpu(), |
|
point_labels.reshape(1, num_pts, 1).cpu(), |
|
model.cpu(), |
|
) |
|
rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks] |
|
predicted_masks = process_small_region(rle) |
|
return predicted_masks |
|
|
|
def show_anns_ours(masks, image): |
|
for mask in masks: |
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
cv2.drawContours(image, contours, -1, (0, 255, 0), 2) |
|
return image |
|
|
|
def process_image(image): |
|
|
|
image_np = np.array(image) |
|
|
|
|
|
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) |
|
sam_result = mask_generator_sam.generate(image_rgb) |
|
|
|
|
|
sam_annotated_image = image_np.copy() |
|
for mask in sam_result: |
|
sam_annotated_image[mask['segmentation']] = [0, 255, 0] |
|
|
|
|
|
mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model) |
|
efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy()) |
|
|
|
return [image, sam_annotated_image, efficient_sam_annotated_image] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")], |
|
title="SAM vs EfficientSAM Comparison", |
|
description="Upload an image to compare the segmentation results of SAM and EfficientSAM." |
|
) |
|
|
|
interface.launch(debug=True) |
|
|