Spaces:
Runtime error
Runtime error
Update image visualization for app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,8 @@ HOME = os.getenv("HOME")
|
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
15 |
MINIMUM_AREA_THRESHOLD = 0.01
|
16 |
|
17 |
-
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
|
18 |
-
|
19 |
SAM_MODEL_TYPE = "vit_h"
|
20 |
|
21 |
MARKDOWN = """
|
@@ -26,13 +26,23 @@ MARKDOWN = """
|
|
26 |
/>
|
27 |
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
28 |
</h1>
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
"""
|
30 |
|
31 |
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
32 |
-
VISUALIZER = Visualizer()
|
33 |
|
34 |
|
35 |
-
def inference(
|
|
|
|
|
|
|
|
|
|
|
36 |
mask_generator = SamAutomaticMaskGenerator(SAM)
|
37 |
result = mask_generator.generate(image=image)
|
38 |
detections = sv.Detections.from_sam(result)
|
@@ -40,7 +50,7 @@ def inference(image: np.ndarray, annotation_mode: List[str]) -> np.ndarray:
|
|
40 |
detections=detections,
|
41 |
area_threshold=MINIMUM_AREA_THRESHOLD)
|
42 |
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
43 |
-
annotated_image =
|
44 |
image=bgr_image,
|
45 |
detections=detections,
|
46 |
with_box="Box" in annotation_mode,
|
@@ -58,6 +68,11 @@ checkbox_annotation_mode = gr.CheckboxGroup(
|
|
58 |
choices=["Mark", "Polygon", "Mask", "Box"],
|
59 |
value=['Mark'],
|
60 |
label="Annotation Mode")
|
|
|
|
|
|
|
|
|
|
|
61 |
image_output = gr.Image(
|
62 |
label="SoM Visual Prompt",
|
63 |
type="numpy",
|
@@ -70,14 +85,17 @@ with gr.Blocks() as demo:
|
|
70 |
with gr.Column():
|
71 |
image_input.render()
|
72 |
with gr.Accordion(label="Detailed prompt settings (e.g., mark type)", open=False):
|
73 |
-
|
|
|
|
|
|
|
74 |
with gr.Column():
|
75 |
image_output.render()
|
76 |
run_button.render()
|
77 |
|
78 |
run_button.click(
|
79 |
fn=inference,
|
80 |
-
inputs=[image_input, checkbox_annotation_mode],
|
81 |
outputs=image_output)
|
82 |
|
83 |
demo.queue().launch(debug=False, show_error=True)
|
|
|
14 |
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
15 |
MINIMUM_AREA_THRESHOLD = 0.01
|
16 |
|
17 |
+
# SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
|
18 |
+
SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
|
19 |
SAM_MODEL_TYPE = "vit_h"
|
20 |
|
21 |
MARKDOWN = """
|
|
|
26 |
/>
|
27 |
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
|
28 |
</h1>
|
29 |
+
|
30 |
+
## 🚧 Roadmap
|
31 |
+
|
32 |
+
- [ ] Support for alphabetic labels
|
33 |
+
- [ ] Support for Semantic-SAM (multi-level)
|
34 |
+
- [ ] Support for interactive mode
|
35 |
"""
|
36 |
|
37 |
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
|
|
|
38 |
|
39 |
|
40 |
+
def inference(
|
41 |
+
image: np.ndarray,
|
42 |
+
annotation_mode: List[str],
|
43 |
+
mask_alpha: float
|
44 |
+
) -> np.ndarray:
|
45 |
+
visualizer = Visualizer(mask_opacity=mask_alpha)
|
46 |
mask_generator = SamAutomaticMaskGenerator(SAM)
|
47 |
result = mask_generator.generate(image=image)
|
48 |
detections = sv.Detections.from_sam(result)
|
|
|
50 |
detections=detections,
|
51 |
area_threshold=MINIMUM_AREA_THRESHOLD)
|
52 |
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
53 |
+
annotated_image = visualizer.visualize(
|
54 |
image=bgr_image,
|
55 |
detections=detections,
|
56 |
with_box="Box" in annotation_mode,
|
|
|
68 |
choices=["Mark", "Polygon", "Mask", "Box"],
|
69 |
value=['Mark'],
|
70 |
label="Annotation Mode")
|
71 |
+
slider_mask_alpha = gr.Slider(
|
72 |
+
minimum=0,
|
73 |
+
maximum=1,
|
74 |
+
value=0.05,
|
75 |
+
label="Mask Alpha")
|
76 |
image_output = gr.Image(
|
77 |
label="SoM Visual Prompt",
|
78 |
type="numpy",
|
|
|
85 |
with gr.Column():
|
86 |
image_input.render()
|
87 |
with gr.Accordion(label="Detailed prompt settings (e.g., mark type)", open=False):
|
88 |
+
with gr.Row():
|
89 |
+
checkbox_annotation_mode.render()
|
90 |
+
with gr.Row():
|
91 |
+
slider_mask_alpha.render()
|
92 |
with gr.Column():
|
93 |
image_output.render()
|
94 |
run_button.render()
|
95 |
|
96 |
run_button.click(
|
97 |
fn=inference,
|
98 |
+
inputs=[image_input, checkbox_annotation_mode, slider_mask_alpha],
|
99 |
outputs=image_output)
|
100 |
|
101 |
demo.queue().launch(debug=False, show_error=True)
|