Sanshruth commited on
Commit
0848e32
·
verified ·
1 Parent(s): 2b307a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -101
app.py CHANGED
@@ -14,8 +14,6 @@ cv2.setNumThreads(cpu_cores)
14
  print(f"OpenCV using {cv2.getNumThreads()} threads out of {cpu_cores} available cores")
15
 
16
  ##############
17
-
18
-
19
  import cv2
20
  import gradio as gr
21
  import numpy as np
@@ -180,7 +178,6 @@ def draw_angled_line(image, line_params, color=(0, 255, 0), thickness=2):
180
  _, _, start_point, end_point = line_params
181
  cv2.line(image, start_point, end_point, color, thickness)
182
 
183
-
184
  def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None):
185
  """
186
  Processes the IP camera stream to count objects of the selected classes crossing the line.
@@ -205,16 +202,9 @@ def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=No
205
  errors.append("Error: Could not open stream.")
206
  return None, "\n".join(errors)
207
 
208
- # Set capture properties for better performance
209
- cap.set(cv2.CAP_PROP_BUFFERSIZE, 30)
210
- cap.set(cv2.CAP_PROP_FPS, 30)
211
- cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
212
-
213
  model = YOLO(model="yolo11n.pt")
214
  crossed_objects = {}
215
- max_tracked_objects = 1000
216
- frames_buffer = []
217
- batch_size = 16
218
 
219
  logger.info("Starting to process the stream...")
220
  while cap.isOpened():
@@ -223,96 +213,48 @@ def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=No
223
  errors.append("Error: Could not read frame from the stream.")
224
  break
225
 
226
- frames_buffer.append(frame)
227
-
228
- if len(frames_buffer) >= batch_size:
229
- # Process batch of frames
230
- results = model.track(frames_buffer, persist=True, conf=confidence_threshold)
231
-
232
- # Process and yield each frame immediately to maintain real-time appearance
233
- for idx, result in enumerate(results):
234
- if result.boxes.id is not None:
235
- track_ids = result.boxes.id.int().cpu().tolist()
236
- clss = result.boxes.cls.cpu().tolist()
237
- boxes = result.boxes.xyxy.cpu()
238
- confs = result.boxes.conf.cpu().tolist()
239
-
240
- for box, cls, t_id, conf in zip(boxes, clss, track_ids, confs):
241
- if conf >= confidence_threshold and model.names[cls] in selected_classes:
242
- if is_object_crossing_line(box, line_params) and t_id not in crossed_objects:
243
- crossed_objects[t_id] = True
244
-
245
- if len(crossed_objects) > max_tracked_objects:
246
- crossed_objects.clear()
247
-
248
- # Visualize the results with bounding boxes, masks, and IDs
249
- annotated_frame = result.plot()
250
-
251
- # Draw the angled line on the frame
252
- draw_angled_line(annotated_frame, line_params, color=(0, 255, 0), thickness=2)
253
-
254
- # Display the count on the frame with a modern look
255
- count = len(crossed_objects)
256
- (text_width, text_height), _ = cv2.getTextSize(f"COUNT: {count}", cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
257
-
258
- # Calculate the position for the middle of the top
259
- margin = 10
260
- x = (annotated_frame.shape[1] - text_width) // 2
261
- y = text_height + margin
262
-
263
- # Draw the black background rectangle
264
- cv2.rectangle(annotated_frame,
265
- (x - margin, y - text_height - margin),
266
- (x + text_width + margin, y + margin),
267
- (0, 0, 0), -1)
268
-
269
- # Draw the text
270
- cv2.putText(annotated_frame, f"COUNT: {count}",
271
- (x, y), cv2.FONT_HERSHEY_SIMPLEX,
272
- 1, (0, 255, 0), 2)
273
-
274
- # Yield each frame as soon as it's processed
275
- yield annotated_frame, ""
276
-
277
- # Clear the buffer after processing
278
- frames_buffer = []
279
-
280
- # If we have remaining frames that don't make a full batch, process them too
281
- elif frames_buffer:
282
- results = model.track(frames_buffer, persist=True, conf=confidence_threshold)
283
-
284
- for result in results:
285
- if result.boxes.id is not None:
286
- track_ids = result.boxes.id.int().cpu().tolist()
287
- clss = result.boxes.cls.cpu().tolist()
288
- boxes = result.boxes.xyxy.cpu()
289
- confs = result.boxes.conf.cpu().tolist()
290
-
291
- for box, cls, t_id, conf in zip(boxes, clss, track_ids, confs):
292
- if conf >= confidence_threshold and model.names[cls] in selected_classes:
293
- if is_object_crossing_line(box, line_params) and t_id not in crossed_objects:
294
- crossed_objects[t_id] = True
295
-
296
- annotated_frame = result.plot()
297
- draw_angled_line(annotated_frame, line_params, color=(0, 255, 0), thickness=2)
298
-
299
- count = len(crossed_objects)
300
- (text_width, text_height), _ = cv2.getTextSize(f"COUNT: {count}", cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
301
- margin = 10
302
- x = (annotated_frame.shape[1] - text_width) // 2
303
- y = text_height + margin
304
-
305
- cv2.rectangle(annotated_frame,
306
- (x - margin, y - text_height - margin),
307
- (x + text_width + margin, y + margin),
308
- (0, 0, 0), -1)
309
- cv2.putText(annotated_frame, f"COUNT: {count}",
310
- (x, y), cv2.FONT_HERSHEY_SIMPLEX,
311
- 1, (0, 255, 0), 2)
312
-
313
- yield annotated_frame, ""
314
-
315
- frames_buffer = []
316
 
317
  cap.release()
318
  logger.info("Stream processing completed.")
 
14
  print(f"OpenCV using {cv2.getNumThreads()} threads out of {cpu_cores} available cores")
15
 
16
  ##############
 
 
17
  import cv2
18
  import gradio as gr
19
  import numpy as np
 
178
  _, _, start_point, end_point = line_params
179
  cv2.line(image, start_point, end_point, color, thickness)
180
 
 
181
  def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None):
182
  """
183
  Processes the IP camera stream to count objects of the selected classes crossing the line.
 
202
  errors.append("Error: Could not open stream.")
203
  return None, "\n".join(errors)
204
 
 
 
 
 
 
205
  model = YOLO(model="yolo11n.pt")
206
  crossed_objects = {}
207
+ max_tracked_objects = 1000 # Maximum number of objects to track before clearing
 
 
208
 
209
  logger.info("Starting to process the stream...")
210
  while cap.isOpened():
 
213
  errors.append("Error: Could not read frame from the stream.")
214
  break
215
 
216
+ # Perform object tracking with confidence threshold
217
+ results = model.track(frame, persist=True, conf=confidence_threshold)
218
+
219
+ if results[0].boxes.id is not None:
220
+ track_ids = results[0].boxes.id.int().cpu().tolist()
221
+ clss = results[0].boxes.cls.cpu().tolist()
222
+ boxes = results[0].boxes.xyxy.cpu()
223
+ confs = results[0].boxes.conf.cpu().tolist()
224
+
225
+ for box, cls, t_id, conf in zip(boxes, clss, track_ids, confs):
226
+ if conf >= confidence_threshold and model.names[cls] in selected_classes:
227
+ # Check if the object crosses the line
228
+ if is_object_crossing_line(box, line_params) and t_id not in crossed_objects:
229
+ crossed_objects[t_id] = True
230
+
231
+ # Clear the dictionary if it gets too large
232
+ if len(crossed_objects) > max_tracked_objects:
233
+ crossed_objects.clear()
234
+
235
+ # Visualize the results with bounding boxes, masks, and IDs
236
+ annotated_frame = results[0].plot()
237
+
238
+ # Draw the angled line on the frame
239
+ draw_angled_line(annotated_frame, line_params, color=(0, 255, 0), thickness=2)
240
+
241
+ # Display the count on the frame with a modern look
242
+ count = len(crossed_objects)
243
+ (text_width, text_height), _ = cv2.getTextSize(f"COUNT: {count}", cv2.FONT_HERSHEY_SIMPLEX, 1, 2)
244
+
245
+ # Calculate the position for the middle of the top
246
+ margin = 10 # Margin from the top
247
+ x = (annotated_frame.shape[1] - text_width) // 2 # Center-align the text horizontally
248
+ y = text_height + margin # Top-align the text
249
+
250
+ # Draw the black background rectangle
251
+ cv2.rectangle(annotated_frame, (x - margin, y - text_height - margin), (x + text_width + margin, y + margin), (0, 0, 0), -1)
252
+
253
+ # Draw the text
254
+ cv2.putText(annotated_frame, f"COUNT: {count}", (x, y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
255
+
256
+ # Yield the annotated frame to Gradio
257
+ yield annotated_frame, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  cap.release()
260
  logger.info("Stream processing completed.")