# The full rewritten version of the provided code with progress bar, error fixes, and proper Gradio integration import os import copy import tempfile from datetime import datetime import gc import cv2 import numpy as np from PIL import Image import matplotlib.pyplot as plt import torch import gradio as gr from moviepy.editor import ImageSequenceClip from sam2.build_sam import build_sam2_video_predictor # Remove CUDA-related env var to force CPU-only mode os.environ.pop("TORCH_CUDNN_SDPA_ENABLED", None) # Config sam2_checkpoint = "checkpoints/edgetam.pt" model_cfg = "edgetam.yaml" examples = [[f"examples/{vid}"] for vid in ["01_dog.mp4", "02_cups.mp4", "03_blocks.mp4", "04_coffee.mp4", "05_default_juggle.mp4"]] OBJ_ID = 0 # Model loader if os.path.exists(sam2_checkpoint) and os.path.exists(model_cfg): try: predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") except Exception as e: print("Error loading predictor:", e) predictor = None else: print("Model files missing.") predictor = None def get_fps(video_path): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return 30.0 fps = cap.get(cv2.CAP_PROP_FPS) cap.release() return fps def reset(session): if session["inference_state"]: predictor.reset_state(session["inference_state"]) session.update({"input_points": [], "input_labels": [], "first_frame": None, "all_frames": None, "inference_state": None}) return None, gr.update(open=True), None, None, gr.update(value=None, visible=False), session def clear_points(session): session["input_points"] = [] session["input_labels"] = [] if session["inference_state"] and session["inference_state"].get("tracking_has_started"): predictor.reset_state(session["inference_state"]) return session["first_frame"], None, gr.update(value=None, visible=False), session def preprocess_video(video_path, session): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return gr.update(open=True), None, None, gr.update(value=None, visible=False), session total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) stride = max(1, total_frames // 300) frames, first_frame = [], None w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) target_w = 640 scale = target_w / w if w > target_w else 1.0 frame_id = 0 while True: ret, frame = cap.read() if not ret: break if frame_id % stride == 0: if scale < 1.0: frame = cv2.resize(frame, (int(w*scale), int(h*scale))) frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if first_frame is None: first_frame = frame frames.append(frame) frame_id += 1 cap.release() session.update({"first_frame": first_frame, "all_frames": frames, "frame_stride": stride, "scale_factor": scale, "inference_state": predictor.init_state(video_path=video_path), "input_points": [], "input_labels": []}) return gr.update(open=False), first_frame, None, gr.update(value=None, visible=False), session def show_mask(mask, obj_id=None): cmap = plt.get_cmap("tab10") color = np.array([*cmap(0 if obj_id is None else obj_id)[:3], 0.6]) h, w = mask.shape mask_rgba = (mask.reshape(h, w, 1) * color.reshape(1, 1, -1) * 255).astype(np.uint8) proper_mask = np.zeros((h, w, 4), dtype=np.uint8) proper_mask[:, :, :min(mask_rgba.shape[2], 4)] = mask_rgba[:, :, :min(mask_rgba.shape[2], 4)] return Image.fromarray(proper_mask, "RGBA") def segment_with_points(ptype, session, evt): session["input_points"].append(evt.index) session["input_labels"].append(1 if ptype == "include" else 0) first = session["first_frame"] h, w = first.shape[:2] layer = np.zeros((h, w, 4), dtype=np.uint8) for idx, pt in enumerate(session["input_points"]): color = (0, 255, 0, 255) if session["input_labels"][idx] == 1 else (255, 0, 0, 255) cv2.circle(layer, pt, int(min(w, h)*0.01), color, -1) overlay = Image.alpha_composite(Image.fromarray(first).convert("RGBA"), Image.fromarray(layer, "RGBA")) try: _, _, logits = predictor.add_new_points(session["inference_state"], 0, OBJ_ID, np.array(session["input_points"]), np.array(session["input_labels"])) mask = (logits[0] > 0.0).cpu().numpy() mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) mask_img = show_mask(mask) return overlay, Image.alpha_composite(Image.fromarray(first).convert("RGBA"), mask_img), session except Exception as e: print("Segmentation error:", e) return overlay, overlay, session def propagate(video_in, session, progress=gr.Progress()): if not session["input_points"] or not session["inference_state"]: return None, session masks = {} for i, (idxs, obj_ids, logits) in enumerate(predictor.propagate_in_video(session["inference_state"])): try: masks[idxs] = {oid: (logits[j] > 0.0).cpu().numpy() for j, oid in enumerate(obj_ids)} progress(i / 300, desc=f"Tracking frame {idxs}") except: continue frames_out, stride = [], max(1, len(masks) // 50) for i in range(0, len(masks), stride): if i not in masks or OBJ_ID not in masks[i]: continue try: frame = session["all_frames"][i] mask = masks[i][OBJ_ID] h, w = frame.shape[:2] mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) output = Image.alpha_composite(Image.fromarray(frame).convert("RGBA"), show_mask(mask)) frames_out.append(np.array(output)) except: continue out_path = os.path.join(tempfile.gettempdir(), f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4") fps = min(15, get_fps(video_in)) ImageSequenceClip(frames_out, fps=fps).write_videofile(out_path, codec="libx264", bitrate="800k", threads=2, logger=None) gc.collect() return gr.update(value=out_path, visible=True), session with gr.Blocks() as demo: state = gr.State({"first_frame": None, "all_frames": None, "input_points": [], "input_labels": [], "inference_state": None, "frame_stride": 1, "scale_factor": 1.0, "original_dimensions": None}) gr.Markdown("