ayyuce commited on
Commit
e5f81ae
·
verified ·
1 Parent(s): c198642

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import torch
4
+ from sahi import AutoDetectionModel
5
+ from sahi.predict import get_sliced_prediction
6
+ from motpy import Detection as MotpyDetection, MultiObjectTracker
7
+ import tempfile
8
+
9
+ # COCO class names (YOLOv8 default)
10
+ COCO_CLASSES = [
11
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
12
+ 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
13
+ 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
14
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
15
+ 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
16
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
17
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
18
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
19
+ 'hair drier', 'toothbrush'
20
+ ]
21
+
22
+
23
+ model_path = "./yolo11n.pt"
24
+ detection_model = AutoDetectionModel.from_pretrained(
25
+ model_type='yolov8',
26
+ model_path=model_path,
27
+ confidence_threshold=0.3,
28
+ device='cpu' # Force CPU usage
29
+ )
30
+
31
+ def track_objects(video_path):
32
+ # Setup video processing
33
+ cap = cv2.VideoCapture(video_path)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ fps = cap.get(cv2.CAP_PROP_FPS)
37
+
38
+ output_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
39
+ output_path = output_file.name
40
+
41
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
43
+
44
+ tracker = MultiObjectTracker(
45
+ dt=0.1,
46
+ model_spec={
47
+ 'order_pos': 1, 'dim_pos': 2,
48
+ 'order_size': 0, 'dim_size': 2,
49
+ 'q_var_pos': 5000., 'r_var_pos': 0.1
50
+ }
51
+ )
52
+
53
+ frame_count = 0
54
+ while cap.isOpened():
55
+ ret, frame = cap.read()
56
+ if not ret:
57
+ break
58
+
59
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
60
+ result = get_sliced_prediction(
61
+ rgb_frame,
62
+ detection_model,
63
+ slice_height=512,
64
+ slice_width=512,
65
+ overlap_height_ratio=0.2,
66
+ overlap_width_ratio=0.2
67
+ )
68
+
69
+ detections = [
70
+ MotpyDetection(
71
+ box=[obj.bbox.minx, obj.bbox.miny, obj.bbox.maxx, obj.bbox.maxy],
72
+ score=obj.score.value,
73
+ class_id=obj.category.id
74
+ )
75
+ for obj in result.object_prediction_list
76
+ ]
77
+
78
+ tracker.step(detections)
79
+ tracks = tracker.active_tracks()
80
+
81
+ for track in tracks:
82
+ x1, y1, x2, y2 = map(int, track.box)
83
+ track_id = track.id
84
+ class_id = track.class_id if track.class_id is not None else -1
85
+ class_name = COCO_CLASSES[class_id] if 0 <= class_id < len(COCO_CLASSES) else str(class_id)
86
+
87
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
88
+ cv2.putText(frame, f'{class_name} {track_id}', (x1, y1 - 10),
89
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
90
+
91
+ out.write(frame)
92
+
93
+ cap.release()
94
+ out.release()
95
+ return output_path
96
+
97
+ def process_video(video):
98
+ output_path = track_objects(video)
99
+ return output_path
100
+
101
+ interface = gr.Interface(
102
+ fn=process_video,
103
+ inputs=gr.Video(label="Input Video"),
104
+ outputs=[
105
+ gr.Video(label="Processed Video"),
106
+ gr.File(label="Download Processed Video")
107
+ ],
108
+ title="SAHI Video Object Tracker",
109
+ description="Object detection and tracking using SAHI and YOLOv11."
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ interface.launch()