Spaces:
Running
Running
import gradio as gr | |
import cv2 | |
import torch | |
from sahi import AutoDetectionModel | |
from sahi.predict import get_sliced_prediction | |
from motpy import Detection as MotpyDetection, MultiObjectTracker | |
import tempfile | |
# COCO class names (YOLOv8 default) | |
COCO_CLASSES = [ | |
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', | |
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | |
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', | |
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', | |
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | |
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', | |
'hair drier', 'toothbrush' | |
] | |
model_path = "./yolo11n.pt" | |
detection_model = AutoDetectionModel.from_pretrained( | |
model_type='yolov8', | |
model_path=model_path, | |
confidence_threshold=0.3, | |
device='cpu' # Force CPU usage | |
) | |
def track_objects(video_path): | |
# Setup video processing | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
output_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
output_path = output_file.name | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
tracker = MultiObjectTracker( | |
dt=0.1, | |
model_spec={ | |
'order_pos': 1, 'dim_pos': 2, | |
'order_size': 0, 'dim_size': 2, | |
'q_var_pos': 5000., 'r_var_pos': 0.1 | |
} | |
) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
result = get_sliced_prediction( | |
rgb_frame, | |
detection_model, | |
slice_height=512, | |
slice_width=512, | |
overlap_height_ratio=0.2, | |
overlap_width_ratio=0.2 | |
) | |
detections = [ | |
MotpyDetection( | |
box=[obj.bbox.minx, obj.bbox.miny, obj.bbox.maxx, obj.bbox.maxy], | |
score=obj.score.value, | |
class_id=obj.category.id | |
) | |
for obj in result.object_prediction_list | |
] | |
tracker.step(detections) | |
tracks = tracker.active_tracks() | |
for track in tracks: | |
x1, y1, x2, y2 = map(int, track.box) | |
track_id = track.id | |
class_id = track.class_id if track.class_id is not None else -1 | |
class_name = COCO_CLASSES[class_id] if 0 <= class_id < len(COCO_CLASSES) else str(class_id) | |
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
cv2.putText(frame, f'{class_name} {track_id}', (x1, y1 - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return output_path | |
def process_video(video): | |
output_path = track_objects(video) | |
return output_path | |
interface = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(label="Input Video"), | |
outputs=[ | |
gr.Video(label="Processed Video"), | |
gr.File(label="Download Processed Video") | |
], | |
title="SAHI Video Object Tracker", | |
description="Object detection and tracking using SAHI and YOLOv11." | |
) | |
if __name__ == "__main__": | |
interface.launch() |