Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,8 +14,25 @@ models = {
|
|
14 |
'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
|
15 |
}
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
|
20 |
mask_generator = SamAutomaticMaskGenerator(
|
21 |
sam,
|
@@ -33,37 +50,43 @@ def inference(device, model_type, input_img, points_per_side, pred_iou_thresh, s
|
|
33 |
output_mode='binary_mask'
|
34 |
)
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
with gr.Blocks() as demo:
|
52 |
with gr.Row():
|
53 |
gr.Markdown(
|
54 |
'''# Segment Anything!🚀
|
55 |
-
|
56 |
-
[
|
57 |
'''
|
58 |
)
|
59 |
with gr.Row():
|
60 |
-
#
|
61 |
-
model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="
|
62 |
-
#
|
63 |
-
device = gr.Dropdown(["cpu"], value='
|
64 |
|
65 |
# 参数
|
66 |
-
with gr.Accordion(label='
|
67 |
with gr.Row():
|
68 |
points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
|
69 |
info='''The number of points to be sampled along one side of the image. The total
|
@@ -88,43 +111,63 @@ with gr.Blocks() as demo:
|
|
88 |
info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
|
89 |
masks between different crops.''')
|
90 |
|
91 |
-
#
|
92 |
-
with gr.
|
93 |
-
with gr.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
|
127 |
-
demo.launch()
|
128 |
|
129 |
|
130 |
|
|
|
14 |
'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
|
15 |
}
|
16 |
|
17 |
+
|
18 |
+
def segment_one(img, mask_generator, seed=None):
|
19 |
+
if seed is not None:
|
20 |
+
np.random.seed(seed)
|
21 |
+
masks = mask_generator.generate(img)
|
22 |
+
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
|
23 |
+
mask_all = np.ones((img.shape[0], img.shape[1], 3))
|
24 |
+
for ann in sorted_anns:
|
25 |
+
m = ann['segmentation']
|
26 |
+
color_mask = np.random.random((1, 3)).tolist()[0]
|
27 |
+
for i in range(3):
|
28 |
+
mask_all[m == True, i] = color_mask[i]
|
29 |
+
result = img / 255 * 0.3 + mask_all * 0.7
|
30 |
+
return result, mask_all
|
31 |
+
|
32 |
+
|
33 |
+
def inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
|
34 |
+
stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, input_x, progress=gr.Progress()):
|
35 |
+
# sam model
|
36 |
sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
|
37 |
mask_generator = SamAutomaticMaskGenerator(
|
38 |
sam,
|
|
|
50 |
output_mode='binary_mask'
|
51 |
)
|
52 |
|
53 |
+
# input is image, type: numpy
|
54 |
+
if type(input_x) == np.ndarray:
|
55 |
+
result, mask_all = segment_one(input_x, mask_generator)
|
56 |
+
return result, mask_all
|
57 |
+
elif isinstance(input_x, str): # input is video, type: path (str)
|
58 |
+
cap = cv2.VideoCapture(input_x) # read video
|
59 |
+
frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
|
60 |
+
W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
61 |
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
62 |
+
print(fps)
|
63 |
+
out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True)
|
64 |
+
for _ in progress.tqdm(range(int(frames_num)), desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)):
|
65 |
+
ret, frame = cap.read() # read a frame
|
66 |
+
result, mask_all = segment_one(frame, mask_generator, seed=2023)
|
67 |
+
result = (result * 255).astype(np.uint8)
|
68 |
+
out.write(result)
|
69 |
+
out.release()
|
70 |
+
cap.release()
|
71 |
+
return 'output.mp4'
|
72 |
|
73 |
|
74 |
with gr.Blocks() as demo:
|
75 |
with gr.Row():
|
76 |
gr.Markdown(
|
77 |
'''# Segment Anything!🚀
|
78 |
+
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
|
79 |
+
[**Official Project**](https://segment-anything.com/)
|
80 |
'''
|
81 |
)
|
82 |
with gr.Row():
|
83 |
+
# select model
|
84 |
+
model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="Select Model")
|
85 |
+
# select device
|
86 |
+
device = gr.Dropdown(["cpu", "cuda"], value='cuda', label="Select Device")
|
87 |
|
88 |
# 参数
|
89 |
+
with gr.Accordion(label='Parameters', open=False):
|
90 |
with gr.Row():
|
91 |
points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
|
92 |
info='''The number of points to be sampled along one side of the image. The total
|
|
|
111 |
info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
|
112 |
masks between different crops.''')
|
113 |
|
114 |
+
# Show image
|
115 |
+
with gr.Tab(label='Image'):
|
116 |
+
with gr.Row().style(equal_height=True):
|
117 |
+
with gr.Column():
|
118 |
+
input_image = gr.Image(type="numpy")
|
119 |
+
with gr.Row():
|
120 |
+
button = gr.Button("Auto!")
|
121 |
+
with gr.Tab(label='Image+Mask'):
|
122 |
+
output_image = gr.Image(type='numpy')
|
123 |
+
with gr.Tab(label='Mask'):
|
124 |
+
output_mask = gr.Image(type='numpy')
|
125 |
+
|
126 |
+
gr.Examples(
|
127 |
+
examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"),
|
128 |
+
os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"),
|
129 |
+
os.path.join(os.path.dirname(__file__), "./images/1.jpg"),
|
130 |
+
os.path.join(os.path.dirname(__file__), "./images/2.jpg"),
|
131 |
+
os.path.join(os.path.dirname(__file__), "./images/3.jpg"),
|
132 |
+
os.path.join(os.path.dirname(__file__), "./images/4.jpg"),
|
133 |
+
os.path.join(os.path.dirname(__file__), "./images/5.jpg"),
|
134 |
+
os.path.join(os.path.dirname(__file__), "./images/6.jpg"),
|
135 |
+
os.path.join(os.path.dirname(__file__), "./images/7.jpg"),
|
136 |
+
os.path.join(os.path.dirname(__file__), "./images/8.jpg"),
|
137 |
+
],
|
138 |
+
inputs=input_image,
|
139 |
+
outputs=output_image,
|
140 |
+
)
|
141 |
+
# Show video
|
142 |
+
with gr.Tab(label='Video'):
|
143 |
+
with gr.Row().style(equal_height=True):
|
144 |
+
with gr.Column():
|
145 |
+
input_video = gr.Video()
|
146 |
+
with gr.Row():
|
147 |
+
button_video = gr.Button("Auto!")
|
148 |
+
output_video = gr.Video(format='mp4')
|
149 |
+
gr.Markdown('''
|
150 |
+
**Note:** processing video will take a long time, please upload a short video.
|
151 |
+
''')
|
152 |
+
gr.Examples(
|
153 |
+
examples=[os.path.join(os.path.dirname(__file__), "./images/video1.mp4")],
|
154 |
+
inputs=input_video,
|
155 |
+
outputs=output_video
|
156 |
+
)
|
157 |
|
158 |
+
# button image
|
159 |
+
button.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
|
160 |
+
min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
|
161 |
+
crop_nms_thresh, input_image],
|
162 |
+
outputs=[output_image, output_mask])
|
163 |
+
# button video
|
164 |
+
button_video.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
|
165 |
+
min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
|
166 |
+
crop_nms_thresh, input_video],
|
167 |
+
outputs=[output_video])
|
168 |
|
169 |
|
170 |
+
demo.queue().launch(debug=True, enable_queue=True)
|
171 |
|
172 |
|
173 |
|