File size: 5,640 Bytes
7637f71 6c5760f 7637f71 6c5760f 7637f71 6c5760f 7637f71 6c5760f e0b4b52 7637f71 8a50c06 e0b4b52 7637f71 e0b4b52 6c5760f 7637f71 6c5760f 7637f71 8a50c06 e0b4b52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision.transforms import ToTensor
from PIL import Image
import io
import cv2
import gradio as gr
import os
import requests
import zipfile
import subprocess
# Ensure the necessary model files are available
def download_file(url, destination):
response = requests.get(url, stream=True)
with open(destination, 'wb') as f:
f.write(response.content)
# Install the necessary packages
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/segment-anything.git"])
subprocess.run(["git", "clone", "https://github.com/yformer/EfficientSAM.git"])
# Change directory to EfficientSAM
os.chdir("EfficientSAM")
from segment_anything.utils.amg import (
batched_mask_to_box,
calculate_stability_score,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
)
from torchvision.ops.boxes import batched_nms, box_area
def process_small_region(rles):
new_masks = []
scores = []
min_area = 100
nms_thresh = 0.7
for rle in rles:
mask = rle_to_mask(rle[0])
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
scores.append(float(unchanged))
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]),
iou_threshold=nms_thresh,
)
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
rles[i_mask] = mask_to_rle_pytorch(mask_torch)
masks = [rle_to_mask(rles[i][0]) for i in keep_by_nms]
return masks
def get_predictions_given_embeddings_and_queries(img, points, point_labels, model):
predicted_masks, predicted_iou = model(
img[None, ...], points, point_labels
)
sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
predicted_iou_scores = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
predicted_masks = torch.take_along_dim(
predicted_masks, sorted_ids[..., None, None], dim=2
)
predicted_masks = predicted_masks[0]
iou = predicted_iou_scores[0, :, 0]
index_iou = iou > 0.7
iou_ = iou[index_iou]
masks = predicted_masks[index_iou]
score = calculate_stability_score(masks, 0.0, 1.0)
score = score[:, 0]
index = score > 0.9
score_ = score[index]
masks = masks[index]
iou_ = iou_[index]
masks = torch.ge(masks, 0.0)
return masks, iou_
def run_everything_ours(image_np, model):
model = model.cpu()
image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
img_tensor = ToTensor()(image)
_, original_image_h, original_image_w = img_tensor.shape
xy = []
GRID_SIZE = 32
for i in range(GRID_SIZE):
curr_x = 0.5 + i / GRID_SIZE * original_image_w
for j in range(GRID_SIZE):
curr_y = 0.5 + j / GRID_SIZE * original_image_h
xy.append([curr_x, curr_y])
xy = torch.from_numpy(np.array(xy))
points = xy
num_pts = xy.shape[0]
point_labels = torch.ones(num_pts, 1)
with torch.no_grad():
predicted_masks, predicted_iou = get_predictions_given_embeddings_and_queries(
img_tensor.cpu(),
points.reshape(1, num_pts, 1, 2).cpu(),
point_labels.reshape(1, num_pts, 1).cpu(),
model.cpu(),
)
rle = [mask_to_rle_pytorch(m[0:1]) for m in predicted_masks]
predicted_masks = process_small_region(rle)
return predicted_masks
def show_anns_ours(masks, image):
for mask in masks:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(image, contours, -1, (0, 255, 0), 2)
return image
def process_image(image):
# Convert PIL image to numpy array
image_np = np.array(image)
# Process with SAM
image_rgb = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
sam_result = mask_generator_sam.generate(image_rgb)
# Annotate SAM result
sam_annotated_image = image_np.copy()
for mask in sam_result:
sam_annotated_image[mask['segmentation']] = [0, 255, 0]
# Process with EfficientSAM
mask_efficient_sam_vits = run_everything_ours(image_np, efficient_sam_vits_model)
efficient_sam_annotated_image = show_anns_ours(mask_efficient_sam_vits, image_np.copy())
return [image, sam_annotated_image, efficient_sam_annotated_image]
# Download EfficientSAM model
if not os.path.exists("weights/efficient_sam_vits.pt.zip"):
download_file("https://example.com/path/to/efficient_sam_vits.pt.zip", "weights/efficient_sam_vits.pt.zip")
# Extract EfficientSAM model
with zipfile.ZipFile("weights/efficient_sam_vits.pt.zip", 'r') as zip_ref:
zip_ref.extractall("weights")
from efficient_sam.build_efficient_sam import build_efficient_sam_vits
efficient_sam_vits_model = build_efficient_sam_vits()
efficient_sam_vits_model.eval()
# Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(type="pil", label="Original"), gr.Image(type="pil", label="SAM Segmented"), gr.Image(type="pil", label="EfficientSAM Segmented")],
title="SAM vs EfficientSAM Comparison",
description="Upload an image to compare the segmentation results of SAM and EfficientSAM."
)
interface.launch(debug=True)
|