datasciencesage commited on
Commit
34faab5
·
1 Parent(s): 3aa216d

Removed the Yolo8x.pt file and changed some code in app file

Browse files
Files changed (2) hide show
  1. app.py +183 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ from ultralytics import YOLOv10
3
+
4
+ import cv2
5
+ import time
6
+ import numpy as np
7
+ import torch
8
+
9
+ def get_direction(old_center, new_center, min_movement=10):
10
+ if old_center is None or new_center is None:
11
+ return "stationary"
12
+
13
+ dx = new_center[0] - old_center[0]
14
+ dy = new_center[1] - old_center[1]
15
+
16
+ if abs(dx) < min_movement and abs(dy) < min_movement:
17
+ return "stationary"
18
+
19
+ if abs(dx) > abs(dy):
20
+ return "right" if dx > 0 else "left"
21
+ else:
22
+ return "down" if dy > 0 else "up"
23
+
24
+ class ObjectTracker:
25
+ def __init__(self):
26
+ self.tracked_objects = {}
27
+ self.object_count = {}
28
+
29
+ def update(self, detections):
30
+ current_objects = {}
31
+ results = []
32
+
33
+ for detection in detections:
34
+ x1, y1, x2, y2 = detection[0:4]
35
+ center = ((x1 + x2) // 2, (y1 + y2) // 2)
36
+ class_id = detection[5]
37
+
38
+ object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}"
39
+
40
+ min_dist = float('inf')
41
+ closest_id = None
42
+
43
+ for prev_id, prev_data in self.tracked_objects.items():
44
+ if prev_id.split('_')[0] == str(class_id):
45
+ dist = np.sqrt((center[0] - prev_data['center'][0])**2 +
46
+ (center[1] - prev_data['center'][1])**2)
47
+ if dist < min_dist and dist < 100:
48
+ min_dist = dist
49
+ closest_id = prev_id
50
+
51
+ if closest_id:
52
+ object_id = closest_id
53
+ else:
54
+ if class_id not in self.object_count:
55
+ self.object_count[class_id] = []
56
+ self.object_count[class_id].append(object_id)
57
+
58
+ prev_center = self.tracked_objects.get(object_id, {}).get('center', None)
59
+ direction = get_direction(prev_center, center)
60
+
61
+ current_objects[object_id] = {
62
+ 'center': center,
63
+ 'direction': direction,
64
+ 'detection': detection
65
+ }
66
+
67
+ results.append((detection, object_id, direction))
68
+
69
+ self.tracked_objects = current_objects
70
+ return results
71
+
72
+ def main():
73
+ # Use YOLOv8x with optimizations
74
+ # model = YOLO('yolov8x.pt')
75
+
76
+ model = YOLOv10.from_pretrained("Ultralytics/YOLOv8")
77
+
78
+
79
+ # Enable GPU if available and set half precision
80
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
+ model.to(device)
82
+
83
+ if device.type != 'cpu':
84
+ torch.backends.cudnn.benchmark = True
85
+
86
+ tracker = ObjectTracker()
87
+ video_path = "test2.mp4"
88
+ cap = cv2.VideoCapture(video_path)
89
+
90
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
91
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
92
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
93
+
94
+ cv2.namedWindow("YOLOv8x Detection with Direction", cv2.WINDOW_NORMAL)
95
+ cv2.resizeWindow("YOLOv8x Detection with Direction", 1280, 720)
96
+
97
+ direction_colors = {
98
+ "left": (255, 0, 0),
99
+ "right": (0, 255, 0),
100
+ "up": (0, 255, 255),
101
+ "down": (0, 0, 255),
102
+ "stationary": (128, 128, 128)
103
+ }
104
+
105
+ # FPS calculation
106
+ fps_start_time = time.time()
107
+ fps_counter = 0
108
+ fps_display = 0
109
+
110
+ # Process every 2nd frame for better performance
111
+ frame_skip = 2
112
+ frame_count = 0
113
+
114
+ print(f"Running on device: {device}")
115
+
116
+ while cap.isOpened():
117
+ success, frame = cap.read()
118
+ if not success:
119
+ break
120
+
121
+ frame_count += 1
122
+ if frame_count % frame_skip != 0:
123
+ continue
124
+
125
+ # Update FPS
126
+ fps_counter += 1
127
+ if time.time() - fps_start_time > 1:
128
+ fps_display = fps_counter * frame_skip # Adjust for skipped frames
129
+ fps_counter = 0
130
+ fps_start_time = time.time()
131
+
132
+ # Optimize inference
133
+ results = model(frame,
134
+ conf=0.25,
135
+ iou=0.45,
136
+ max_det=20,
137
+ verbose=False)[0]
138
+
139
+ detections = []
140
+ for box in results.boxes.data:
141
+ x1, y1, x2, y2, conf, cls = box.tolist()
142
+ detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)])
143
+
144
+ tracked_objects = tracker.update(detections)
145
+
146
+ # Draw FPS
147
+ cv2.putText(frame, f"FPS: {fps_display}",
148
+ (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
149
+ 1, (0, 255, 0), 2)
150
+
151
+ # Draw total detections
152
+ cv2.putText(frame, f"Detections: {len(tracked_objects)}",
153
+ (10, 70), cv2.FONT_HERSHEY_SIMPLEX,
154
+ 1, (0, 255, 0), 2)
155
+
156
+ for detection, obj_id, direction in tracked_objects:
157
+ x1, y1, x2, y2, conf, cls = detection
158
+ color = direction_colors.get(direction, (128, 128, 128))
159
+
160
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
161
+
162
+ label = f"{model.names[int(cls)]} {direction} {conf:.2f}"
163
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
164
+
165
+ cv2.rectangle(frame,
166
+ (int(x1), int(y1) - text_size[1] - 10),
167
+ (int(x1) + text_size[0], int(y1)),
168
+ color, -1)
169
+
170
+ cv2.putText(frame, label,
171
+ (int(x1), int(y1) - 5),
172
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
173
+
174
+ cv2.imshow("YOLOv8x Detection with Direction", frame)
175
+
176
+ if cv2.waitKey(1) & 0xFF == ord('q'):
177
+ break
178
+
179
+ cap.release()
180
+ cv2.destroyAllWindows()
181
+
182
+ if __name__ == "__main__":
183
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ ultralytics
3
+ opencv-python
4
+ torch
5
+ numpy
6
+ Pillow
7
+ torchvision
8
+ numpy