Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update setup.py
Browse files
    	
        setup.py
    CHANGED
    
    | @@ -1,72 +1,636 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
            #  | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
                 | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
| 25 | 
            -
                " | 
| 26 | 
            -
                 | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
| 29 | 
            -
                 | 
| 30 | 
            -
                 | 
| 31 | 
            -
                " | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
|  | |
| 37 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 39 |  | 
| 40 | 
            -
             | 
| 41 | 
            -
                 | 
| 42 | 
            -
             | 
| 43 | 
            -
                     | 
| 44 | 
            -
                     | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 49 | 
             
                    ],
         | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
                 | 
| 59 | 
            -
             | 
| 60 | 
            -
                 | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 | 
            -
                 | 
| 66 | 
            -
             | 
| 67 | 
            -
                 | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import subprocess
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            from typing import List, Tuple, Optional
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Define the command to be executed
         | 
| 8 | 
            +
            command = ["python", "setup.py", "build_ext", "--inplace"]
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Execute the command
         | 
| 11 | 
            +
            result = subprocess.run(command, capture_output=True, text=True)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def install_cuda_toolkit():
         | 
| 17 | 
            +
                # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
         | 
| 18 | 
            +
                CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
         | 
| 19 | 
            +
                CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
         | 
| 20 | 
            +
                subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
         | 
| 21 | 
            +
                subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
         | 
| 22 | 
            +
                subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                os.environ["CUDA_HOME"] = "/usr/local/cuda"
         | 
| 25 | 
            +
                os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
         | 
| 26 | 
            +
                os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
         | 
| 27 | 
            +
                    os.environ["CUDA_HOME"],
         | 
| 28 | 
            +
                    "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
         | 
| 29 | 
            +
                )
         | 
| 30 | 
            +
                # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
         | 
| 31 | 
            +
                os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            install_cuda_toolkit()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            css="""
         | 
| 36 | 
            +
            div#component-18, div#component-25, div#component-35, div#component-41{
         | 
| 37 | 
            +
                align-items: stretch!important;
         | 
| 38 | 
             
            }
         | 
| 39 | 
            +
            """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Print the output and error (if any)
         | 
| 42 | 
            +
            print("Output:\n", result.stdout)
         | 
| 43 | 
            +
            print("Errors:\n", result.stderr)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            # Check if the command was successful
         | 
| 46 | 
            +
            if result.returncode == 0:
         | 
| 47 | 
            +
                print("Command executed successfully.")
         | 
| 48 | 
            +
            else:
         | 
| 49 | 
            +
                print("Command failed with return code:", result.returncode)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            import gradio as gr
         | 
| 52 | 
            +
            from datetime import datetime
         | 
| 53 | 
            +
            os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
         | 
| 54 | 
            +
            import torch
         | 
| 55 | 
            +
            import numpy as np
         | 
| 56 | 
            +
            import cv2
         | 
| 57 | 
            +
            import matplotlib.pyplot as plt
         | 
| 58 | 
            +
            from PIL import Image, ImageFilter
         | 
| 59 | 
            +
            from sam2.build_sam import build_sam2_video_predictor
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            from moviepy.editor import ImageSequenceClip
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def get_video_fps(video_path):
         | 
| 64 | 
            +
                # Open the video file
         | 
| 65 | 
            +
                cap = cv2.VideoCapture(video_path)
         | 
| 66 | 
            +
                
         | 
| 67 | 
            +
                if not cap.isOpened():
         | 
| 68 | 
            +
                    print("Error: Could not open video.")
         | 
| 69 | 
            +
                    return None
         | 
| 70 | 
            +
                
         | 
| 71 | 
            +
                # Get the FPS of the video
         | 
| 72 | 
            +
                fps = cap.get(cv2.CAP_PROP_FPS)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                return fps
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def clear_points(image):
         | 
| 77 | 
            +
                # we clean all
         | 
| 78 | 
            +
                return [
         | 
| 79 | 
            +
                    image,   # first_frame_path
         | 
| 80 | 
            +
                    gr.State([]),      # tracking_points
         | 
| 81 | 
            +
                    gr.State([]),      # trackings_input_label
         | 
| 82 | 
            +
                    image,   # points_map
         | 
| 83 | 
            +
                    #gr.State()     # stored_inference_state
         | 
| 84 | 
            +
                ]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def preprocess_video_in(video_path):
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Generate a unique ID based on the current date and time
         | 
| 89 | 
            +
                unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                # Set directory with this ID to store video frames 
         | 
| 92 | 
            +
                extracted_frames_output_dir = f'frames_{unique_id}'
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                # Create the output directory
         | 
| 95 | 
            +
                os.makedirs(extracted_frames_output_dir, exist_ok=True)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                ### Process video frames ###
         | 
| 98 | 
            +
                # Open the video file
         | 
| 99 | 
            +
                cap = cv2.VideoCapture(video_path)
         | 
| 100 | 
            +
                
         | 
| 101 | 
            +
                if not cap.isOpened():
         | 
| 102 | 
            +
                    print("Error: Could not open video.")
         | 
| 103 | 
            +
                    return None
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # Get the frames per second (FPS) of the video
         | 
| 106 | 
            +
                fps = cap.get(cv2.CAP_PROP_FPS)
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                # Calculate the number of frames to process (10 seconds of video)
         | 
| 109 | 
            +
                max_frames = int(fps * 10)
         | 
| 110 | 
            +
                
         | 
| 111 | 
            +
                frame_number = 0
         | 
| 112 | 
            +
                first_frame = None
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                while True:
         | 
| 115 | 
            +
                    ret, frame = cap.read()
         | 
| 116 | 
            +
                    if not ret or frame_number >= max_frames:
         | 
| 117 | 
            +
                        break
         | 
| 118 | 
            +
                    
         | 
| 119 | 
            +
                    # Format the frame filename as '00000.jpg'
         | 
| 120 | 
            +
                    frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
         | 
| 121 | 
            +
                    
         | 
| 122 | 
            +
                    # Save the frame as a JPEG file
         | 
| 123 | 
            +
                    cv2.imwrite(frame_filename, frame)
         | 
| 124 | 
            +
                    
         | 
| 125 | 
            +
                    # Store the first frame
         | 
| 126 | 
            +
                    if frame_number == 0:
         | 
| 127 | 
            +
                        first_frame = frame_filename
         | 
| 128 | 
            +
                    
         | 
| 129 | 
            +
                    frame_number += 1
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
                # Release the video capture object
         | 
| 132 | 
            +
                cap.release()
         | 
| 133 | 
            +
                
         | 
| 134 | 
            +
                # scan all the JPEG frame names in this directory
         | 
| 135 | 
            +
                scanned_frames = [
         | 
| 136 | 
            +
                    p for p in os.listdir(extracted_frames_output_dir)
         | 
| 137 | 
            +
                    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
         | 
| 138 | 
            +
                ]
         | 
| 139 | 
            +
                scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
         | 
| 140 | 
            +
                # print(f"SCANNED_FRAMES: {scanned_frames}")
         | 
| 141 | 
            +
                
         | 
| 142 | 
            +
                return [
         | 
| 143 | 
            +
                    first_frame,           # first_frame_path
         | 
| 144 | 
            +
                    gr.State([]),          # tracking_points
         | 
| 145 | 
            +
                    gr.State([]),          # trackings_input_label
         | 
| 146 | 
            +
                    first_frame,           # input_first_frame_image
         | 
| 147 | 
            +
                    first_frame,           # points_map
         | 
| 148 | 
            +
                    extracted_frames_output_dir,            # video_frames_dir
         | 
| 149 | 
            +
                    scanned_frames,        # scanned_frames
         | 
| 150 | 
            +
                    None,                  # stored_inference_state
         | 
| 151 | 
            +
                    None,                  # stored_frame_names
         | 
| 152 | 
            +
                    gr.update(open=False)  # video_in_drawer
         | 
| 153 | 
            +
                ]
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            def get_point(point_type, tracking_points, trackings_input_label, input_first_frame_image, evt: gr.SelectData):
         | 
| 156 | 
            +
                print(f"You selected {evt.value} at {evt.index} from {evt.target}")
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                tracking_points.value.append(evt.index)
         | 
| 159 | 
            +
                print(f"TRACKING POINT: {tracking_points.value}")
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                if point_type == "include":
         | 
| 162 | 
            +
                    trackings_input_label.value.append(1)
         | 
| 163 | 
            +
                elif point_type == "exclude":
         | 
| 164 | 
            +
                    trackings_input_label.value.append(0)
         | 
| 165 | 
            +
                print(f"TRACKING INPUT LABEL: {trackings_input_label.value}")
         | 
| 166 | 
            +
                
         | 
| 167 | 
            +
                # Open the image and get its dimensions
         | 
| 168 | 
            +
                transparent_background = Image.open(input_first_frame_image).convert('RGBA')
         | 
| 169 | 
            +
                w, h = transparent_background.size
         | 
| 170 | 
            +
                
         | 
| 171 | 
            +
                # Define the circle radius as a fraction of the smaller dimension
         | 
| 172 | 
            +
                fraction = 0.02  # You can adjust this value as needed
         | 
| 173 | 
            +
                radius = int(fraction * min(w, h))
         | 
| 174 | 
            +
                
         | 
| 175 | 
            +
                # Create a transparent layer to draw on
         | 
| 176 | 
            +
                transparent_layer = np.zeros((h, w, 4), dtype=np.uint8)
         | 
| 177 | 
            +
                
         | 
| 178 | 
            +
                for index, track in enumerate(tracking_points.value):
         | 
| 179 | 
            +
                    if trackings_input_label.value[index] == 1:
         | 
| 180 | 
            +
                        cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1)
         | 
| 181 | 
            +
                    else:
         | 
| 182 | 
            +
                        cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                # Convert the transparent layer back to an image
         | 
| 185 | 
            +
                transparent_layer = Image.fromarray(transparent_layer, 'RGBA')
         | 
| 186 | 
            +
                selected_point_map = Image.alpha_composite(transparent_background, transparent_layer)
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                return tracking_points, trackings_input_label, selected_point_map
         | 
| 189 | 
            +
                
         | 
| 190 | 
            +
            # use bfloat16 for the entire notebook
         | 
| 191 | 
            +
            torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            if torch.cuda.get_device_properties(0).major >= 8:
         | 
| 194 | 
            +
                # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
         | 
| 195 | 
            +
                torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 196 | 
            +
                torch.backends.cudnn.allow_tf32 = True
         | 
| 197 | 
            +
                
         | 
| 198 | 
            +
            def show_mask(mask, ax, obj_id=None, random_color=False):
         | 
| 199 | 
            +
                if random_color:
         | 
| 200 | 
            +
                    color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
         | 
| 201 | 
            +
                else:
         | 
| 202 | 
            +
                    cmap = plt.get_cmap("tab10")
         | 
| 203 | 
            +
                    cmap_idx = 0 if obj_id is None else obj_id
         | 
| 204 | 
            +
                    color = np.array([*cmap(cmap_idx)[:3], 0.6])
         | 
| 205 | 
            +
                h, w = mask.shape[-2:]
         | 
| 206 | 
            +
                mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
         | 
| 207 | 
            +
                ax.imshow(mask_image)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            def show_points(coords, labels, ax, marker_size=200):
         | 
| 211 | 
            +
                pos_points = coords[labels==1]
         | 
| 212 | 
            +
                neg_points = coords[labels==0]
         | 
| 213 | 
            +
                ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
         | 
| 214 | 
            +
                ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            def show_box(box, ax):
         | 
| 217 | 
            +
                x0, y0 = box[0], box[1]
         | 
| 218 | 
            +
                w, h = box[2] - box[0], box[3] - box[1]
         | 
| 219 | 
            +
                ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            def load_model(checkpoint):
         | 
| 223 | 
            +
                # Load model accordingly to user's choice
         | 
| 224 | 
            +
                if checkpoint == "tiny":
         | 
| 225 | 
            +
                    sam2_checkpoint = "./checkpoints/sam2.1_hiera_tiny.pt"
         | 
| 226 | 
            +
                    model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
         | 
| 227 | 
            +
                    return [sam2_checkpoint, model_cfg]
         | 
| 228 | 
            +
                elif checkpoint == "samll":
         | 
| 229 | 
            +
                    sam2_checkpoint = "./checkpoints/sam2.1_hiera_small.pt"
         | 
| 230 | 
            +
                    model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
         | 
| 231 | 
            +
                    return [sam2_checkpoint, model_cfg]
         | 
| 232 | 
            +
                elif checkpoint == "base-plus":
         | 
| 233 | 
            +
                    sam2_checkpoint = "./checkpoints/sam2.1_hiera_base_plus.pt"
         | 
| 234 | 
            +
                    model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
         | 
| 235 | 
            +
                    return [sam2_checkpoint, model_cfg]
         | 
| 236 | 
            +
                # elif checkpoint == "large":
         | 
| 237 | 
            +
                #     sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
         | 
| 238 | 
            +
                #     model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
         | 
| 239 | 
            +
                #     return [sam2_checkpoint, model_cfg]
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
            def get_mask_sam_process(
         | 
| 244 | 
            +
                stored_inference_state,
         | 
| 245 | 
            +
                input_first_frame_image, 
         | 
| 246 | 
            +
                checkpoint, 
         | 
| 247 | 
            +
                tracking_points, 
         | 
| 248 | 
            +
                trackings_input_label, 
         | 
| 249 | 
            +
                video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
         | 
| 250 | 
            +
                scanned_frames, 
         | 
| 251 | 
            +
                working_frame: str = None, # current frame being added points
         | 
| 252 | 
            +
                available_frames_to_check: List[str] = [],
         | 
| 253 | 
            +
                # progress=gr.Progress(track_tqdm=True)
         | 
| 254 | 
            +
            ):
         | 
| 255 | 
            +
                
         | 
| 256 | 
            +
                # get model and model config paths
         | 
| 257 | 
            +
                print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
         | 
| 258 | 
            +
                sam2_checkpoint, model_cfg = load_model(checkpoint)
         | 
| 259 | 
            +
                print("MODEL LOADED")
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                # set predictor 
         | 
| 262 | 
            +
                predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
         | 
| 263 | 
            +
                print("PREDICTOR READY")
         | 
| 264 |  | 
| 265 | 
            +
                # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
         | 
| 266 | 
            +
                # print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
         | 
| 267 | 
            +
                video_dir = video_frames_dir
         | 
| 268 | 
            +
                
         | 
| 269 | 
            +
                # scan all the JPEG frame names in this directory
         | 
| 270 | 
            +
                frame_names = scanned_frames
         | 
| 271 |  | 
| 272 | 
            +
                # print(f"STORED INFERENCE STEP: {stored_inference_state}")
         | 
| 273 | 
            +
                if stored_inference_state is None:
         | 
| 274 | 
            +
                    # Init SAM2 inference_state
         | 
| 275 | 
            +
                    inference_state = predictor.init_state(video_path=video_dir)
         | 
| 276 | 
            +
                    inference_state['num_pathway'] = 3
         | 
| 277 | 
            +
                    inference_state['iou_thre'] = 0.3
         | 
| 278 | 
            +
                    inference_state['uncertainty'] = 2
         | 
| 279 | 
            +
                    print("NEW INFERENCE_STATE INITIATED")
         | 
| 280 | 
            +
                else:
         | 
| 281 | 
            +
                    inference_state = stored_inference_state
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                # segment and track one object
         | 
| 284 | 
            +
                # predictor.reset_state(inference_state) # if any previous tracking, reset
         | 
| 285 | 
            +
                
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                ### HANDLING WORKING FRAME
         | 
| 288 | 
            +
                # new_working_frame = None
         | 
| 289 | 
            +
                # Add new point
         | 
| 290 | 
            +
                if working_frame is None:
         | 
| 291 | 
            +
                    ann_frame_idx = 0  # the frame index we interact with, 0 if it is the first frame
         | 
| 292 | 
            +
                    working_frame = "00000.jpg"
         | 
| 293 | 
            +
                else:
         | 
| 294 | 
            +
                    # Use a regular expression to find the integer
         | 
| 295 | 
            +
                    match = re.search(r'frame_(\d+)', working_frame)
         | 
| 296 | 
            +
                    if match:
         | 
| 297 | 
            +
                        # Extract the integer from the match
         | 
| 298 | 
            +
                        frame_number = int(match.group(1))
         | 
| 299 | 
            +
                        ann_frame_idx = frame_number
         | 
| 300 | 
            +
                        
         | 
| 301 | 
            +
                print(f"NEW_WORKING_FRAME PATH: {working_frame}")
         | 
| 302 | 
            +
                
         | 
| 303 | 
            +
                ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)
         | 
| 304 | 
            +
                
         | 
| 305 | 
            +
                # Let's add a positive click at (x, y) = (210, 350) to get started
         | 
| 306 | 
            +
                points = np.array(tracking_points.value, dtype=np.float32)
         | 
| 307 | 
            +
                # for labels, `1` means positive click and `0` means negative click
         | 
| 308 | 
            +
                labels = np.array(trackings_input_label.value, np.int32)
         | 
| 309 | 
            +
                _, out_obj_ids, out_mask_logits = predictor.add_new_points(
         | 
| 310 | 
            +
                    inference_state=inference_state,
         | 
| 311 | 
            +
                    frame_idx=ann_frame_idx,
         | 
| 312 | 
            +
                    obj_id=ann_obj_id,
         | 
| 313 | 
            +
                    points=points,
         | 
| 314 | 
            +
                    labels=labels,
         | 
| 315 | 
            +
                )
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                # Create the plot
         | 
| 318 | 
            +
                plt.figure(figsize=(12, 8))
         | 
| 319 | 
            +
                plt.title(f"frame {ann_frame_idx}")
         | 
| 320 | 
            +
                plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
         | 
| 321 | 
            +
                show_points(points, labels, plt.gca())
         | 
| 322 | 
            +
                show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
         | 
| 323 | 
            +
                
         | 
| 324 | 
            +
                # Save the plot as a JPG file
         | 
| 325 | 
            +
                first_frame_output_filename = "output_first_frame.jpg"
         | 
| 326 | 
            +
                plt.savefig(first_frame_output_filename, format='jpg')
         | 
| 327 | 
            +
                plt.close()
         | 
| 328 | 
            +
                torch.cuda.empty_cache()
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                # Assuming available_frames_to_check.value is a list
         | 
| 331 | 
            +
                if working_frame not in available_frames_to_check:
         | 
| 332 | 
            +
                    available_frames_to_check.append(working_frame)
         | 
| 333 | 
            +
                    print(available_frames_to_check)
         | 
| 334 | 
            +
                
         | 
| 335 | 
            +
                # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
         | 
| 336 | 
            +
                return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
            def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):   
         | 
| 339 | 
            +
                #### PROPAGATION ####
         | 
| 340 | 
            +
                sam2_checkpoint, model_cfg = load_model(checkpoint)
         | 
| 341 | 
            +
                predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
         | 
| 342 | 
            +
                
         | 
| 343 | 
            +
                inference_state = stored_inference_state
         | 
| 344 | 
            +
                frame_names = stored_frame_names
         | 
| 345 | 
            +
                video_dir = video_frames_dir
         | 
| 346 | 
            +
                
         | 
| 347 | 
            +
                # Define a directory to save the JPEG images
         | 
| 348 | 
            +
                frames_output_dir = "frames_output_images"
         | 
| 349 | 
            +
                os.makedirs(frames_output_dir, exist_ok=True)
         | 
| 350 | 
            +
                
         | 
| 351 | 
            +
                # Initialize a list to store file paths of saved images
         | 
| 352 | 
            +
                jpeg_images = []
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                # run propagation throughout the video and collect the results in a dict
         | 
| 355 | 
            +
                video_segments = {}  # video_segments contains the per-frame segmentation results
         | 
| 356 | 
            +
                # for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
         | 
| 357 | 
            +
                #     video_segments[out_frame_idx] = {
         | 
| 358 | 
            +
                #         out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
         | 
| 359 | 
            +
                #         for i, out_obj_id in enumerate(out_obj_ids)
         | 
| 360 | 
            +
                #     }
         | 
| 361 | 
            +
                
         | 
| 362 | 
            +
                out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False,)
         | 
| 363 | 
            +
                print(out_obj_ids)
         | 
| 364 | 
            +
                for frame_idx in range(0, inference_state['num_frames']):
         | 
| 365 | 
            +
                    
         | 
| 366 | 
            +
                    video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx]> 0.0).cpu().numpy()}
         | 
| 367 | 
            +
                    # output_scores_per_object[object_id][frame_idx] = out_mask_logits[frame_idx].cpu().numpy()
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                # render the segmentation results every few frames
         | 
| 370 | 
            +
                if vis_frame_type == "check":
         | 
| 371 | 
            +
                    vis_frame_stride = 15
         | 
| 372 | 
            +
                elif vis_frame_type == "render":
         | 
| 373 | 
            +
                    vis_frame_stride = 1
         | 
| 374 | 
            +
                
         | 
| 375 | 
            +
                plt.close("all")
         | 
| 376 | 
            +
                for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
         | 
| 377 | 
            +
                    plt.figure(figsize=(6, 4))
         | 
| 378 | 
            +
                    plt.title(f"frame {out_frame_idx}")
         | 
| 379 | 
            +
                    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
         | 
| 380 | 
            +
                    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
         | 
| 381 | 
            +
                        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    # Define the output filename and save the figure as a JPEG file
         | 
| 384 | 
            +
                    output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
         | 
| 385 | 
            +
                    plt.savefig(output_filename, format='jpg')
         | 
| 386 | 
            +
                
         | 
| 387 | 
            +
                    # Close the plot
         | 
| 388 | 
            +
                    plt.close()
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # Append the file path to the list
         | 
| 391 | 
            +
                    jpeg_images.append(output_filename)
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
         | 
| 394 | 
            +
                        available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                torch.cuda.empty_cache()
         | 
| 397 | 
            +
                print(f"JPEG_IMAGES: {jpeg_images}")
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                if vis_frame_type == "check":
         | 
| 400 | 
            +
                    return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True)
         | 
| 401 | 
            +
                elif vis_frame_type == "render":
         | 
| 402 | 
            +
                    # Create a video clip from the image sequence
         | 
| 403 | 
            +
                    original_fps = get_video_fps(video_in)
         | 
| 404 | 
            +
                    fps = original_fps  # Frames per second
         | 
| 405 | 
            +
                    total_frames = len(jpeg_images)
         | 
| 406 | 
            +
                    clip = ImageSequenceClip(jpeg_images, fps=fps)
         | 
| 407 | 
            +
                    # Write the result to a file
         | 
| 408 | 
            +
                    final_vid_output_path = "output_video.mp4"
         | 
| 409 | 
            +
                    
         | 
| 410 | 
            +
                    # Write the result to a file
         | 
| 411 | 
            +
                    clip.write_videofile(
         | 
| 412 | 
            +
                        final_vid_output_path,
         | 
| 413 | 
            +
                        codec='libx264'
         | 
| 414 | 
            +
                    )
         | 
| 415 | 
            +
                    
         | 
| 416 | 
            +
                    return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
            def update_ui(vis_frame_type):
         | 
| 419 | 
            +
                if vis_frame_type == "check":
         | 
| 420 | 
            +
                    return gr.update(visible=True), gr.update(visible=False)
         | 
| 421 | 
            +
                elif vis_frame_type == "render":
         | 
| 422 | 
            +
                    return gr.update(visible=False), gr.update(visible=True)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
            def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
         | 
| 425 | 
            +
                new_working_frame = None
         | 
| 426 | 
            +
                if working_frame == None:
         | 
| 427 | 
            +
                    new_working_frame = os.path.join(video_frames_dir, scanned_frames[0])
         | 
| 428 | 
            +
                    
         | 
| 429 | 
            +
                else:
         | 
| 430 | 
            +
                    # Use a regular expression to find the integer
         | 
| 431 | 
            +
                    match = re.search(r'frame_(\d+)', working_frame)
         | 
| 432 | 
            +
                    if match:
         | 
| 433 | 
            +
                        # Extract the integer from the match
         | 
| 434 | 
            +
                        frame_number = int(match.group(1))
         | 
| 435 | 
            +
                        ann_frame_idx = frame_number
         | 
| 436 | 
            +
                        new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
         | 
| 437 | 
            +
                return gr.State([]), gr.State([]), new_working_frame, new_working_frame
         | 
| 438 | 
            +
             | 
| 439 | 
            +
            def reset_propagation(first_frame_path, predictor, stored_inference_state):
         | 
| 440 | 
            +
                
         | 
| 441 | 
            +
                predictor.reset_state(stored_inference_state)
         | 
| 442 | 
            +
                # print(f"RESET State: {stored_inference_state} ")
         | 
| 443 | 
            +
                return first_frame_path, gr.State([]), gr.State([]), gr.update(value=None, visible=False), stored_inference_state, None, ["frame_0.jpg"], first_frame_path, "frame_0.jpg", gr.update(visible=False)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
             | 
| 446 | 
            +
            with gr.Blocks(css=css) as demo:
         | 
| 447 | 
            +
                first_frame_path = gr.State()
         | 
| 448 | 
            +
                tracking_points = gr.State([])
         | 
| 449 | 
            +
                trackings_input_label = gr.State([])
         | 
| 450 | 
            +
                video_frames_dir = gr.State()
         | 
| 451 | 
            +
                scanned_frames = gr.State()
         | 
| 452 | 
            +
                loaded_predictor = gr.State()
         | 
| 453 | 
            +
                stored_inference_state = gr.State()
         | 
| 454 | 
            +
                stored_frame_names = gr.State()
         | 
| 455 | 
            +
                available_frames_to_check = gr.State([])
         | 
| 456 | 
            +
                with gr.Column():
         | 
| 457 | 
            +
                    gr.Markdown(
         | 
| 458 | 
            +
                        """
         | 
| 459 | 
            +
                        <h1 style="text-align: center;">🔥 SAM2Long Demo 🔥</h1>
         | 
| 460 | 
            +
                        """
         | 
| 461 | 
            +
                    )
         | 
| 462 | 
            +
                    gr.Markdown(
         | 
| 463 | 
            +
                        """
         | 
| 464 | 
            +
                        This is a simple demo for video segmentation with [SAM2Long](https://github.com/Mark12Ding/SAM2Long).
         | 
| 465 | 
            +
                        """
         | 
| 466 | 
            +
                    )
         | 
| 467 | 
            +
                    gr.Markdown(
         | 
| 468 | 
            +
                        """
         | 
| 469 | 
            +
                        ### 📋 Instructions:
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                        It is largely built on the [SAM2-Video-Predictor](https://huggingface.co/spaces/fffiloni/SAM2-Video-Predictor).
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                        1. **Upload your video** [MP4-24fps]
         | 
| 474 | 
            +
                        2. With **'include' point type** selected, click on the object to mask on the first frame
         | 
| 475 | 
            +
                        3. Switch to **'exclude' point type** if you want to specify an area to avoid
         | 
| 476 | 
            +
                        4. **Get Mask!**
         | 
| 477 | 
            +
                        5. **Check Propagation** every 15 frames
         | 
| 478 | 
            +
                        6. **Propagate with "render"** to render the final masked video
         | 
| 479 | 
            +
                        7. **Hit Reset** button if you want to refresh and start again
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                        *Note: Input video will be processed for up to 10 seconds only for demo purposes.*
         | 
| 482 | 
            +
                        """
         | 
| 483 | 
            +
                    )
         | 
| 484 | 
            +
                    with gr.Row():
         | 
| 485 | 
            +
                        
         | 
| 486 | 
            +
                        with gr.Column():
         | 
| 487 | 
            +
                            with gr.Group():
         | 
| 488 | 
            +
                                with gr.Group():
         | 
| 489 | 
            +
                                    with gr.Row():
         | 
| 490 | 
            +
                                        point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include", scale=2)
         | 
| 491 | 
            +
                                        clear_points_btn = gr.Button("Clear Points", scale=1)
         | 
| 492 | 
            +
                                
         | 
| 493 | 
            +
                                input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)                 
         | 
| 494 | 
            +
                                
         | 
| 495 | 
            +
                                points_map = gr.Image(
         | 
| 496 | 
            +
                                    label="Point n Click map", 
         | 
| 497 | 
            +
                                    type="filepath",
         | 
| 498 | 
            +
                                    interactive=False
         | 
| 499 | 
            +
                                )
         | 
| 500 | 
            +
                
         | 
| 501 | 
            +
                                with gr.Group():
         | 
| 502 | 
            +
                                    with gr.Row():
         | 
| 503 | 
            +
                                        checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus"], value="tiny")
         | 
| 504 | 
            +
                                        submit_btn = gr.Button("Get Mask", size="lg")
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                            with gr.Accordion("Your video IN", open=True) as video_in_drawer:
         | 
| 507 | 
            +
                                video_in = gr.Video(label="Video IN", format="mp4")
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                            gr.HTML("""
         | 
| 510 | 
            +
                            
         | 
| 511 | 
            +
                            <a href="https://huggingface.co/spaces/{os.environ['SPACE_ID']}?duplicate=true">
         | 
| 512 | 
            +
                                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg-dark.svg" alt="Duplicate this Space" />
         | 
| 513 | 
            +
                            </a> to skip queue and avoid OOM errors from heavy public load
         | 
| 514 | 
            +
                            """)
         | 
| 515 | 
            +
                        
         | 
| 516 | 
            +
                        with gr.Column():
         | 
| 517 | 
            +
                            with gr.Group():
         | 
| 518 | 
            +
                                # with gr.Group():
         | 
| 519 | 
            +
                                    # with gr.Row():
         | 
| 520 | 
            +
                                working_frame = gr.Dropdown(label="working frame ID", choices=[""], value="frame_0.jpg", visible=False, allow_custom_value=False, interactive=True)
         | 
| 521 | 
            +
                                    #     change_current = gr.Button("change current", visible=False)
         | 
| 522 | 
            +
                                # working_frame = []
         | 
| 523 | 
            +
                                output_result = gr.Image(label="current working mask ref")
         | 
| 524 | 
            +
                            with gr.Group():
         | 
| 525 | 
            +
                                with gr.Row():
         | 
| 526 | 
            +
                                    vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
         | 
| 527 | 
            +
                                    propagate_btn = gr.Button("Propagate", scale=1)
         | 
| 528 | 
            +
                            reset_prpgt_brn = gr.Button("Reset", visible=False)
         | 
| 529 | 
            +
                            output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
         | 
| 530 | 
            +
                            output_video = gr.Video(visible=False)
         | 
| 531 | 
            +
                            # output_result_mask = gr.Image()
         | 
| 532 | 
            +
                
         | 
| 533 | 
            +
                
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                # When new video is uploaded
         | 
| 536 | 
            +
                video_in.upload(
         | 
| 537 | 
            +
                    fn = preprocess_video_in, 
         | 
| 538 | 
            +
                    inputs = [video_in], 
         | 
| 539 | 
            +
                    outputs = [
         | 
| 540 | 
            +
                        first_frame_path, 
         | 
| 541 | 
            +
                        tracking_points, # update Tracking Points in the gr.State([]) object
         | 
| 542 | 
            +
                        trackings_input_label, # update Tracking Labels in the gr.State([]) object
         | 
| 543 | 
            +
                        input_first_frame_image, # hidden component used as ref when clearing points
         | 
| 544 | 
            +
                        points_map, # Image component where we add new tracking points
         | 
| 545 | 
            +
                        video_frames_dir, # Array where frames from video_in are deep stored
         | 
| 546 | 
            +
                        scanned_frames, # Scanned frames by SAM2
         | 
| 547 | 
            +
                        stored_inference_state, # Sam2 inference state
         | 
| 548 | 
            +
                        stored_frame_names, # 
         | 
| 549 | 
            +
                        video_in_drawer, # Accordion to hide uploaded video player
         | 
| 550 | 
            +
                    ],
         | 
| 551 | 
            +
                    queue = False
         | 
| 552 | 
            +
                )
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                
         | 
| 555 | 
            +
                # triggered when we click on image to add new points
         | 
| 556 | 
            +
                points_map.select(
         | 
| 557 | 
            +
                    fn = get_point, 
         | 
| 558 | 
            +
                    inputs = [
         | 
| 559 | 
            +
                        point_type, # "include" or "exclude"
         | 
| 560 | 
            +
                        tracking_points, # get tracking_points values
         | 
| 561 | 
            +
                        trackings_input_label, # get tracking label values
         | 
| 562 | 
            +
                        input_first_frame_image, # gr.State() first frame path
         | 
| 563 | 
            +
                    ], 
         | 
| 564 | 
            +
                    outputs = [
         | 
| 565 | 
            +
                        tracking_points, # updated with new points
         | 
| 566 | 
            +
                        trackings_input_label, # updated with corresponding labels
         | 
| 567 | 
            +
                        points_map, # updated image with points
         | 
| 568 | 
            +
                    ], 
         | 
| 569 | 
            +
                    queue = False
         | 
| 570 | 
            +
                )
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                # Clear every points clicked and added to the map
         | 
| 573 | 
            +
                clear_points_btn.click(
         | 
| 574 | 
            +
                    fn = clear_points,
         | 
| 575 | 
            +
                    inputs = input_first_frame_image, # we get the untouched hidden image
         | 
| 576 | 
            +
                    outputs = [
         | 
| 577 | 
            +
                        first_frame_path, 
         | 
| 578 | 
            +
                        tracking_points, 
         | 
| 579 | 
            +
                        trackings_input_label, 
         | 
| 580 | 
            +
                        points_map, 
         | 
| 581 | 
            +
                        #stored_inference_state,
         | 
| 582 | 
            +
                    ],
         | 
| 583 | 
            +
                    queue=False
         | 
| 584 | 
            +
                )
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                
         | 
| 587 | 
            +
                # change_current.click(
         | 
| 588 | 
            +
                #     fn = switch_working_frame,
         | 
| 589 | 
            +
                #     inputs = [working_frame, scanned_frames, video_frames_dir],
         | 
| 590 | 
            +
                #     outputs = [tracking_points, trackings_input_label, input_first_frame_image, points_map],
         | 
| 591 | 
            +
                #     queue=False
         | 
| 592 | 
            +
                # )
         | 
| 593 | 
            +
                
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                submit_btn.click(
         | 
| 596 | 
            +
                    fn = get_mask_sam_process,
         | 
| 597 | 
            +
                    inputs = [
         | 
| 598 | 
            +
                        stored_inference_state,
         | 
| 599 | 
            +
                        input_first_frame_image, 
         | 
| 600 | 
            +
                        checkpoint, 
         | 
| 601 | 
            +
                        tracking_points, 
         | 
| 602 | 
            +
                        trackings_input_label, 
         | 
| 603 | 
            +
                        video_frames_dir, 
         | 
| 604 | 
            +
                        scanned_frames, 
         | 
| 605 | 
            +
                        working_frame,
         | 
| 606 | 
            +
                        available_frames_to_check,
         | 
| 607 | 
             
                    ],
         | 
| 608 | 
            +
                    outputs = [
         | 
| 609 | 
            +
                        output_result, 
         | 
| 610 | 
            +
                        stored_frame_names, 
         | 
| 611 | 
            +
                        loaded_predictor,
         | 
| 612 | 
            +
                        stored_inference_state,
         | 
| 613 | 
            +
                        working_frame,
         | 
| 614 | 
            +
                    ],
         | 
| 615 | 
            +
                    queue=False
         | 
| 616 | 
            +
                )
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                reset_prpgt_brn.click(
         | 
| 619 | 
            +
                    fn = reset_propagation,
         | 
| 620 | 
            +
                    inputs = [first_frame_path, loaded_predictor, stored_inference_state],
         | 
| 621 | 
            +
                    outputs = [points_map, tracking_points, trackings_input_label, output_propagated, stored_inference_state, output_result, available_frames_to_check, input_first_frame_image, working_frame, reset_prpgt_brn],
         | 
| 622 | 
            +
                    queue=False
         | 
| 623 | 
            +
                )
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                propagate_btn.click(
         | 
| 626 | 
            +
                    fn = update_ui,
         | 
| 627 | 
            +
                    inputs = [vis_frame_type],
         | 
| 628 | 
            +
                    outputs = [output_propagated, output_video],
         | 
| 629 | 
            +
                    queue=False
         | 
| 630 | 
            +
                ).then(
         | 
| 631 | 
            +
                    fn = propagate_to_all,
         | 
| 632 | 
            +
                    inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame],
         | 
| 633 | 
            +
                    outputs = [output_propagated, output_video, working_frame, available_frames_to_check, reset_prpgt_brn]
         | 
| 634 | 
            +
                )
         | 
| 635 | 
            +
             | 
| 636 | 
            +
            demo.queue().launch(show_api=False, show_error=True, share=True, server_name="0.0.0.0", server_port=11111)
         | 
