File size: 2,142 Bytes
3d74359
8260e47
 
bf29adc
 
e6eaebf
bf29adc
410698b
bf29adc
 
410698b
bf29adc
9aea32e
630e69b
9aea32e
630e69b
e6eaebf
410698b
 
 
 
 
 
630e69b
8260e47
 
 
 
 
 
 
bf29adc
8260e47
 
bf29adc
 
 
8260e47
bf29adc
 
 
410698b
 
 
 
 
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95190fc
bf29adc
 
 
 
630e69b
bf29adc
 
 
95190fc
410698b
 
8260e47
bf29adc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Any, Dict

import cv2
import gradio as gr
import numpy as np
import torch
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.sam2_image_predictor import SAM2ImagePredictor

from src.plot_utils import export_mask

import spaces

@spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sam2_model = load_model(
        variant=model_choice,
        ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
        device=device,
    )
    predictor = SAM2ImagePredictor(sam2_model)  # type:ignore
    predictor.set_image(annotations["image"])
    coordinates = []
    for i in range(len(annotations["boxes"])):
        coordinate = [
            int(annotations["boxes"][i]["xmin"]),
            int(annotations["boxes"][i]["ymin"]),
            int(annotations["boxes"][i]["xmax"]),
            int(annotations["boxes"][i]["ymax"]),
        ]
        coordinates.append(coordinate)

    masks, scores, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=np.array(coordinates),
        multimask_output=False,
    )

    if masks.shape[0] == 1:
        # handle single mask cases
        masks = np.expand_dims(masks, axis=0)

    return export_mask(masks)


with gr.Blocks(delete_cache=(30, 30)) as demo:
    gr.Markdown(
        """
        # 1. Choose Model Checkpoint
        """
    )
    with gr.Row():
        model = gr.Dropdown(
            choices=["tiny", "small", "base_plus", "large"],
            value="tiny",
            label="Model Checkpoint",
            info="Which model checkpoint to load?",
        )

    gr.Markdown(
        """
        # 2. Upload your Image and draw bounding box(es)
        """
    )

    annotator = image_annotator(
        value={"image": cv2.imread("assets/example.png")},
        disable_edit_boxes=True,
        label="Draw a bounding box",
    )
    btn = gr.Button("Get Segmentation Mask(s)")
    btn.click(
        fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
    )

demo.launch()