Sanshruth's picture
Update app.py
29901d7 verified
raw
history blame
6.18 kB
# Maximize CPU usage and GPU utilization
import multiprocessing
import cv2
# Get the number of CPU cores
cpu_cores = multiprocessing.cpu_count()
# Set OpenCV to use all available cores
cv2.setNumThreads(cpu_cores)
# Print the number of threads being used (optional)
print(f"OpenCV using {cv2.getNumThreads()} threads out of {cpu_cores} available cores")
##############
import torch
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
import logging
import math
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables to store line coordinates and line equation
start_point = None
end_point = None
line_params = None # Stores (slope, intercept) of the line
# Initialize model once
model = YOLO('yolov8n.pt') # Use smaller model if needed
# Check for GPU availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
logger.info(f"Using device: {device}")
# Video processing parameters
FRAME_SKIP = 1 # Process every nth frame
FRAME_SCALE = 0.5 # Scale factor for input frames
def extract_first_frame(stream_url):
"""Extracts the first available frame from the IP camera stream."""
logger.info("Extracting first frame...")
cap = cv2.VideoCapture(stream_url)
if not cap.isOpened():
logger.error("Could not open stream.")
return None, "Error: Could not open stream."
ret, frame = cap.read()
cap.release()
if not ret:
logger.error("Could not read frame.")
return None, "Error: Could not read frame."
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame_rgb), "First frame extracted."
def update_line(image, evt: gr.SelectData):
"""Updates the line based on user interaction."""
global start_point, end_point, line_params
if start_point is None:
start_point = (evt.index[0], evt.index[1])
draw = ImageDraw.Draw(image)
draw.ellipse((start_point[0]-5, start_point[1]-5, start_point[0]+5, start_point[1]+5),
fill="blue", outline="blue")
return image, f"Line Start: {start_point}"
end_point = (evt.index[0], evt.index[1])
draw = ImageDraw.Draw(image)
draw.line([start_point, end_point], fill="red", width=2)
draw.ellipse((end_point[0]-5, end_point[1]-5, end_point[0]+5, end_point[1]+5),
fill="green", outline="green")
# Calculate line parameters
if start_point[0] != end_point[0]:
slope = (end_point[1] - start_point[1]) / (end_point[0] - start_point[0])
intercept = start_point[1] - slope * start_point[0]
line_params = (slope, intercept, start_point, end_point)
else:
line_params = (float('inf'), start_point[0], start_point, end_point)
start_point = None
return image, f"Line: {line_params[0]:.2f}x + {line_params[1]:.2f}"
def optimized_intersection_check(box, line_params):
"""Optimized line-box intersection check using vector math."""
_, _, (x1, y1), (x2, y2) = line_params
box_x1, box_y1, box_x2, box_y2 = box
# Convert line to parametric form
dx = x2 - x1
dy = y2 - y1
# Check if any box edge intersects the line
t_near = -float('inf')
t_far = float('inf')
for i in range(2):
if dx == 0 and dy == 0:
continue
if i == 0: # X-axis
t0 = (box_x1 - x1) / dx if dx != 0 else 0
t1 = (box_x2 - x1) / dx if dx != 0 else 0
else: # Y-axis
t0 = (box_y1 - y1) / dy if dy != 0 else 0
t1 = (box_y2 - y1) / dy if dy != 0 else 0
t_min = min(t0, t1)
t_max = max(t0, t1)
if t_min > t_near: t_near = t_min
if t_max < t_far: t_far = t_max
return t_near <= t_far and t_near <= 1 and t_far >= 0
def process_video(confidence_threshold=0.5, selected_classes=None, stream_url=None):
"""Optimized video processing pipeline."""
global line_params
# Validation checks
if not line_params or not selected_classes or not stream_url:
return None, "Missing configuration parameters"
# Convert to set for faster lookups
selected_classes = set(selected_classes)
# Video capture setup
cap = cv2.VideoCapture(stream_url)
if not cap.isOpened():
return None, "Error opening stream"
crossed_objects = set()
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
if frame_count % FRAME_SKIP != 0:
continue
# Preprocess frame
frame = cv2.resize(frame, None, fx=FRAME_SCALE, fy=FRAME_SCALE)
# Object detection
results = model.track(
frame,
persist=True,
conf=confidence_threshold,
verbose=False,
device=device,
tracker="botsort.yaml" # Use optimized tracker config
)
# Process detections
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().numpy()
track_ids = results[0].boxes.id.int().cpu().numpy()
classes = results[0].boxes.cls.cpu().numpy()
for box, track_id, cls in zip(boxes, track_ids, classes):
if model.names[int(cls)] not in selected_classes:
continue
if optimized_intersection_check(box, line_params) and track_id not in crossed_objects:
crossed_objects.add(track_id)
if len(crossed_objects) > 1000:
crossed_objects.clear()
# Annotation
annotated_frame = results[0].plot()
cv2.line(annotated_frame, line_params[2], line_params[3], (0,255,0), 2)
cv2.putText(annotated_frame, f"COUNT: {len(crossed_objects)}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)
yield annotated_frame, ""
cap.release()