Spaces:
Runtime error
Runtime error
Commit
·
05187ec
1
Parent(s):
c2afc01
huggingface -- version 2
Browse files- .gitattributes +1 -0
- app.py +171 -78
- app_test.py +44 -21
- test.txt +0 -0
- test_beta.txt +0 -0
- test_sample/test-sample1.mp4 +3 -0
- tools/interact_tools.py +67 -67
- track_anything.py +14 -13
- tracker/base_tracker.py +36 -23
.gitattributes
CHANGED
|
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 35 |
assets/demo_version_1.MP4 filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/inpainting.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/qingming.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
test_sample/test-sample1.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -17,7 +17,7 @@ import torchvision
|
|
| 17 |
import torch
|
| 18 |
import concurrent.futures
|
| 19 |
import queue
|
| 20 |
-
|
| 21 |
# download checkpoints
|
| 22 |
def download_checkpoint(url, folder, filename):
|
| 23 |
os.makedirs(folder, exist_ok=True)
|
|
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
|
|
| 84 |
"masks": [None]*len(frames),
|
| 85 |
"logits": [None]*len(frames),
|
| 86 |
"select_frame_number": 0,
|
| 87 |
-
"fps":
|
| 88 |
}
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# get the select frame from gradio slider
|
| 92 |
-
def select_template(image_selection_slider, video_state):
|
| 93 |
|
| 94 |
# images = video_state[1]
|
| 95 |
image_selection_slider -= 1
|
|
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
|
|
| 100 |
model.samcontroler.sam_controler.reset_image()
|
| 101 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
| 102 |
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
return video_state["painted_images"][image_selection_slider], video_state
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# use sam to get the mask
|
| 107 |
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
|
@@ -133,17 +148,65 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
| 133 |
|
| 134 |
return painted_image, video_state, interactive_state
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
# tracking vos
|
| 137 |
-
def vos_tracking_video(video_state, interactive_state):
|
| 138 |
model.xmem.clear_memory()
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
fps = video_state["fps"]
|
| 142 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
| 149 |
interactive_state["inference_times"] += 1
|
|
@@ -152,7 +215,7 @@ def vos_tracking_video(video_state, interactive_state):
|
|
| 152 |
interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
|
| 153 |
interactive_state["positive_click_times"],
|
| 154 |
interactive_state["negative_click_times"]))
|
| 155 |
-
|
| 156 |
#### shanggao code for mask save
|
| 157 |
if interactive_state["mask_save"]:
|
| 158 |
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
|
|
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
| 176 |
output_path (str): The path to save the generated video.
|
| 177 |
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
| 178 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
frames = torch.from_numpy(np.asarray(frames))
|
| 180 |
if not os.path.exists(os.path.dirname(output_path)):
|
| 181 |
os.makedirs(os.path.dirname(output_path))
|
|
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
|
| 193 |
|
| 194 |
# args, defined in track_anything.py
|
| 195 |
args = parse_augment()
|
| 196 |
-
# args.port =
|
| 197 |
-
# args.device = "cuda:
|
| 198 |
# args.mask_save = True
|
| 199 |
|
| 200 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
|
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
|
|
| 208 |
"inference_times": 0,
|
| 209 |
"negative_click_times" : 0,
|
| 210 |
"positive_click_times": 0,
|
| 211 |
-
"mask_save": args.mask_save
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
video_state = gr.State(
|
| 214 |
{
|
| 215 |
"video_name": "",
|
|
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
|
|
| 225 |
with gr.Row():
|
| 226 |
|
| 227 |
# for user video input
|
| 228 |
-
with gr.Column(
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
|
| 231 |
|
| 232 |
|
| 233 |
-
with gr.Row(
|
| 234 |
# put the template frame under the radio button
|
| 235 |
-
with gr.Column(
|
| 236 |
# extract frames
|
| 237 |
with gr.Column():
|
| 238 |
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
| 239 |
|
| 240 |
# click points settins, negative or positive, mode continuous or single
|
| 241 |
with gr.Row():
|
| 242 |
-
with gr.Row(
|
| 243 |
point_prompt = gr.Radio(
|
| 244 |
choices=["Positive", "Negative"],
|
| 245 |
value="Positive",
|
| 246 |
label="Point Prompt",
|
| 247 |
-
interactive=True
|
|
|
|
| 248 |
click_mode = gr.Radio(
|
| 249 |
choices=["Continuous", "Single"],
|
| 250 |
value="Continuous",
|
| 251 |
label="Clicking Mode",
|
| 252 |
-
interactive=True
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
-
with gr.Column(
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
| 265 |
|
| 266 |
# first step: get the video information
|
| 267 |
extract_frames_button.click(
|
|
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
|
|
| 269 |
inputs=[
|
| 270 |
video_input, video_state
|
| 271 |
],
|
| 272 |
-
outputs=[video_state,
|
|
|
|
|
|
|
| 273 |
)
|
| 274 |
|
| 275 |
# second step: select images from slider
|
| 276 |
image_selection_slider.release(fn=select_template,
|
| 277 |
-
inputs=[image_selection_slider, video_state],
|
| 278 |
-
outputs=[template_frame, video_state], api_name="select_image")
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
|
| 281 |
template_frame.select(
|
| 282 |
fn=sam_refine,
|
| 283 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
| 284 |
outputs=[template_frame, video_state, interactive_state]
|
| 285 |
)
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
tracking_video_predict_button.click(
|
| 288 |
fn=vos_tracking_video,
|
| 289 |
-
inputs=[video_state, interactive_state],
|
| 290 |
outputs=[video_output, video_state, interactive_state]
|
| 291 |
)
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
# clear input
|
| 295 |
video_input.clear(
|
|
@@ -306,57 +413,43 @@ with gr.Blocks() as iface:
|
|
| 306 |
"inference_times": 0,
|
| 307 |
"negative_click_times" : 0,
|
| 308 |
"positive_click_times": 0,
|
| 309 |
-
"mask_save": args.mask_save
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
[],
|
| 314 |
-
[
|
| 315 |
-
video_state,
|
| 316 |
-
interactive_state,
|
| 317 |
-
click_state,
|
| 318 |
-
],
|
| 319 |
-
queue=False,
|
| 320 |
-
show_progress=False
|
| 321 |
-
)
|
| 322 |
-
clear_button_image.click(
|
| 323 |
-
lambda: (
|
| 324 |
-
{
|
| 325 |
-
"origin_images": None,
|
| 326 |
-
"painted_images": None,
|
| 327 |
-
"masks": None,
|
| 328 |
-
"logits": None,
|
| 329 |
-
"select_frame_number": 0,
|
| 330 |
-
"fps": 30
|
| 331 |
},
|
| 332 |
-
|
| 333 |
-
"inference_times": 0,
|
| 334 |
-
"negative_click_times" : 0,
|
| 335 |
-
"positive_click_times": 0,
|
| 336 |
-
"mask_save": args.mask_save
|
| 337 |
},
|
| 338 |
-
[[],[]]
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
[],
|
| 341 |
[
|
| 342 |
video_state,
|
| 343 |
interactive_state,
|
| 344 |
click_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
],
|
| 346 |
-
|
| 347 |
queue=False,
|
| 348 |
-
show_progress=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
)
|
| 351 |
-
clear_button_clike.click(
|
| 352 |
-
lambda: ([[],[]]),
|
| 353 |
-
[],
|
| 354 |
-
[click_state],
|
| 355 |
-
queue=False,
|
| 356 |
-
show_progress=False
|
| 357 |
)
|
| 358 |
iface.queue(concurrency_count=1)
|
| 359 |
-
iface.launch(enable_queue=True)
|
| 360 |
|
| 361 |
|
| 362 |
|
|
|
|
| 17 |
import torch
|
| 18 |
import concurrent.futures
|
| 19 |
import queue
|
| 20 |
+
from tools.painter import mask_painter, point_painter
|
| 21 |
# download checkpoints
|
| 22 |
def download_checkpoint(url, folder, filename):
|
| 23 |
os.makedirs(folder, exist_ok=True)
|
|
|
|
| 84 |
"masks": [None]*len(frames),
|
| 85 |
"logits": [None]*len(frames),
|
| 86 |
"select_frame_number": 0,
|
| 87 |
+
"fps": fps
|
| 88 |
}
|
| 89 |
+
video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
|
| 90 |
+
|
| 91 |
+
model.samcontroler.sam_controler.reset_image()
|
| 92 |
+
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
| 93 |
+
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
| 94 |
+
gr.update(visible=True), gr.update(visible=True), \
|
| 95 |
+
gr.update(visible=True), gr.update(visible=True), \
|
| 96 |
+
gr.update(visible=True), gr.update(visible=True), \
|
| 97 |
+
gr.update(visible=True), gr.update(visible=True), \
|
| 98 |
+
gr.update(visible=True)
|
| 99 |
|
| 100 |
# get the select frame from gradio slider
|
| 101 |
+
def select_template(image_selection_slider, video_state, interactive_state):
|
| 102 |
|
| 103 |
# images = video_state[1]
|
| 104 |
image_selection_slider -= 1
|
|
|
|
| 109 |
model.samcontroler.sam_controler.reset_image()
|
| 110 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
| 111 |
|
| 112 |
+
# # clear multi mask
|
| 113 |
+
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
|
| 114 |
|
| 115 |
+
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
| 116 |
+
|
| 117 |
+
def get_end_number(track_pause_number_slider, interactive_state):
|
| 118 |
+
interactive_state["track_end_number"] = track_pause_number_slider
|
| 119 |
+
return interactive_state
|
| 120 |
|
| 121 |
# use sam to get the mask
|
| 122 |
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
|
|
|
| 148 |
|
| 149 |
return painted_image, video_state, interactive_state
|
| 150 |
|
| 151 |
+
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
| 152 |
+
mask = video_state["masks"][video_state["select_frame_number"]]
|
| 153 |
+
interactive_state["multi_mask"]["masks"].append(mask)
|
| 154 |
+
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
| 155 |
+
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
| 156 |
+
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
| 157 |
+
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
|
| 158 |
+
|
| 159 |
+
def clear_click(video_state, click_state):
|
| 160 |
+
click_state = [[],[]]
|
| 161 |
+
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
| 162 |
+
return template_frame, click_state
|
| 163 |
+
|
| 164 |
+
def remove_multi_mask(interactive_state):
|
| 165 |
+
interactive_state["multi_mask"]["mask_names"]= []
|
| 166 |
+
interactive_state["multi_mask"]["masks"] = []
|
| 167 |
+
return interactive_state
|
| 168 |
+
|
| 169 |
+
def show_mask(video_state, interactive_state, mask_dropdown):
|
| 170 |
+
mask_dropdown.sort()
|
| 171 |
+
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
| 172 |
+
|
| 173 |
+
for i in range(len(mask_dropdown)):
|
| 174 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
| 175 |
+
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
| 176 |
+
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
| 177 |
+
|
| 178 |
+
return select_frame
|
| 179 |
+
|
| 180 |
# tracking vos
|
| 181 |
+
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
| 182 |
model.xmem.clear_memory()
|
| 183 |
+
if interactive_state["track_end_number"]:
|
| 184 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
| 185 |
+
else:
|
| 186 |
+
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
| 187 |
+
|
| 188 |
+
if interactive_state["multi_mask"]["masks"]:
|
| 189 |
+
if len(mask_dropdown) == 0:
|
| 190 |
+
mask_dropdown = ["mask_001"]
|
| 191 |
+
mask_dropdown.sort()
|
| 192 |
+
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
|
| 193 |
+
for i in range(1,len(mask_dropdown)):
|
| 194 |
+
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
| 195 |
+
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
|
| 196 |
+
video_state["masks"][video_state["select_frame_number"]]= template_mask
|
| 197 |
+
else:
|
| 198 |
+
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 199 |
fps = video_state["fps"]
|
| 200 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
| 201 |
|
| 202 |
+
if interactive_state["track_end_number"]:
|
| 203 |
+
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
| 204 |
+
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
| 205 |
+
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
|
| 206 |
+
else:
|
| 207 |
+
video_state["masks"][video_state["select_frame_number"]:] = masks
|
| 208 |
+
video_state["logits"][video_state["select_frame_number"]:] = logits
|
| 209 |
+
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
|
| 210 |
|
| 211 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
| 212 |
interactive_state["inference_times"] += 1
|
|
|
|
| 215 |
interactive_state["positive_click_times"]+interactive_state["negative_click_times"],
|
| 216 |
interactive_state["positive_click_times"],
|
| 217 |
interactive_state["negative_click_times"]))
|
| 218 |
+
|
| 219 |
#### shanggao code for mask save
|
| 220 |
if interactive_state["mask_save"]:
|
| 221 |
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])):
|
|
|
|
| 239 |
output_path (str): The path to save the generated video.
|
| 240 |
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
| 241 |
"""
|
| 242 |
+
# height, width, layers = frames[0].shape
|
| 243 |
+
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
| 244 |
+
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
| 245 |
+
# print(output_path)
|
| 246 |
+
# for frame in frames:
|
| 247 |
+
# video.write(frame)
|
| 248 |
+
|
| 249 |
+
# video.release()
|
| 250 |
frames = torch.from_numpy(np.asarray(frames))
|
| 251 |
if not os.path.exists(os.path.dirname(output_path)):
|
| 252 |
os.makedirs(os.path.dirname(output_path))
|
|
|
|
| 264 |
|
| 265 |
# args, defined in track_anything.py
|
| 266 |
args = parse_augment()
|
| 267 |
+
# args.port = 12315
|
| 268 |
+
# args.device = "cuda:1"
|
| 269 |
# args.mask_save = True
|
| 270 |
|
| 271 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
|
|
|
| 279 |
"inference_times": 0,
|
| 280 |
"negative_click_times" : 0,
|
| 281 |
"positive_click_times": 0,
|
| 282 |
+
"mask_save": args.mask_save,
|
| 283 |
+
"multi_mask": {
|
| 284 |
+
"mask_names": [],
|
| 285 |
+
"masks": []
|
| 286 |
+
},
|
| 287 |
+
"track_end_number": None
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
video_state = gr.State(
|
| 292 |
{
|
| 293 |
"video_name": "",
|
|
|
|
| 303 |
with gr.Row():
|
| 304 |
|
| 305 |
# for user video input
|
| 306 |
+
with gr.Column():
|
| 307 |
+
with gr.Row(scale=0.4):
|
| 308 |
+
video_input = gr.Video(autosize=True)
|
| 309 |
+
video_info = gr.Textbox()
|
| 310 |
|
| 311 |
|
| 312 |
|
| 313 |
+
with gr.Row():
|
| 314 |
# put the template frame under the radio button
|
| 315 |
+
with gr.Column():
|
| 316 |
# extract frames
|
| 317 |
with gr.Column():
|
| 318 |
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
| 319 |
|
| 320 |
# click points settins, negative or positive, mode continuous or single
|
| 321 |
with gr.Row():
|
| 322 |
+
with gr.Row():
|
| 323 |
point_prompt = gr.Radio(
|
| 324 |
choices=["Positive", "Negative"],
|
| 325 |
value="Positive",
|
| 326 |
label="Point Prompt",
|
| 327 |
+
interactive=True,
|
| 328 |
+
visible=False)
|
| 329 |
click_mode = gr.Radio(
|
| 330 |
choices=["Continuous", "Single"],
|
| 331 |
value="Continuous",
|
| 332 |
label="Clicking Mode",
|
| 333 |
+
interactive=True,
|
| 334 |
+
visible=False)
|
| 335 |
+
with gr.Row():
|
| 336 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
| 337 |
+
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
| 338 |
+
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
| 339 |
+
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
|
| 340 |
+
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
|
|
|
| 341 |
|
| 342 |
+
with gr.Column():
|
| 343 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
|
| 344 |
+
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
| 345 |
+
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
| 346 |
+
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
| 347 |
|
| 348 |
# first step: get the video information
|
| 349 |
extract_frames_button.click(
|
|
|
|
| 351 |
inputs=[
|
| 352 |
video_input, video_state
|
| 353 |
],
|
| 354 |
+
outputs=[video_state, video_info, template_frame,
|
| 355 |
+
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
| 356 |
+
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
|
| 357 |
)
|
| 358 |
|
| 359 |
# second step: select images from slider
|
| 360 |
image_selection_slider.release(fn=select_template,
|
| 361 |
+
inputs=[image_selection_slider, video_state, interactive_state],
|
| 362 |
+
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
| 363 |
+
track_pause_number_slider.release(fn=get_end_number,
|
| 364 |
+
inputs=[track_pause_number_slider, interactive_state],
|
| 365 |
+
outputs=[interactive_state], api_name="end_image")
|
| 366 |
|
| 367 |
+
# click select image to get mask using sam
|
| 368 |
template_frame.select(
|
| 369 |
fn=sam_refine,
|
| 370 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
| 371 |
outputs=[template_frame, video_state, interactive_state]
|
| 372 |
)
|
| 373 |
|
| 374 |
+
# add different mask
|
| 375 |
+
Add_mask_button.click(
|
| 376 |
+
fn=add_multi_mask,
|
| 377 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
| 378 |
+
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
remove_mask_button.click(
|
| 382 |
+
fn=remove_multi_mask,
|
| 383 |
+
inputs=[interactive_state],
|
| 384 |
+
outputs=[interactive_state]
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# tracking video from select image and mask
|
| 388 |
tracking_video_predict_button.click(
|
| 389 |
fn=vos_tracking_video,
|
| 390 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
| 391 |
outputs=[video_output, video_state, interactive_state]
|
| 392 |
)
|
| 393 |
|
| 394 |
+
# click to get mask
|
| 395 |
+
mask_dropdown.change(
|
| 396 |
+
fn=show_mask,
|
| 397 |
+
inputs=[video_state, interactive_state, mask_dropdown],
|
| 398 |
+
outputs=[template_frame]
|
| 399 |
+
)
|
| 400 |
|
| 401 |
# clear input
|
| 402 |
video_input.clear(
|
|
|
|
| 413 |
"inference_times": 0,
|
| 414 |
"negative_click_times" : 0,
|
| 415 |
"positive_click_times": 0,
|
| 416 |
+
"mask_save": args.mask_save,
|
| 417 |
+
"multi_mask": {
|
| 418 |
+
"mask_names": [],
|
| 419 |
+
"masks": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
},
|
| 421 |
+
"track_end_number": 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
},
|
| 423 |
+
[[],[]],
|
| 424 |
+
None,
|
| 425 |
+
None,
|
| 426 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
| 427 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
| 428 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
|
| 429 |
+
|
| 430 |
+
),
|
| 431 |
[],
|
| 432 |
[
|
| 433 |
video_state,
|
| 434 |
interactive_state,
|
| 435 |
click_state,
|
| 436 |
+
video_output,
|
| 437 |
+
template_frame,
|
| 438 |
+
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
| 439 |
+
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
|
| 440 |
],
|
|
|
|
| 441 |
queue=False,
|
| 442 |
+
show_progress=False)
|
| 443 |
+
|
| 444 |
+
# points clear
|
| 445 |
+
clear_button_click.click(
|
| 446 |
+
fn = clear_click,
|
| 447 |
+
inputs = [video_state, click_state,],
|
| 448 |
+
outputs = [template_frame,click_state],
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
)
|
| 451 |
iface.queue(concurrency_count=1)
|
| 452 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
| 453 |
|
| 454 |
|
| 455 |
|
app_test.py
CHANGED
|
@@ -1,23 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
iface.launch(server_name='0.0.0.0', server_port=12212)
|
|
|
|
| 1 |
+
# import gradio as gr
|
| 2 |
+
|
| 3 |
+
# def update_iframe(slider_value):
|
| 4 |
+
# return f'''
|
| 5 |
+
# <script>
|
| 6 |
+
# window.addEventListener('message', function(event) {{
|
| 7 |
+
# if (event.data.sliderValue !== undefined) {{
|
| 8 |
+
# var iframe = document.getElementById("text_iframe");
|
| 9 |
+
# iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
|
| 10 |
+
# }}
|
| 11 |
+
# }}, false);
|
| 12 |
+
# </script>
|
| 13 |
+
# <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
|
| 14 |
+
# '''
|
| 15 |
+
|
| 16 |
+
# iface = gr.Interface(
|
| 17 |
+
# fn=update_iframe,
|
| 18 |
+
# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
|
| 19 |
+
# outputs=gr.outputs.HTML(),
|
| 20 |
+
# allow_flagging=False,
|
| 21 |
+
# )
|
| 22 |
+
|
| 23 |
+
# iface.launch(server_name='0.0.0.0', server_port=12212)
|
| 24 |
+
|
| 25 |
import gradio as gr
|
| 26 |
|
| 27 |
+
|
| 28 |
+
def change_mask(drop):
|
| 29 |
+
return gr.update(choices=["hello", "kitty"])
|
| 30 |
+
|
| 31 |
+
with gr.Blocks() as iface:
|
| 32 |
+
drop = gr.Dropdown(
|
| 33 |
+
choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
|
| 34 |
+
)
|
| 35 |
+
radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
|
| 36 |
+
multi_drop = gr.Dropdown(
|
| 37 |
+
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
multi_drop.change(
|
| 41 |
+
fn=change_mask,
|
| 42 |
+
inputs = multi_drop,
|
| 43 |
+
outputs=multi_drop
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
iface.launch(server_name='0.0.0.0', server_port=1223)
|
|
|
test.txt
ADDED
|
File without changes
|
test_beta.txt
ADDED
|
File without changes
|
test_sample/test-sample1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:403b711376a79026beedb7d0d919d35268298150120438a22a5330d0c8cdd6b6
|
| 3 |
+
size 6039223
|
tools/interact_tools.py
CHANGED
|
@@ -37,16 +37,16 @@ class SamControler():
|
|
| 37 |
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
| 38 |
|
| 39 |
|
| 40 |
-
def seg_again(self, image: np.ndarray):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
|
| 49 |
-
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
| 50 |
'''
|
| 51 |
it is used in first frame in video
|
| 52 |
return: mask, logit, painted image(mask+point)
|
|
@@ -88,47 +88,47 @@ class SamControler():
|
|
| 88 |
|
| 89 |
return mask, logit, painted_image
|
| 90 |
|
| 91 |
-
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
|
| 131 |
-
|
| 132 |
|
| 133 |
|
| 134 |
|
|
@@ -226,31 +226,31 @@ class SamControler():
|
|
| 226 |
|
| 227 |
|
| 228 |
|
| 229 |
-
if __name__ == "__main__":
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
|
| 255 |
|
| 256 |
|
|
|
|
| 37 |
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
| 38 |
|
| 39 |
|
| 40 |
+
# def seg_again(self, image: np.ndarray):
|
| 41 |
+
# '''
|
| 42 |
+
# it is used when interact in video
|
| 43 |
+
# '''
|
| 44 |
+
# self.sam_controler.reset_image()
|
| 45 |
+
# self.sam_controler.set_image(image)
|
| 46 |
+
# return
|
| 47 |
|
| 48 |
|
| 49 |
+
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
|
| 50 |
'''
|
| 51 |
it is used in first frame in video
|
| 52 |
return: mask, logit, painted image(mask+point)
|
|
|
|
| 88 |
|
| 89 |
return mask, logit, painted_image
|
| 90 |
|
| 91 |
+
# def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 92 |
+
# origal_image = self.sam_controler.orignal_image
|
| 93 |
+
# if same:
|
| 94 |
+
# '''
|
| 95 |
+
# true; loop in the same image
|
| 96 |
+
# '''
|
| 97 |
+
# prompts = {
|
| 98 |
+
# 'point_coords': points,
|
| 99 |
+
# 'point_labels': labels,
|
| 100 |
+
# 'mask_input': logits[None, :, :]
|
| 101 |
+
# }
|
| 102 |
+
# masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
| 103 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
| 104 |
|
| 105 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
| 106 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
| 107 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
| 108 |
+
# painted_image = Image.fromarray(painted_image)
|
| 109 |
|
| 110 |
+
# return mask, logit, painted_image
|
| 111 |
+
# else:
|
| 112 |
+
# '''
|
| 113 |
+
# loop in the different image, interact in the video
|
| 114 |
+
# '''
|
| 115 |
+
# if image is None:
|
| 116 |
+
# raise('Image error')
|
| 117 |
+
# else:
|
| 118 |
+
# self.seg_again(image)
|
| 119 |
+
# prompts = {
|
| 120 |
+
# 'point_coords': points,
|
| 121 |
+
# 'point_labels': labels,
|
| 122 |
+
# }
|
| 123 |
+
# masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
| 124 |
+
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
| 125 |
|
| 126 |
+
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
| 127 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
| 128 |
+
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
| 129 |
+
# painted_image = Image.fromarray(painted_image)
|
| 130 |
|
| 131 |
+
# return mask, logit, painted_image
|
| 132 |
|
| 133 |
|
| 134 |
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
|
| 229 |
+
# if __name__ == "__main__":
|
| 230 |
+
# points = np.array([[500, 375], [1125, 625]])
|
| 231 |
+
# labels = np.array([1, 1])
|
| 232 |
+
# image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
| 233 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 234 |
|
| 235 |
+
# sam_controler = initialize()
|
| 236 |
+
# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
| 237 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
| 238 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
| 239 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
| 240 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
| 241 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
| 242 |
|
| 243 |
+
# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
| 244 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
| 245 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
| 246 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
| 247 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
| 248 |
|
| 249 |
+
# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
| 250 |
+
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
| 251 |
+
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
| 252 |
+
# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
| 253 |
+
# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
| 254 |
|
| 255 |
|
| 256 |
|
track_anything.py
CHANGED
|
@@ -15,26 +15,26 @@ class TrackingAnything():
|
|
| 15 |
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
| 16 |
|
| 17 |
|
| 18 |
-
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
|
| 31 |
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
| 32 |
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
| 33 |
return mask, logit, painted_image
|
| 34 |
|
| 35 |
-
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 36 |
-
|
| 37 |
-
|
| 38 |
|
| 39 |
def generator(self, images: list, template_mask:np.ndarray):
|
| 40 |
|
|
@@ -53,6 +53,7 @@ class TrackingAnything():
|
|
| 53 |
masks.append(mask)
|
| 54 |
logits.append(logit)
|
| 55 |
painted_images.append(painted_image)
|
|
|
|
| 56 |
return masks, logits, painted_images
|
| 57 |
|
| 58 |
|
|
|
|
| 15 |
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
| 16 |
|
| 17 |
|
| 18 |
+
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
| 19 |
+
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 20 |
+
# if first_flag:
|
| 21 |
+
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
| 22 |
+
# return mask, logit, painted_image
|
| 23 |
|
| 24 |
+
# if interact_flag:
|
| 25 |
+
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
| 26 |
+
# return mask, logit, painted_image
|
| 27 |
|
| 28 |
+
# mask, logit, painted_image = self.xmem.track(image, logit)
|
| 29 |
+
# return mask, logit, painted_image
|
| 30 |
|
| 31 |
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
| 32 |
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
| 33 |
return mask, logit, painted_image
|
| 34 |
|
| 35 |
+
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 36 |
+
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
| 37 |
+
# return mask, logit, painted_image
|
| 38 |
|
| 39 |
def generator(self, images: list, template_mask:np.ndarray):
|
| 40 |
|
|
|
|
| 53 |
masks.append(mask)
|
| 54 |
logits.append(logit)
|
| 55 |
painted_images.append(painted_image)
|
| 56 |
+
print("tracking image {}".format(i))
|
| 57 |
return masks, logits, painted_images
|
| 58 |
|
| 59 |
|
tracker/base_tracker.py
CHANGED
|
@@ -67,6 +67,7 @@ class BaseTracker:
|
|
| 67 |
logit: numpy arrays, probability map (H, W)
|
| 68 |
painted_image: numpy array (H, W, 3)
|
| 69 |
"""
|
|
|
|
| 70 |
if first_frame_annotation is not None: # first frame mask
|
| 71 |
# initialisation
|
| 72 |
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
|
@@ -87,12 +88,20 @@ class BaseTracker:
|
|
| 87 |
out_mask = torch.argmax(probs, dim=0)
|
| 88 |
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
painted_image = frame
|
| 92 |
for obj in range(1, num_objs+1):
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
|
| 97 |
@torch.no_grad()
|
| 98 |
def sam_refinement(self, frame, logits, ti):
|
|
@@ -142,34 +151,38 @@ if __name__ == '__main__':
|
|
| 142 |
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
| 143 |
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
| 144 |
|
| 145 |
-
# test for storage efficiency
|
| 146 |
-
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
| 147 |
-
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
mask, prob, painted_image = tracker.track(frame)
|
| 157 |
-
# save
|
| 158 |
-
painted_image = Image.fromarray(painted_image)
|
| 159 |
-
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
| 160 |
|
| 161 |
-
tracker.clear_memory()
|
| 162 |
for ti, frame in enumerate(frames):
|
| 163 |
-
print(ti)
|
| 164 |
-
# if ti > 200:
|
| 165 |
-
# break
|
| 166 |
if ti == 0:
|
| 167 |
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
| 168 |
else:
|
| 169 |
mask, prob, painted_image = tracker.track(frame)
|
| 170 |
# save
|
| 171 |
painted_image = Image.fromarray(painted_image)
|
| 172 |
-
painted_image.save(f'/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# # track anything given in the first frame annotation
|
| 175 |
# for ti, frame in enumerate(frames):
|
|
|
|
| 67 |
logit: numpy arrays, probability map (H, W)
|
| 68 |
painted_image: numpy array (H, W, 3)
|
| 69 |
"""
|
| 70 |
+
|
| 71 |
if first_frame_annotation is not None: # first frame mask
|
| 72 |
# initialisation
|
| 73 |
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
|
|
|
| 88 |
out_mask = torch.argmax(probs, dim=0)
|
| 89 |
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
| 90 |
|
| 91 |
+
final_mask = np.zeros_like(out_mask)
|
| 92 |
+
|
| 93 |
+
# map back
|
| 94 |
+
for k, v in self.mapper.remappings.items():
|
| 95 |
+
final_mask[out_mask == v] = k
|
| 96 |
+
|
| 97 |
+
num_objs = final_mask.max()
|
| 98 |
painted_image = frame
|
| 99 |
for obj in range(1, num_objs+1):
|
| 100 |
+
if np.max(final_mask==obj) == 0:
|
| 101 |
+
continue
|
| 102 |
+
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
| 103 |
+
|
| 104 |
+
return final_mask, final_mask, painted_image
|
| 105 |
|
| 106 |
@torch.no_grad()
|
| 107 |
def sam_refinement(self, frame, logits, ti):
|
|
|
|
| 151 |
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
| 152 |
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
| 153 |
|
| 154 |
+
# # test for storage efficiency
|
| 155 |
+
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
| 156 |
+
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
| 157 |
|
| 158 |
+
first_frame_annotation[first_frame_annotation==1] = 15
|
| 159 |
+
first_frame_annotation[first_frame_annotation==2] = 20
|
| 160 |
+
|
| 161 |
+
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
| 162 |
+
if not os.path.exists(save_path):
|
| 163 |
+
os.mkdir(save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
|
|
|
| 165 |
for ti, frame in enumerate(frames):
|
|
|
|
|
|
|
|
|
|
| 166 |
if ti == 0:
|
| 167 |
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
| 168 |
else:
|
| 169 |
mask, prob, painted_image = tracker.track(frame)
|
| 170 |
# save
|
| 171 |
painted_image = Image.fromarray(painted_image)
|
| 172 |
+
painted_image.save(f'{save_path}/{ti:05d}.png')
|
| 173 |
+
|
| 174 |
+
# tracker.clear_memory()
|
| 175 |
+
# for ti, frame in enumerate(frames):
|
| 176 |
+
# print(ti)
|
| 177 |
+
# # if ti > 200:
|
| 178 |
+
# # break
|
| 179 |
+
# if ti == 0:
|
| 180 |
+
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
| 181 |
+
# else:
|
| 182 |
+
# mask, prob, painted_image = tracker.track(frame)
|
| 183 |
+
# # save
|
| 184 |
+
# painted_image = Image.fromarray(painted_image)
|
| 185 |
+
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
| 186 |
|
| 187 |
# # track anything given in the first frame annotation
|
| 188 |
# for ti, frame in enumerate(frames):
|