datasciencesage commited on
Commit
f384d65
·
1 Parent(s): c31b3e2

changes in app

Browse files
Files changed (1) hide show
  1. app.py +110 -94
app.py CHANGED
@@ -1,10 +1,13 @@
 
1
  from ultralytics import YOLO
2
- from ultralytics import YOLOv10
3
-
4
  import cv2
5
  import time
6
  import numpy as np
7
  import torch
 
 
 
 
8
 
9
  def get_direction(old_center, new_center, min_movement=10):
10
  if old_center is None or new_center is None:
@@ -69,115 +72,128 @@ class ObjectTracker:
69
  self.tracked_objects = current_objects
70
  return results
71
 
72
- def main():
73
- # Use YOLOv8x with optimizations
74
- # model = YOLO('yolov8x.pt')
75
-
76
- model = YOLOv10.from_pretrained("Ultralytics/YOLOv8")
77
 
78
 
79
- # Enable GPU if available and set half precision
80
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
81
- model.to(device)
82
-
83
- if device.type != 'cpu':
84
- torch.backends.cudnn.benchmark = True
85
-
86
- tracker = ObjectTracker()
87
- video_path = "test2.mp4"
88
- cap = cv2.VideoCapture(video_path)
89
-
90
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
91
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
92
- fps = int(cap.get(cv2.CAP_PROP_FPS))
93
-
94
- cv2.namedWindow("YOLOv8x Detection with Direction", cv2.WINDOW_NORMAL)
95
- cv2.resizeWindow("YOLOv8x Detection with Direction", 1280, 720)
96
-
97
- direction_colors = {
98
- "left": (255, 0, 0),
99
- "right": (0, 255, 0),
100
- "up": (0, 255, 255),
101
- "down": (0, 0, 255),
102
- "stationary": (128, 128, 128)
103
- }
104
 
105
- # FPS calculation
106
- fps_start_time = time.time()
107
- fps_counter = 0
108
- fps_display = 0
109
 
110
- # Process every 2nd frame for better performance
111
- frame_skip = 2
112
- frame_count = 0
113
 
114
- print(f"Running on device: {device}")
 
115
 
116
- while cap.isOpened():
117
- success, frame = cap.read()
118
- if not success:
119
- break
120
-
121
- frame_count += 1
122
- if frame_count % frame_skip != 0:
123
- continue
124
-
125
- # Update FPS
126
- fps_counter += 1
127
- if time.time() - fps_start_time > 1:
128
- fps_display = fps_counter * frame_skip # Adjust for skipped frames
129
- fps_counter = 0
130
- fps_start_time = time.time()
131
 
132
- # Optimize inference
133
- results = model(frame,
134
- conf=0.25,
135
- iou=0.45,
136
- max_det=20,
137
- verbose=False)[0]
138
 
139
- detections = []
140
- for box in results.boxes.data:
141
- x1, y1, x2, y2, conf, cls = box.tolist()
142
- detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)])
143
 
144
- tracked_objects = tracker.update(detections)
 
 
 
 
 
 
145
 
146
- # Draw FPS
147
- cv2.putText(frame, f"FPS: {fps_display}",
148
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
149
- 1, (0, 255, 0), 2)
150
 
151
- # Draw total detections
152
- cv2.putText(frame, f"Detections: {len(tracked_objects)}",
153
- (10, 70), cv2.FONT_HERSHEY_SIMPLEX,
154
- 1, (0, 255, 0), 2)
155
 
156
- for detection, obj_id, direction in tracked_objects:
157
- x1, y1, x2, y2, conf, cls = detection
158
- color = direction_colors.get(direction, (128, 128, 128))
 
159
 
160
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
 
 
 
 
 
161
 
162
- label = f"{model.names[int(cls)]} {direction} {conf:.2f}"
163
- text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
 
 
164
 
165
- cv2.rectangle(frame,
166
- (int(x1), int(y1) - text_size[1] - 10),
167
- (int(x1) + text_size[0], int(y1)),
168
- color, -1)
169
 
170
- cv2.putText(frame, label,
171
- (int(x1), int(y1) - 5),
172
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
173
-
174
- cv2.imshow("YOLOv8x Detection with Direction", frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
- if cv2.waitKey(1) & 0xFF == ord('q'):
177
- break
178
-
179
- cap.release()
180
- cv2.destroyAllWindows()
181
 
182
  if __name__ == "__main__":
183
  main()
 
1
+ import streamlit as st
2
  from ultralytics import YOLO
 
 
3
  import cv2
4
  import time
5
  import numpy as np
6
  import torch
7
+ from PIL import Image
8
+ import tempfile
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
 
12
  def get_direction(old_center, new_center, min_movement=10):
13
  if old_center is None or new_center is None:
 
72
  self.tracked_objects = current_objects
73
  return results
74
 
 
 
 
 
 
75
 
76
 
77
+ def main():
78
+ st.title("Real-time Object Detection with Direction")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ # File uploader for video
81
+ uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov'])
 
 
82
 
83
+ # Add start button
84
+ start_detection = st.button("Start Detection")
 
85
 
86
+ # Add stop button
87
+ stop_detection = st.button("Stop Detection")
88
 
89
+ if uploaded_file is not None and start_detection:
90
+ # Create a session state to track if detection is running
91
+ if 'running' not in st.session_state:
92
+ st.session_state.running = True
93
+
94
+ # Save uploaded file temporarily
95
+ tfile = tempfile.NamedTemporaryFile(delete=False)
96
+ tfile.write(uploaded_file.read())
 
 
 
 
 
 
 
97
 
98
+ # Load model
99
+ with st.spinner('Loading model...'):
100
+ model = YOLO('yolov8x.pt',verbose=False)
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ model.to(device)
 
103
 
104
+ tracker = ObjectTracker()
105
+ cap = cv2.VideoCapture(tfile.name)
 
 
106
 
107
+ direction_colors = {
108
+ "left": (255, 0, 0),
109
+ "right": (0, 255, 0),
110
+ "up": (0, 255, 255),
111
+ "down": (0, 0, 255),
112
+ "stationary": (128, 128, 128)
113
+ }
114
 
115
+ # Create placeholder for video frame
116
+ frame_placeholder = st.empty()
117
+ # Create placeholder for detection info
118
+ info_placeholder = st.empty()
119
 
120
+ st.success("Detection Started!")
 
 
 
121
 
122
+ while cap.isOpened() and st.session_state.running:
123
+ success, frame = cap.read()
124
+ if not success:
125
+ break
126
 
127
+ # Run detection
128
+ results = model(frame,
129
+ conf=0.25,
130
+ iou=0.45,
131
+ max_det=20,
132
+ verbose=False)[0]
133
 
134
+ detections = []
135
+ for box in results.boxes.data:
136
+ x1, y1, x2, y2, conf, cls = box.tolist()
137
+ detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)])
138
 
139
+ tracked_objects = tracker.update(detections)
 
 
 
140
 
141
+ # Dictionary to store detection counts
142
+ detection_counts = {}
143
+
144
+ for detection, obj_id, direction in tracked_objects:
145
+ x1, y1, x2, y2, conf, cls = detection
146
+ color = direction_colors.get(direction, (128, 128, 128))
147
+
148
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
149
+
150
+ label = f"{model.names[int(cls)]} {direction} {conf:.2f}"
151
+ # Increased font size and thickness
152
+ font_scale = 1.2
153
+ thickness = 3
154
+ text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0]
155
+
156
+ # Increased padding for label background
157
+ padding_y = 15
158
+ cv2.rectangle(frame,
159
+ (int(x1), int(y1) - text_size[1] - padding_y),
160
+ (int(x1) + text_size[0], int(y1)),
161
+ color, -1)
162
+
163
+ cv2.putText(frame, label,
164
+ (int(x1), int(y1) - 5),
165
+ cv2.FONT_HERSHEY_SIMPLEX,
166
+ font_scale,
167
+ (255, 255, 255),
168
+ thickness)
169
+
170
+ # Count detections by class
171
+ class_name = model.names[int(cls)]
172
+ detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
173
+
174
+ # Convert BGR to RGB
175
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
176
+
177
+ # Update frame
178
+ frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True)
179
+
180
+ # Update detection info
181
+ info_text = "Detected Objects:\n"
182
+ for class_name, count in detection_counts.items():
183
+ info_text += f"{class_name}: {count}\n"
184
+ info_placeholder.text(info_text)
185
+
186
+ # Check if stop button is pressed
187
+ if stop_detection:
188
+ st.session_state.running = False
189
+ break
190
+
191
+ cap.release()
192
+ st.session_state.running = False
193
+ st.warning("Detection Stopped")
194
 
195
+ elif uploaded_file is None and start_detection:
196
+ st.error("Please upload a video file first!")
 
 
 
197
 
198
  if __name__ == "__main__":
199
  main()