|
import streamlit as st |
|
from ultralytics import YOLO |
|
import cv2 |
|
import time |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import tempfile |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
def get_direction(old_center, new_center, min_movement=10): |
|
if old_center is None or new_center is None: |
|
return "stationary" |
|
|
|
dx = new_center[0] - old_center[0] |
|
dy = new_center[1] - old_center[1] |
|
|
|
if abs(dx) < min_movement and abs(dy) < min_movement: |
|
return "stationary" |
|
|
|
if abs(dx) > abs(dy): |
|
return "right" if dx > 0 else "left" |
|
else: |
|
return "down" if dy > 0 else "up" |
|
|
|
class ObjectTracker: |
|
def __init__(self): |
|
self.tracked_objects = {} |
|
self.object_count = {} |
|
|
|
def update(self, detections): |
|
current_objects = {} |
|
results = [] |
|
|
|
for detection in detections: |
|
x1, y1, x2, y2 = detection[0:4] |
|
center = ((x1 + x2) // 2, (y1 + y2) // 2) |
|
class_id = detection[5] |
|
|
|
object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}" |
|
|
|
min_dist = float('inf') |
|
closest_id = None |
|
|
|
for prev_id, prev_data in self.tracked_objects.items(): |
|
if prev_id.split('_')[0] == str(class_id): |
|
dist = np.sqrt((center[0] - prev_data['center'][0])**2 + |
|
(center[1] - prev_data['center'][1])**2) |
|
if dist < min_dist and dist < 100: |
|
min_dist = dist |
|
closest_id = prev_id |
|
|
|
if closest_id: |
|
object_id = closest_id |
|
else: |
|
if class_id not in self.object_count: |
|
self.object_count[class_id] = [] |
|
self.object_count[class_id].append(object_id) |
|
|
|
prev_center = self.tracked_objects.get(object_id, {}).get('center', None) |
|
direction = get_direction(prev_center, center) |
|
|
|
current_objects[object_id] = { |
|
'center': center, |
|
'direction': direction, |
|
'detection': detection |
|
} |
|
|
|
results.append((detection, object_id, direction)) |
|
|
|
self.tracked_objects = current_objects |
|
return results |
|
|
|
|
|
|
|
def main(): |
|
st.title("Real-time Object Detection with Direction") |
|
|
|
uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov']) |
|
|
|
start_detection = st.button("Start Detection") |
|
|
|
stop_detection = st.button("Stop Detection") |
|
|
|
if uploaded_file is not None and start_detection: |
|
if 'running' not in st.session_state: |
|
st.session_state.running = True |
|
|
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(uploaded_file.read()) |
|
|
|
with st.spinner('Loading model...'): |
|
model = YOLO('yolov8x.pt',verbose=False) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
tracker = ObjectTracker() |
|
cap = cv2.VideoCapture(tfile.name) |
|
|
|
direction_colors = { |
|
"left": (255, 0, 0), |
|
"right": (0, 255, 0), |
|
"up": (0, 255, 255), |
|
"down": (0, 0, 255), |
|
"stationary": (128, 128, 128) |
|
} |
|
|
|
frame_placeholder = st.empty() |
|
info_placeholder = st.empty() |
|
|
|
st.success("Detection Started!") |
|
|
|
while cap.isOpened() and st.session_state.running: |
|
success, frame = cap.read() |
|
if not success: |
|
break |
|
|
|
results = model(frame, |
|
conf=0.25, |
|
iou=0.45, |
|
max_det=20, |
|
verbose=False)[0] |
|
|
|
detections = [] |
|
for box in results.boxes.data: |
|
x1, y1, x2, y2, conf, cls = box.tolist() |
|
detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)]) |
|
|
|
tracked_objects = tracker.update(detections) |
|
|
|
detection_counts = {} |
|
|
|
for detection, obj_id, direction in tracked_objects: |
|
x1, y1, x2, y2, conf, cls = detection |
|
color = direction_colors.get(direction, (128, 128, 128)) |
|
|
|
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) |
|
|
|
label = f"{model.names[int(cls)]} {direction} {conf:.2f}" |
|
font_scale = 1.2 |
|
thickness = 3 |
|
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0] |
|
|
|
padding_y = 15 |
|
cv2.rectangle(frame, |
|
(int(x1), int(y1) - text_size[1] - padding_y), |
|
(int(x1) + text_size[0], int(y1)), |
|
color, -1) |
|
|
|
cv2.putText(frame, label, |
|
(int(x1), int(y1) - 5), |
|
cv2.FONT_HERSHEY_SIMPLEX, |
|
font_scale, |
|
(255, 255, 255), |
|
thickness) |
|
|
|
class_name = model.names[int(cls)] |
|
detection_counts[class_name] = detection_counts.get(class_name, 0) + 1 |
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True) |
|
|
|
|
|
info_text = "Detected Objects:\n" |
|
for class_name, count in detection_counts.items(): |
|
info_text += f"{class_name}: {count}\n" |
|
info_placeholder.text(info_text) |
|
|
|
|
|
if stop_detection: |
|
st.session_state.running = False |
|
break |
|
|
|
cap.release() |
|
st.session_state.running = False |
|
st.warning("Detection Stopped") |
|
|
|
elif uploaded_file is None and start_detection: |
|
st.error("Please upload a video file first!") |
|
|
|
if __name__ == "__main__": |
|
main() |