Mar2Ding commited on
Commit
27ee3a9
·
verified ·
1 Parent(s): 10773b5

Update setup.py

Browse files
Files changed (1) hide show
  1. setup.py +632 -68
setup.py CHANGED
@@ -1,72 +1,636 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from setuptools import find_packages, setup
8
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension
9
-
10
- # Package metadata
11
- NAME = "SAM 2"
12
- VERSION = "1.0"
13
- DESCRIPTION = "SAM 2: Segment Anything in Images and Videos"
14
- URL = "https://github.com/facebookresearch/segment-anything-2"
15
- AUTHOR = "Meta AI"
16
- AUTHOR_EMAIL = "[email protected]"
17
- LICENSE = "Apache 2.0"
18
-
19
- # Read the contents of README file
20
- with open("README.md", "r") as f:
21
- LONG_DESCRIPTION = f.read()
22
-
23
- # Required dependencies
24
- REQUIRED_PACKAGES = [
25
- "torch>=2.3.1",
26
- "torchvision>=0.18.1",
27
- "numpy>=1.24.4",
28
- "tqdm>=4.66.1",
29
- "hydra-core>=1.3.2",
30
- "iopath>=0.1.10",
31
- "pillow>=9.4.0",
32
- ]
33
-
34
- EXTRA_PACKAGES = {
35
- "demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"],
36
- "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"],
 
37
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
39
 
40
- def get_extensions():
41
- srcs = ["sam2/csrc/connected_components.cu"]
42
- compile_args = {
43
- "cxx": [],
44
- "nvcc": [
45
- "-DCUDA_HAS_FP16=1",
46
- "-D__CUDA_NO_HALF_OPERATORS__",
47
- "-D__CUDA_NO_HALF_CONVERSIONS__",
48
- "-D__CUDA_NO_HALF2_OPERATORS__",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  ],
50
- }
51
- ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
52
- return ext_modules
53
-
54
-
55
- # Setup configuration
56
- setup(
57
- name=NAME,
58
- version=VERSION,
59
- description=DESCRIPTION,
60
- long_description=LONG_DESCRIPTION,
61
- long_description_content_type="text/markdown",
62
- url=URL,
63
- author=AUTHOR,
64
- author_email=AUTHOR_EMAIL,
65
- license=LICENSE,
66
- packages=find_packages(exclude="notebooks"),
67
- install_requires=REQUIRED_PACKAGES,
68
- extras_require=EXTRA_PACKAGES,
69
- python_requires=">=3.10.0",
70
- ext_modules=get_extensions(),
71
- cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)},
72
- )
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import re
3
+ from typing import List, Tuple, Optional
4
+ import os
5
+
6
+
7
+ # Define the command to be executed
8
+ command = ["python", "setup.py", "build_ext", "--inplace"]
9
+
10
+ # Execute the command
11
+ result = subprocess.run(command, capture_output=True, text=True)
12
+
13
+
14
+
15
+
16
+ def install_cuda_toolkit():
17
+ # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
18
+ CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
19
+ CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
20
+ subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
21
+ subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
22
+ subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
23
+
24
+ os.environ["CUDA_HOME"] = "/usr/local/cuda"
25
+ os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
26
+ os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
27
+ os.environ["CUDA_HOME"],
28
+ "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
29
+ )
30
+ # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
31
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
32
+
33
+ install_cuda_toolkit()
34
+
35
+ css="""
36
+ div#component-18, div#component-25, div#component-35, div#component-41{
37
+ align-items: stretch!important;
38
  }
39
+ """
40
+
41
+ # Print the output and error (if any)
42
+ print("Output:\n", result.stdout)
43
+ print("Errors:\n", result.stderr)
44
+
45
+ # Check if the command was successful
46
+ if result.returncode == 0:
47
+ print("Command executed successfully.")
48
+ else:
49
+ print("Command failed with return code:", result.returncode)
50
+
51
+ import gradio as gr
52
+ from datetime import datetime
53
+ os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
54
+ import torch
55
+ import numpy as np
56
+ import cv2
57
+ import matplotlib.pyplot as plt
58
+ from PIL import Image, ImageFilter
59
+ from sam2.build_sam import build_sam2_video_predictor
60
+
61
+ from moviepy.editor import ImageSequenceClip
62
+
63
+ def get_video_fps(video_path):
64
+ # Open the video file
65
+ cap = cv2.VideoCapture(video_path)
66
+
67
+ if not cap.isOpened():
68
+ print("Error: Could not open video.")
69
+ return None
70
+
71
+ # Get the FPS of the video
72
+ fps = cap.get(cv2.CAP_PROP_FPS)
73
+
74
+ return fps
75
+
76
+ def clear_points(image):
77
+ # we clean all
78
+ return [
79
+ image, # first_frame_path
80
+ gr.State([]), # tracking_points
81
+ gr.State([]), # trackings_input_label
82
+ image, # points_map
83
+ #gr.State() # stored_inference_state
84
+ ]
85
+
86
+ def preprocess_video_in(video_path):
87
+
88
+ # Generate a unique ID based on the current date and time
89
+ unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
90
+
91
+ # Set directory with this ID to store video frames
92
+ extracted_frames_output_dir = f'frames_{unique_id}'
93
+
94
+ # Create the output directory
95
+ os.makedirs(extracted_frames_output_dir, exist_ok=True)
96
+
97
+ ### Process video frames ###
98
+ # Open the video file
99
+ cap = cv2.VideoCapture(video_path)
100
+
101
+ if not cap.isOpened():
102
+ print("Error: Could not open video.")
103
+ return None
104
+
105
+ # Get the frames per second (FPS) of the video
106
+ fps = cap.get(cv2.CAP_PROP_FPS)
107
+
108
+ # Calculate the number of frames to process (10 seconds of video)
109
+ max_frames = int(fps * 10)
110
+
111
+ frame_number = 0
112
+ first_frame = None
113
+
114
+ while True:
115
+ ret, frame = cap.read()
116
+ if not ret or frame_number >= max_frames:
117
+ break
118
+
119
+ # Format the frame filename as '00000.jpg'
120
+ frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
121
+
122
+ # Save the frame as a JPEG file
123
+ cv2.imwrite(frame_filename, frame)
124
+
125
+ # Store the first frame
126
+ if frame_number == 0:
127
+ first_frame = frame_filename
128
+
129
+ frame_number += 1
130
+
131
+ # Release the video capture object
132
+ cap.release()
133
+
134
+ # scan all the JPEG frame names in this directory
135
+ scanned_frames = [
136
+ p for p in os.listdir(extracted_frames_output_dir)
137
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
138
+ ]
139
+ scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
140
+ # print(f"SCANNED_FRAMES: {scanned_frames}")
141
+
142
+ return [
143
+ first_frame, # first_frame_path
144
+ gr.State([]), # tracking_points
145
+ gr.State([]), # trackings_input_label
146
+ first_frame, # input_first_frame_image
147
+ first_frame, # points_map
148
+ extracted_frames_output_dir, # video_frames_dir
149
+ scanned_frames, # scanned_frames
150
+ None, # stored_inference_state
151
+ None, # stored_frame_names
152
+ gr.update(open=False) # video_in_drawer
153
+ ]
154
+
155
+ def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
156
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
157
+
158
+ tracking_points.value.append(evt.index)
159
+ print(f"TRACKING POINT: {tracking_points.value}")
160
+
161
+ if point_type == "include":
162
+ trackings_input_label.value.append(1)
163
+ elif point_type == "exclude":
164
+ trackings_input_label.value.append(0)
165
+ print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
166
+
167
+ # Open the image and get its dimensions
168
+ transparent_background = Image.open(input_first_frame_image).convert('RGBA')
169
+ w, h = transparent_background.size
170
+
171
+ # Define the circle radius as a fraction of the smaller dimension
172
+ fraction = 0.02 # You can adjust this value as needed
173
+ radius = int(fraction * min(w, h))
174
+
175
+ # Create a transparent layer to draw on
176
+ transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
177
+
178
+ for index, track in enumerate(tracking_points.value):
179
+ if trackings_input_label.value[index] == 1:
180
+ cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
181
+ else:
182
+ cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
183
+
184
+ # Convert the transparent layer back to an image
185
+ transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
186
+ selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
187
+
188
+ return tracking_points, trackings_input_label, selected_point_map
189
+
190
+ # use bfloat16 for the entire notebook
191
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
192
+
193
+ if torch.cuda.get_device_properties(0).major >= 8:
194
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
195
+ torch.backends.cuda.matmul.allow_tf32 = True
196
+ torch.backends.cudnn.allow_tf32 = True
197
+
198
+ def show_mask(mask, ax, obj_id=None, random_color=False):
199
+ if random_color:
200
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
201
+ else:
202
+ cmap = plt.get_cmap("tab10")
203
+ cmap_idx = 0 if obj_id is None else obj_id
204
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
205
+ h, w = mask.shape[-2:]
206
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
207
+ ax.imshow(mask_image)
208
+
209
+
210
+ def show_points(coords, labels, ax, marker_size=200):
211
+ pos_points = coords[labels==1]
212
+ neg_points = coords[labels==0]
213
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
214
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
215
+
216
+ def show_box(box, ax):
217
+ x0, y0 = box[0], box[1]
218
+ w, h = box[2] - box[0], box[3] - box[1]
219
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
220
+
221
+
222
+ def load_model(checkpoint):
223
+ # Load model accordingly to user's choice
224
+ if checkpoint == "tiny":
225
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
226
+ model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
227
+ return [sam2_checkpoint, model_cfg]
228
+ elif checkpoint == "samll":
229
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
230
+ model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
231
+ return [sam2_checkpoint, model_cfg]
232
+ elif checkpoint == "base-plus":
233
+ sam2_checkpoint = "./checkpoints/sam2.1_hiera_base_plus.pt"
234
+ model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
235
+ return [sam2_checkpoint, model_cfg]
236
+ # elif checkpoint == "large":
237
+ # sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
238
+ # model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
239
+ # return [sam2_checkpoint, model_cfg]
240
+
241
+
242
+
243
+ def get_mask_sam_process(
244
+ stored_inference_state,
245
+ input_first_frame_image,
246
+ checkpoint,
247
+ tracking_points,
248
+ trackings_input_label,
249
+ video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
250
+ scanned_frames,
251
+ working_frame: str = None, # current frame being added points
252
+ available_frames_to_check: List[str] = [],
253
+ # progress=gr.Progress(track_tqdm=True)
254
+ ):
255
+
256
+ # get model and model config paths
257
+ print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
258
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
259
+ print("MODEL LOADED")
260
+
261
+ # set predictor
262
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
263
+ print("PREDICTOR READY")
264
 
265
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
266
+ # print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
267
+ video_dir = video_frames_dir
268
+
269
+ # scan all the JPEG frame names in this directory
270
+ frame_names = scanned_frames
271
 
272
+ # print(f"STORED INFERENCE STEP: {stored_inference_state}")
273
+ if stored_inference_state is None:
274
+ # Init SAM2 inference_state
275
+ inference_state = predictor.init_state(video_path=video_dir)
276
+ inference_state['num_pathway'] = 3
277
+ inference_state['iou_thre'] = 0.3
278
+ inference_state['uncertainty'] = 2
279
+ print("NEW INFERENCE_STATE INITIATED")
280
+ else:
281
+ inference_state = stored_inference_state
282
+
283
+ # segment and track one object
284
+ # predictor.reset_state(inference_state) # if any previous tracking, reset
285
+
286
+
287
+ ### HANDLING WORKING FRAME
288
+ # new_working_frame = None
289
+ # Add new point
290
+ if working_frame is None:
291
+ ann_frame_idx = 0 # the frame index we interact with, 0 if it is the first frame
292
+ working_frame = "00000.jpg"
293
+ else:
294
+ # Use a regular expression to find the integer
295
+ match = re.search(r'frame_(\d+)', working_frame)
296
+ if match:
297
+ # Extract the integer from the match
298
+ frame_number = int(match.group(1))
299
+ ann_frame_idx = frame_number
300
+
301
+ print(f"NEW_WORKING_FRAME PATH: {working_frame}")
302
+
303
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
304
+
305
+ # Let's add a positive click at (x, y) = (210, 350) to get started
306
+ points = np.array(tracking_points.value, dtype=np.float32)
307
+ # for labels, `1` means positive click and `0` means negative click
308
+ labels = np.array(trackings_input_label.value, np.int32)
309
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points(
310
+ inference_state=inference_state,
311
+ frame_idx=ann_frame_idx,
312
+ obj_id=ann_obj_id,
313
+ points=points,
314
+ labels=labels,
315
+ )
316
+
317
+ # Create the plot
318
+ plt.figure(figsize=(12, 8))
319
+ plt.title(f"frame {ann_frame_idx}")
320
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
321
+ show_points(points, labels, plt.gca())
322
+ show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
323
+
324
+ # Save the plot as a JPG file
325
+ first_frame_output_filename = "output_first_frame.jpg"
326
+ plt.savefig(first_frame_output_filename, format='jpg')
327
+ plt.close()
328
+ torch.cuda.empty_cache()
329
+
330
+ # Assuming available_frames_to_check.value is a list
331
+ if working_frame not in available_frames_to_check:
332
+ available_frames_to_check.append(working_frame)
333
+ print(available_frames_to_check)
334
+
335
+ # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
336
+ return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
337
+
338
+ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
339
+ #### PROPAGATION ####
340
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
341
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
342
+
343
+ inference_state = stored_inference_state
344
+ frame_names = stored_frame_names
345
+ video_dir = video_frames_dir
346
+
347
+ # Define a directory to save the JPEG images
348
+ frames_output_dir = "frames_output_images"
349
+ os.makedirs(frames_output_dir, exist_ok=True)
350
+
351
+ # Initialize a list to store file paths of saved images
352
+ jpeg_images = []
353
+
354
+ # run propagation throughout the video and collect the results in a dict
355
+ video_segments = {} # video_segments contains the per-frame segmentation results
356
+ # for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
357
+ # video_segments[out_frame_idx] = {
358
+ # out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
359
+ # for i, out_obj_id in enumerate(out_obj_ids)
360
+ # }
361
+
362
+ out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False,)
363
+ print(out_obj_ids)
364
+ for frame_idx in range(0, inference_state['num_frames']):
365
+
366
+ video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx]> 0.0).cpu().numpy()}
367
+ # output_scores_per_object[object_id][frame_idx] = out_mask_logits[frame_idx].cpu().numpy()
368
+
369
+ # render the segmentation results every few frames
370
+ if vis_frame_type == "check":
371
+ vis_frame_stride = 15
372
+ elif vis_frame_type == "render":
373
+ vis_frame_stride = 1
374
+
375
+ plt.close("all")
376
+ for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
377
+ plt.figure(figsize=(6, 4))
378
+ plt.title(f"frame {out_frame_idx}")
379
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
380
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
381
+ show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
382
+
383
+ # Define the output filename and save the figure as a JPEG file
384
+ output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
385
+ plt.savefig(output_filename, format='jpg')
386
+
387
+ # Close the plot
388
+ plt.close()
389
+
390
+ # Append the file path to the list
391
+ jpeg_images.append(output_filename)
392
+
393
+ if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
394
+ available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
395
+
396
+ torch.cuda.empty_cache()
397
+ print(f"JPEG_IMAGES: {jpeg_images}")
398
+
399
+ if vis_frame_type == "check":
400
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
401
+ elif vis_frame_type == "render":
402
+ # Create a video clip from the image sequence
403
+ original_fps = get_video_fps(video_in)
404
+ fps = original_fps # Frames per second
405
+ total_frames = len(jpeg_images)
406
+ clip = ImageSequenceClip(jpeg_images, fps=fps)
407
+ # Write the result to a file
408
+ final_vid_output_path = "output_video.mp4"
409
+
410
+ # Write the result to a file
411
+ clip.write_videofile(
412
+ final_vid_output_path,
413
+ codec='libx264'
414
+ )
415
+
416
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
417
+
418
+ def update_ui(vis_frame_type):
419
+ if vis_frame_type == "check":
420
+ return gr.update(visible=True), gr.update(visible=False)
421
+ elif vis_frame_type == "render":
422
+ return gr.update(visible=False), gr.update(visible=True)
423
+
424
+ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
425
+ new_working_frame = None
426
+ if working_frame == None:
427
+ new_working_frame = os.path.join(video_frames_dir, scanned_frames[0])
428
+
429
+ else:
430
+ # Use a regular expression to find the integer
431
+ match = re.search(r'frame_(\d+)', working_frame)
432
+ if match:
433
+ # Extract the integer from the match
434
+ frame_number = int(match.group(1))
435
+ ann_frame_idx = frame_number
436
+ new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
437
+ return gr.State([]), gr.State([]), new_working_frame, new_working_frame
438
+
439
+ def reset_propagation(first_frame_path, predictor, stored_inference_state):
440
+
441
+ predictor.reset_state(stored_inference_state)
442
+ # print(f"RESET State: {stored_inference_state} ")
443
+ return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
444
+
445
+
446
+ with gr.Blocks(css=css) as demo:
447
+ first_frame_path = gr.State()
448
+ tracking_points = gr.State([])
449
+ trackings_input_label = gr.State([])
450
+ video_frames_dir = gr.State()
451
+ scanned_frames = gr.State()
452
+ loaded_predictor = gr.State()
453
+ stored_inference_state = gr.State()
454
+ stored_frame_names = gr.State()
455
+ available_frames_to_check = gr.State([])
456
+ with gr.Column():
457
+ gr.Markdown(
458
+ """
459
+ <h1 style="text-align: center;">🔥 SAM2Long Demo 🔥</h1>
460
+ """
461
+ )
462
+ gr.Markdown(
463
+ """
464
+ This is a simple demo for video segmentation with [SAM2Long](https://github.com/Mark12Ding/SAM2Long).
465
+ """
466
+ )
467
+ gr.Markdown(
468
+ """
469
+ ### 📋 Instructions:
470
+
471
+ It is largely built on the [SAM2-Video-Predictor](https://huggingface.co/spaces/fffiloni/SAM2-Video-Predictor).
472
+
473
+ 1. **Upload your video** [MP4-24fps]
474
+ 2. With **'include' point type** selected, click on the object to mask on the first frame
475
+ 3. Switch to **'exclude' point type** if you want to specify an area to avoid
476
+ 4. **Get Mask!**
477
+ 5. **Check Propagation** every 15 frames
478
+ 6. **Propagate with "render"** to render the final masked video
479
+ 7. **Hit Reset** button if you want to refresh and start again
480
+
481
+ *Note: Input video will be processed for up to 10 seconds only for demo purposes.*
482
+ """
483
+ )
484
+ with gr.Row():
485
+
486
+ with gr.Column():
487
+ with gr.Group():
488
+ with gr.Group():
489
+ with gr.Row():
490
+ point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
491
+ clear_points_btn = gr.Button("Clear Points", scale=1)
492
+
493
+ input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
494
+
495
+ points_map = gr.Image(
496
+ label="Point n Click map",
497
+ type="filepath",
498
+ interactive=False
499
+ )
500
+
501
+ with gr.Group():
502
+ with gr.Row():
503
+ checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus"], value="tiny")
504
+ submit_btn = gr.Button("Get Mask", size="lg")
505
+
506
+ with gr.Accordion("Your video IN", open=True) as video_in_drawer:
507
+ video_in = gr.Video(label="Video IN", format="mp4")
508
+
509
+ gr.HTML("""
510
+
511
+ <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true">
512
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
513
+ </a> to skip queue and avoid OOM errors from heavy public load
514
+ """)
515
+
516
+ with gr.Column():
517
+ with gr.Group():
518
+ # with gr.Group():
519
+ # with gr.Row():
520
+ working_frame = gr.Dropdown(label="working frame ID", choices=[""], value="frame_0.jpg", visible=False, allow_custom_value=False, interactive=True)
521
+ # change_current = gr.Button("change current", visible=False)
522
+ # working_frame = []
523
+ output_result = gr.Image(label="current working mask ref")
524
+ with gr.Group():
525
+ with gr.Row():
526
+ vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
527
+ propagate_btn = gr.Button("Propagate", scale=1)
528
+ reset_prpgt_brn = gr.Button("Reset", visible=False)
529
+ output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
530
+ output_video = gr.Video(visible=False)
531
+ # output_result_mask = gr.Image()
532
+
533
+
534
+
535
+ # When new video is uploaded
536
+ video_in.upload(
537
+ fn = preprocess_video_in,
538
+ inputs = [video_in],
539
+ outputs = [
540
+ first_frame_path,
541
+ tracking_points, # update Tracking Points in the gr.State([]) object
542
+ trackings_input_label, # update Tracking Labels in the gr.State([]) object
543
+ input_first_frame_image, # hidden component used as ref when clearing points
544
+ points_map, # Image component where we add new tracking points
545
+ video_frames_dir, # Array where frames from video_in are deep stored
546
+ scanned_frames, # Scanned frames by SAM2
547
+ stored_inference_state, # Sam2 inference state
548
+ stored_frame_names, #
549
+ video_in_drawer, # Accordion to hide uploaded video player
550
+ ],
551
+ queue = False
552
+ )
553
+
554
+
555
+ # triggered when we click on image to add new points
556
+ points_map.select(
557
+ fn = get_point,
558
+ inputs = [
559
+ point_type, # "include" or "exclude"
560
+ tracking_points, # get tracking_points values
561
+ trackings_input_label, # get tracking label values
562
+ input_first_frame_image, # gr.State() first frame path
563
+ ],
564
+ outputs = [
565
+ tracking_points, # updated with new points
566
+ trackings_input_label, # updated with corresponding labels
567
+ points_map, # updated image with points
568
+ ],
569
+ queue = False
570
+ )
571
+
572
+ # Clear every points clicked and added to the map
573
+ clear_points_btn.click(
574
+ fn = clear_points,
575
+ inputs = input_first_frame_image, # we get the untouched hidden image
576
+ outputs = [
577
+ first_frame_path,
578
+ tracking_points,
579
+ trackings_input_label,
580
+ points_map,
581
+ #stored_inference_state,
582
+ ],
583
+ queue=False
584
+ )
585
+
586
+
587
+ # change_current.click(
588
+ # fn = switch_working_frame,
589
+ # inputs = [working_frame, scanned_frames, video_frames_dir],
590
+ # outputs = [tracking_points, trackings_input_label, input_first_frame_image, points_map],
591
+ # queue=False
592
+ # )
593
+
594
+
595
+ submit_btn.click(
596
+ fn = get_mask_sam_process,
597
+ inputs = [
598
+ stored_inference_state,
599
+ input_first_frame_image,
600
+ checkpoint,
601
+ tracking_points,
602
+ trackings_input_label,
603
+ video_frames_dir,
604
+ scanned_frames,
605
+ working_frame,
606
+ available_frames_to_check,
607
  ],
608
+ outputs = [
609
+ output_result,
610
+ stored_frame_names,
611
+ loaded_predictor,
612
+ stored_inference_state,
613
+ working_frame,
614
+ ],
615
+ queue=False
616
+ )
617
+
618
+ reset_prpgt_brn.click(
619
+ fn = reset_propagation,
620
+ inputs = [first_frame_path, loaded_predictor, stored_inference_state],
621
+ outputs = [points_map, tracking_points, trackings_input_label, output_propagated, stored_inference_state, output_result, available_frames_to_check, input_first_frame_image, working_frame, reset_prpgt_brn],
622
+ queue=False
623
+ )
624
+
625
+ propagate_btn.click(
626
+ fn = update_ui,
627
+ inputs = [vis_frame_type],
628
+ outputs = [output_propagated, output_video],
629
+ queue=False
630
+ ).then(
631
+ fn = propagate_to_all,
632
+ inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
633
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
634
+ )
635
+
636
+ demo.queue().launch(show_api=False, show_error=True, share=True, server_name="0.0.0.0", server_port=11111)