import gradio as gr import numpy as np import cv2 import torch from typing import Dict, Any, List from src.plot_utils import show_masks from gradio_image_annotation import image_annotator from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor choice_mapping: Dict[str, List[str]] = { "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], } def predict(model_choice, annotations: Dict[str, Any]): config_file, ckpt_path = choice_mapping[str(model_choice)] device = "cuda" if torch.cuda.is_available() else "cpu" sam2_model = build_sam2(config_file, ckpt_path, device=device) predictor = SAM2ImagePredictor(sam2_model) predictor.set_image(annotations["image"]) coordinates = np.array( [ int(annotations["boxes"][0]["xmin"]), int(annotations["boxes"][0]["ymin"]), int(annotations["boxes"][0]["xmax"]), int(annotations["boxes"][0]["ymax"]), ] ) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=coordinates[None, :], multimask_output=False, ) mask = masks.transpose(1, 2, 0) mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format cv2.imwrite("mask.png", mask_image) return [ show_masks(annotations["image"], masks, scores, box_coords=coordinates), gr.DownloadButton("Download Mask", value="mask.png", visible=True), ] with gr.Blocks(delete_cache=(30, 30)) as demo: gr.Markdown( """ # 1. Choose Model Checkpoint """ ) with gr.Row(): model = gr.Dropdown( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="Model Checkpoint", info="Which model checkpoint to load?", ) gr.Markdown( """ # 2. Upload your Image and draw a bounding box """ ) annotator = image_annotator( value={"image": cv2.imread("assets/example.png")}, disable_edit_boxes=True, label="Draw a bounding box", ) btn = gr.Button("Get Segmentation Mask") download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False) btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn]) demo.launch()