NandiniLokeshReddy's picture
Update app.py
7a00437 verified
import os
import subprocess
import sys
import requests
import zipfile
import gradio as gr
import torch
import numpy as np
from torchvision.transforms import ToTensor
from PIL import Image
import cv2
# Ensure the necessary model files are available
def download_file(url, destination):
response = requests.get(url, stream=True)
with open(destination, 'wb') as f:
f.write(response.content)
# Download SAM model
if not os.path.exists("weights/sam_vit_h_4b8939.pth"):
os.makedirs("weights", exist_ok=True)
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "weights/sam_vit_h_4b8939.pth")
# Add EfficientSAM to Python path
sys.path.append(os.path.abspath("EfficientSAM-main"))
# Import SAM and EfficientSAM modules
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
# Constants
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
CHECKPOINT_PATH = "weights/sam_vit_h_4b8939.pth"
# Load SAM model
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_generator_sam = SamAutomaticMaskGenerator(sam)
# Load EfficientSAM model
with zipfile.ZipFile("EfficientSAM-main/weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
efficient_sam_vits_model = build_efficient_sam_vits()
from segment_anything.utils.amg import (
batched_mask_to_box,
calculate_stability_score,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area
def process_small_region(rles):
new_masks = []
scores = []
min_area = 100
nms_thresh = 0.7
for rle in rles:
mask = rle_to_mask(rle[0])
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
scores.append(float(unchanged))
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]),
iou_threshold=nms_thresh,
)
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
rles[i_mask] = mask_to_rle_pytorch(mask_torch)
masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
return masks
def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
predicted_masks, predicted_iou = model(
img[None, ...], points, point_labels
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_masks = torch.take_along_dim(
predicted_masks, sorted_ids[..., None, None], dim=2
)
predicted_masks = predicted_masks[0]
iou = predicted_iou_scores[0, :, 0]
index_iou = iou > 0.7
iou_ = iou[index_iou]
masks = predicted_masks[index_iou]
score = calculate_stability_score(masks, 0.0, 1.0)
score = score[:, 0]
index = score > 0.9
score_ = score[index]
masks = masks[index]
iou_ = iou_[index]
masks = torch.ge(masks, 0.0)
return masks, iou_
def run_everything_ours(image_np, model):
model = model.cpu()
image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
img_tensor = ToTensor()(image)
_, original_image_h, original_image_w = img_tensor.shape
xy = []
GRID_SIZE = 32
for i in range(GRID_SIZE):
curr_x = 0.5 + i / GRID_SIZE * original_image_w
for j in range(GRID_SIZE):
curr_y = 0.5 + j / GRID_SIZE * original_image_h
xy.append([curr_x, curr_y])
xy = torch.from_numpy(np.array(xy))
points = xy
num_pts = xy.shape[0]
point_labels = torch.ones(num_pts, 1)
with torch.no_grad():
predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
img_tensor.cpu(),
points.reshape(1, num_pts, 1, 2).cpu(),
point_labels.reshape(1, num_pts, 1).cpu(),
model.cpu(),
)
rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
predicted_masks = process_small_region(rle)
return predicted_masks
def show_anns_ours(masks, image):
for mask in masks:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
return image
def process_image(image):
# Convert PIL image to numpy array
image_np = np.array(image)
# Process with SAM
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
sam_result = mask_generator_sam.generate(image_rgb)
# Annotate SAM result
sam_annotated_image = image_np.copy()
for mask in sam_result:
sam_annotated_image[mask['segmentation']] = [0, 255, 0]
# Process with EfficientSAM
mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model)
efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy())
return [image, sam_annotated_image, efficient_sam_annotated_image]
# Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")],
title="SAM vs EfficientSAM Comparison",
description="Upload an image to compare the segmentation results of SAM and EfficientSAM."
)
interface.launch(debug=True)