Spaces:
Build error
Build error
Commit
·
94dd0a9
1
Parent(s):
39c26fe
add duplicate space
Browse files
app.py
CHANGED
|
@@ -13,9 +13,13 @@ import requests
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
|
|
|
|
|
|
| 16 |
from tools.painter import mask_painter
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# download checkpoints
|
| 21 |
def download_checkpoint(url, folder, filename):
|
|
@@ -202,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
| 202 |
|
| 203 |
# tracking vos
|
| 204 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 205 |
model.xmem.clear_memory()
|
| 206 |
if interactive_state["track_end_number"]:
|
| 207 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
@@ -221,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 221 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 222 |
fps = video_state["fps"]
|
| 223 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
|
|
| 224 |
|
| 225 |
if interactive_state["track_end_number"]:
|
| 226 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
@@ -260,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 260 |
|
| 261 |
# inpaint
|
| 262 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 263 |
frames = np.asarray(video_state["origin_images"])
|
| 264 |
fps = video_state["fps"]
|
| 265 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
@@ -306,27 +314,44 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
| 306 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 307 |
return output_path
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
# check and download checkpoints if needed
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
xmem_checkpoint = "XMem-s012.pth"
|
| 313 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 314 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 315 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 316 |
|
|
|
|
| 317 |
folder ="./checkpoints"
|
| 318 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
| 319 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 320 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
| 321 |
-
|
| 322 |
-
args = parse_augment()
|
| 323 |
-
# args.port = 12315
|
| 324 |
-
# args.device = "cuda:2"
|
| 325 |
-
# args.mask_save = True
|
| 326 |
|
| 327 |
# initialize sam, xmem, e2fgvi models
|
| 328 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
with gr.Blocks() as iface:
|
| 331 |
"""
|
| 332 |
state for
|
|
@@ -358,7 +383,8 @@ with gr.Blocks() as iface:
|
|
| 358 |
"fps": 30
|
| 359 |
}
|
| 360 |
)
|
| 361 |
-
|
|
|
|
| 362 |
with gr.Row():
|
| 363 |
|
| 364 |
# for user video input
|
|
@@ -367,7 +393,7 @@ with gr.Blocks() as iface:
|
|
| 367 |
video_input = gr.Video(autosize=True)
|
| 368 |
with gr.Column():
|
| 369 |
video_info = gr.Textbox()
|
| 370 |
-
|
| 371 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 372 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 373 |
|
|
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
| 16 |
+
from tools.interact_tools import SamControler
|
| 17 |
+
from tracker.base_tracker import BaseTracker
|
| 18 |
from tools.painter import mask_painter
|
| 19 |
+
try:
|
| 20 |
+
from mmcv.cnn import ConvModule
|
| 21 |
+
except:
|
| 22 |
+
os.system("mim install mmcv")
|
| 23 |
|
| 24 |
# download checkpoints
|
| 25 |
def download_checkpoint(url, folder, filename):
|
|
|
|
| 206 |
|
| 207 |
# tracking vos
|
| 208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
| 209 |
+
|
| 210 |
model.xmem.clear_memory()
|
| 211 |
if interactive_state["track_end_number"]:
|
| 212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
|
| 226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 227 |
fps = video_state["fps"]
|
| 228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
| 229 |
+
# clear GPU memory
|
| 230 |
+
model.xmem.clear_memory()
|
| 231 |
|
| 232 |
if interactive_state["track_end_number"]:
|
| 233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
|
| 267 |
|
| 268 |
# inpaint
|
| 269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
| 270 |
+
|
| 271 |
frames = np.asarray(video_state["origin_images"])
|
| 272 |
fps = video_state["fps"]
|
| 273 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
|
| 314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 315 |
return output_path
|
| 316 |
|
| 317 |
+
|
| 318 |
+
# args, defined in track_anything.py
|
| 319 |
+
args = parse_augment()
|
| 320 |
+
|
| 321 |
# check and download checkpoints if needed
|
| 322 |
+
SAM_checkpoint_dict = {
|
| 323 |
+
'vit_h': "sam_vit_h_4b8939.pth",
|
| 324 |
+
'vit_l': "sam_vit_l_0b3195.pth",
|
| 325 |
+
"vit_b": "sam_vit_b_01ec64.pth"
|
| 326 |
+
}
|
| 327 |
+
SAM_checkpoint_url_dict = {
|
| 328 |
+
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 329 |
+
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
| 330 |
+
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
| 331 |
+
}
|
| 332 |
+
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
| 333 |
+
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
| 334 |
xmem_checkpoint = "XMem-s012.pth"
|
| 335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 338 |
|
| 339 |
+
|
| 340 |
folder ="./checkpoints"
|
| 341 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
| 342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
| 344 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
# initialize sam, xmem, e2fgvi models
|
| 347 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 348 |
|
| 349 |
+
|
| 350 |
+
title = """<p><h1 align="center">Track-Anything</h1></p>
|
| 351 |
+
"""
|
| 352 |
+
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
| 353 |
+
|
| 354 |
+
|
| 355 |
with gr.Blocks() as iface:
|
| 356 |
"""
|
| 357 |
state for
|
|
|
|
| 383 |
"fps": 30
|
| 384 |
}
|
| 385 |
)
|
| 386 |
+
gr.Markdown(title)
|
| 387 |
+
gr.Markdown(description)
|
| 388 |
with gr.Row():
|
| 389 |
|
| 390 |
# for user video input
|
|
|
|
| 393 |
video_input = gr.Video(autosize=True)
|
| 394 |
with gr.Column():
|
| 395 |
video_info = gr.Textbox()
|
| 396 |
+
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
| 397 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 398 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 399 |
|