import matplotlib.pyplot as plt import numpy as np import torch from torchvision.transforms import ToTensor from PIL import Image import io import cv2 import gradio as gr import os import requests import zipfile import subprocess # Ensure the necessary model files are available def download_file(url, destination): response = requests.get(url, stream=True) with open(destination, 'wb') as f: f.write(response.content) # Install the necessary packages subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"]) subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"]) # Change directory to EfficientSAM os.chdir("EfficientSAM") from segment_anything.utils.amg import ( batched_mask_to_box, calculate_stability_score, mask_to_rle_pytorch, remove_small_regions, rle_to_mask, ) from torchvision.ops.boxes import batched_nms, box_area def process_small_region(rles): new_masks = [] scores = [] min_area = 100 nms_thresh = 0.7 for rle in rles: mask = rle_to_mask(rle[0]) mask, changed = remove_small_regions(mask, min_area, mode="holes") unchanged = not changed mask, changed = remove_small_regions(mask, min_area, mode="islands") unchanged = unchanged and not changed new_masks.append(torch.as_tensor(mask).unsqueeze(0)) scores.append(float(unchanged)) masks = torch.cat(new_masks, dim=0) boxes = batched_mask_to_box(masks) keep_by_nms = batched_nms( boxes.float(), torch.as_tensor(scores), torch.zeros_like(boxes[:, 0]), iou_threshold=nms_thresh, ) for i_mask in keep_by_nms: if scores[i_mask] == 0.0: mask_torch = masks[i_mask].unsqueeze(0) rles[i_mask] = mask_to_rle_pytorch(mask_torch) masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms] return masks def get_predictions_given_embeddings_and_queries(img, points, point_labels, model): predicted_masks, predicted_iou = model( img[None, ...], points, point_labels ) sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True) predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2) predicted_masks = torch.take_along_dim( predicted_masks, sorted_ids[..., None, None], dim=2 ) predicted_masks = predicted_masks[0] iou = predicted_iou_scores[0, :, 0] index_iou = iou > 0.7 iou_ = iou[index_iou] masks = predicted_masks[index_iou] score = calculate_stability_score(masks, 0.0, 1.0) score = score[:, 0] index = score > 0.9 score_ = score[index] masks = masks[index] iou_ = iou_[index] masks = torch.ge(masks, 0.0) return masks, iou_ def run_everything_ours(image_np, model): model = model.cpu() image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) img_tensor = ToTensor()(image) _, original_image_h, original_image_w = img_tensor.shape xy = [] GRID_SIZE = 32 for i in range(GRID_SIZE): curr_x = 0.5 + i / GRID_SIZE * original_image_w for j in range(GRID_SIZE): curr_y = 0.5 + j / GRID_SIZE * original_image_h xy.append([curr_x, curr_y]) xy = torch.from_numpy(np.array(xy)) points = xy num_pts = xy.shape[0] point_labels = torch.ones(num_pts, 1) with torch.no_grad(): predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries( img_tensor.cpu(), points.reshape(1, num_pts, 1, 2).cpu(), point_labels.reshape(1, num_pts, 1).cpu(), model.cpu(), ) rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks] predicted_masks = process_small_region(rle) return predicted_masks def show_anns_ours(masks, image): for mask in masks: contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(image, contours, -1, (0, 255, 0), 2) return image def process_image(image): # Convert PIL image to numpy array image_np = np.array(image) # Process with SAM image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) sam_result = mask_generator_sam.generate(image_rgb) # Annotate SAM result sam_annotated_image = image_np.copy() for mask in sam_result: sam_annotated_image[mask['segmentation']] = [0, 255, 0] # Process with EfficientSAM mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model) efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy()) return [image, sam_annotated_image, efficient_sam_annotated_image] # Download EfficientSAM model if not os.path.exists("weights/efficient_sam_vits.pt.zip"): download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip") # Extract EfficientSAM model with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref: zip_ref.extractall("weights") from efficient_sam.build_efficient_sam import build_efficient_sam_vits efficient_sam_vits_model = build_efficient_sam_vits() efficient_sam_vits_model.eval() # Gradio interface interface = gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")], title="SAM vs EfficientSAM Comparison", description="Upload an image to compare the segmentation results of SAM and EfficientSAM." ) interface.launch(debug=True)