File size: 2,145 Bytes
3d74359 8260e47 3ac5059 8260e47 bf29adc e6eaebf bf29adc 410698b bf29adc 410698b bf29adc 630e69b 9aea32e 630e69b 3dcca3c 410698b 3dcca3c 410698b 630e69b 8260e47 bf29adc 8260e47 bf29adc 8260e47 bf29adc 410698b bf29adc 95190fc bf29adc 630e69b bf29adc 95190fc 410698b 8260e47 bf29adc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import Any, Dict
import spaces
import cv2
import gradio as gr
import numpy as np
import torch
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.sam2_image_predictor import SAM2ImagePredictor
from src.plot_utils import export_mask
@spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
# device = "cuda" if torch.cuda.is_available() else "cpu"
sam2_model = load_model(
variant=model_choice,
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
device="cuda",
)
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
predictor.set_image(annotations["image"])
coordinates = []
for i in range(len(annotations["boxes"])):
coordinate = [
int(annotations["boxes"][i]["xmin"]),
int(annotations["boxes"][i]["ymin"]),
int(annotations["boxes"][i]["xmax"]),
int(annotations["boxes"][i]["ymax"]),
]
coordinates.append(coordinate)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(coordinates),
multimask_output=False,
)
if masks.shape[0] == 1:
# handle single mask cases
masks = np.expand_dims(masks, axis=0)
return export_mask(masks)
with gr.Blocks(delete_cache=(30, 30)) as demo:
gr.Markdown(
"""
# 1. Choose Model Checkpoint
"""
)
with gr.Row():
model = gr.Dropdown(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="Model Checkpoint",
info="Which model checkpoint to load?",
)
gr.Markdown(
"""
# 2. Upload your Image and draw bounding box(es)
"""
)
annotator = image_annotator(
value={"image": cv2.imread("assets/example.png")},
disable_edit_boxes=True,
label="Draw a bounding box",
)
btn = gr.Button("Get Segmentation Mask(s)")
btn.click(
fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
)
demo.launch()
|