Spaces:
Running
on
Zero
Running
on
Zero
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageDraw
|
5 |
+
import requests
|
6 |
+
from transformers import SamModel, SamProcessor
|
7 |
+
import cv2
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
# we load model and processor
|
13 |
+
model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
|
14 |
+
processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
|
15 |
+
|
16 |
+
cache_data = None
|
17 |
+
|
18 |
+
def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
|
19 |
+
gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
|
20 |
+
_, thresh = cv2.threshold(gray, 127, 255, 0)
|
21 |
+
kernel = np.ones((5,5),np.uint8)
|
22 |
+
closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
23 |
+
contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
24 |
+
points = []
|
25 |
+
for contour in contours:
|
26 |
+
moments = cv2.moments(contour)
|
27 |
+
cx = int(moments['m10']/moments['m00'])
|
28 |
+
cy = int(moments['m01']/moments['m00'])
|
29 |
+
points.append([cx, cy])
|
30 |
+
return [points]
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
|
34 |
+
global cache_data
|
35 |
+
image_input = Image.fromarray(image_input)
|
36 |
+
inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
|
37 |
+
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
|
38 |
+
embedding = model.get_image_embeddings(inputs["pixel_values"])
|
39 |
+
pixels = inputs["pixel_values"]
|
40 |
+
cache_data = [pixels, embedding]
|
41 |
+
del inputs["pixel_values"]
|
42 |
+
|
43 |
+
outputs = model.forward(image_embeddings=cache_data[1], **inputs)
|
44 |
+
masks = processor.image_processor.post_process_masks(
|
45 |
+
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
46 |
+
)
|
47 |
+
masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
|
48 |
+
|
49 |
+
return masks
|
50 |
+
|
51 |
+
def main_func(inputs) -> List[Image.Image]:
|
52 |
+
dots = inputs['mask']
|
53 |
+
points = mask_2_dots(dots)
|
54 |
+
image_input = inputs['image']
|
55 |
+
masks = foward_pass(image_input, points)
|
56 |
+
|
57 |
+
image_input = Image.fromarray(image_input)
|
58 |
+
draw = ImageDraw.Draw(image_input)
|
59 |
+
for point in points[0]:
|
60 |
+
draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
|
61 |
+
|
62 |
+
pred_masks = [image_input]
|
63 |
+
for i in range(masks.shape[2]):
|
64 |
+
pred_masks.append(Image.fromarray((masks[:,:,i] * 255).astype(np.uint8)))
|
65 |
+
|
66 |
+
return pred_masks
|
67 |
+
|
68 |
+
def reset_data():
|
69 |
+
global cache_data
|
70 |
+
cache_data = None
|
71 |
+
|
72 |
+
with gr.Blocks() as demo:
|
73 |
+
gr.Markdown("# How to use")
|
74 |
+
gr.Markdown("To start, input an image, then use the brush to create dots on the object which you want to segment, don't worry if your dots aren't perfect as the code will find the middle of each drawn item. Then press the segment button to create masks for the object that the dots are on.")
|
75 |
+
gr.Markdown("# Demo to run Robust Segment Anything base model")
|
76 |
+
gr.Markdown("""This app uses the [Robust Segment Anything](https://huggingface.co/jadechoghari/robustsam-vit-base) model from Snap Research to get a mask from a points in an image.
|
77 |
+
""")
|
78 |
+
with gr.Tab("Flip Image"):
|
79 |
+
with gr.Row():
|
80 |
+
image_input = gr.ImageEditor()
|
81 |
+
image_output = gr.Gallery()
|
82 |
+
|
83 |
+
image_button = gr.Button("Segment Image")
|
84 |
+
|
85 |
+
image_button.click(main_func, inputs=image_input, outputs=image_output)
|
86 |
+
image_input.upload(reset_data)
|
87 |
+
|
88 |
+
demo.launch()
|