Sanshruth commited on
Commit
ef69651
·
verified ·
1 Parent(s): e93dd75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -53
app.py CHANGED
@@ -182,85 +182,81 @@ def draw_angled_line(image, line_params, color=(0, 255, 0), thickness=2):
182
 
183
  def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None):
184
  """
185
- Processes the IP camera stream to count objects of the selected classes crossing the line.
186
  """
187
  global line_params
188
 
189
- errors = []
 
190
 
191
- if line_params is None:
192
- errors.append("Error: No line drawn. Please draw a line on the first frame.")
193
- if selected_classes is None or len(selected_classes) == 0:
194
- errors.append("Error: No classes selected. Please select at least one class to detect.")
195
- if stream_url is None or stream_url.strip() == "":
196
- errors.append("Error: No stream URL provided.")
197
-
198
- if errors:
199
- return None, "\n".join(errors)
200
-
201
- logger.info("Connecting to the IP camera stream...")
202
  cap = cv2.VideoCapture(stream_url)
203
  if not cap.isOpened():
204
- errors.append("Error: Could not open stream.")
205
- return None, "\n".join(errors)
206
 
207
- model = YOLO(model="yolo11n.pt")
 
208
  crossed_objects = {}
209
- max_tracked_objects = 100 # Maximum number of objects to track before clearing
 
 
 
 
 
 
 
 
210
 
211
- logger.info("Starting to process the stream...")
212
  while cap.isOpened():
213
  ret, frame = cap.read()
214
  if not ret:
215
- errors.append("Error: Could not read frame from the stream.")
216
  break
217
 
218
- # Perform object tracking with confidence threshold
219
- results = model.track(frame, persist=True, conf=confidence_threshold,iou=0.5,max_det=50,verbose=False)
220
 
221
- if results[0].boxes.id is not None:
222
- track_ids = results[0].boxes.id.int().cpu().tolist()
223
- clss = results[0].boxes.cls.cpu().tolist()
224
- boxes = results[0].boxes.xyxy.cpu()
225
- confs = results[0].boxes.conf.cpu().tolist()
226
 
227
- for box, cls, t_id, conf in zip(boxes, clss, track_ids, confs):
228
- if conf >= confidence_threshold and model.names[cls] in selected_classes:
229
- # Check if the object crosses the line
230
- if is_object_crossing_line(box, line_params) and t_id not in crossed_objects:
231
- crossed_objects[t_id] = True
 
 
232
 
233
- # Clear the dictionary if it gets too large
234
- if len(crossed_objects) > max_tracked_objects:
235
- crossed_objects.clear()
236
 
237
- # Visualize the results with bounding boxes, masks, and IDs
238
- annotated_frame = results[0].plot()
 
 
 
239
 
240
- # Draw the angled line on the frame
241
- draw_angled_line(annotated_frame, line_params, color=(0, 255, 0), thickness=2)
 
242
 
243
- # Display the count on the frame with a modern look
244
- count = len(crossed_objects)
245
- (text_width, text_height), _ = cv2.getTextSize(f"COUNT: {count}", cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
246
 
247
- # Calculate the position for the middle of the top
248
- margin = 10 # Margin from the top
249
- x = (annotated_frame.shape[1] - text_width) // 2 # Center-align the text horizontally
250
- y = text_height + margin # Top-align the text
251
 
252
- # Draw the black background rectangle
253
- cv2.rectangle(annotated_frame, (x - margin, y - text_height - margin), (x + text_width + margin, y + margin), (0, 0, 0), -1)
 
 
254
 
255
- # Draw the text
256
- cv2.putText(annotated_frame, f"COUNT: {count}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
257
 
258
- # Yield the annotated frame to Gradio
259
- yield annotated_frame, ""
260
 
261
  cap.release()
262
- logger.info("Stream processing completed.")
263
-
264
  # Define the Gradio interface
265
  with gr.Blocks() as demo:
266
  gr.Markdown("<h1>Real-time monitoring, object tracking, and line-crossing detection for CCTV camera streams.</h1></center>")
 
182
 
183
  def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None):
184
  """
185
+ Processes the IP camera stream with batch processing for faster performance.
186
  """
187
  global line_params
188
 
189
+ if line_params is None or selected_classes is None or not stream_url:
190
+ return None, "Error: Missing required parameters"
191
 
 
 
 
 
 
 
 
 
 
 
 
192
  cap = cv2.VideoCapture(stream_url)
193
  if not cap.isOpened():
194
+ return None, "Error: Could not open stream"
 
195
 
196
+ # Initialize variables
197
+ frames_buffer = []
198
  crossed_objects = {}
199
+ batch_size = 16
200
+ max_tracked_objects = 1000
201
+
202
+ # Set capture properties for better performance
203
+ cap.set(cv2.CAP_PROP_BUFFERSIZE, 30)
204
+ cap.set(cv2.CAP_PROP_FPS, 30)
205
+ cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
206
+
207
+ model = YOLO(model="yolo11n.pt")
208
 
 
209
  while cap.isOpened():
210
  ret, frame = cap.read()
211
  if not ret:
 
212
  break
213
 
214
+ frames_buffer.append(frame)
 
215
 
216
+ if len(frames_buffer) >= batch_size:
217
+ # Process batch of frames
218
+ results = model.track(frames_buffer, persist=True, conf=confidence_threshold, verbose=False)
 
 
219
 
220
+ # Process each frame's results
221
+ for frame_idx, result in enumerate(results):
222
+ if result.boxes.id is not None:
223
+ track_ids = result.boxes.id.int().cpu().tolist()
224
+ clss = result.boxes.cls.cpu().tolist()
225
+ boxes = result.boxes.xyxy.cpu()
226
+ confs = result.boxes.conf.cpu().tolist()
227
 
228
+ # Create annotated frame
229
+ annotated_frame = frames_buffer[frame_idx].copy()
 
230
 
231
+ for box, cls, t_id, conf in zip(boxes, clss, track_ids, confs):
232
+ if conf >= confidence_threshold and model.names[cls] in selected_classes:
233
+ # Check line crossing
234
+ if is_object_crossing_line(box, line_params) and t_id not in crossed_objects:
235
+ crossed_objects[t_id] = True
236
 
237
+ # Clear if too many objects
238
+ if len(crossed_objects) > max_tracked_objects:
239
+ crossed_objects.clear()
240
 
241
+ # Draw bounding box
242
+ x1, y1, x2, y2 = map(int, box)
243
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
244
 
245
+ # Draw line
246
+ draw_angled_line(annotated_frame, line_params, color=(0, 255, 0), thickness=2)
 
 
247
 
248
+ # Draw count
249
+ count = len(crossed_objects)
250
+ cv2.putText(annotated_frame, f"COUNT: {count}", (10, 30),
251
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
252
 
253
+ # Yield the processed frame
254
+ yield annotated_frame, ""
255
 
256
+ # Clear buffer after processing batch
257
+ frames_buffer = []
258
 
259
  cap.release()
 
 
260
  # Define the Gradio interface
261
  with gr.Blocks() as demo:
262
  gr.Markdown("<h1>Real-time monitoring, object tracking, and line-crossing detection for CCTV camera streams.</h1></center>")