datasciencesage's picture
Modified the Readme
b4192f2
import streamlit as st
from ultralytics import YOLO
import cv2
import time
import numpy as np
import torch
from PIL import Image
import tempfile
import warnings
warnings.filterwarnings('ignore')
def get_direction(old_center, new_center, min_movement=10):
if old_center is None or new_center is None:
return "stationary"
dx = new_center[0] - old_center[0]
dy = new_center[1] - old_center[1]
if abs(dx) < min_movement and abs(dy) < min_movement:
return "stationary"
if abs(dx) > abs(dy):
return "right" if dx > 0 else "left"
else:
return "down" if dy > 0 else "up"
class ObjectTracker:
def __init__(self):
self.tracked_objects = {}
self.object_count = {}
def update(self, detections):
current_objects = {}
results = []
for detection in detections:
x1, y1, x2, y2 = detection[0:4]
center = ((x1 + x2) // 2, (y1 + y2) // 2)
class_id = detection[5]
object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}"
min_dist = float('inf')
closest_id = None
for prev_id, prev_data in self.tracked_objects.items():
if prev_id.split('_')[0] == str(class_id):
dist = np.sqrt((center[0] - prev_data['center'][0])**2 +
(center[1] - prev_data['center'][1])**2)
if dist < min_dist and dist < 100:
min_dist = dist
closest_id = prev_id
if closest_id:
object_id = closest_id
else:
if class_id not in self.object_count:
self.object_count[class_id] = []
self.object_count[class_id].append(object_id)
prev_center = self.tracked_objects.get(object_id, {}).get('center', None)
direction = get_direction(prev_center, center)
current_objects[object_id] = {
'center': center,
'direction': direction,
'detection': detection
}
results.append((detection, object_id, direction))
self.tracked_objects = current_objects
return results
def main():
st.title("Real-time Object Detection with Direction")
uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov'])
start_detection = st.button("Start Detection")
stop_detection = st.button("Stop Detection")
if uploaded_file is not None and start_detection:
if 'running' not in st.session_state:
st.session_state.running = True
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
with st.spinner('Loading model...'):
model = YOLO('yolov8x.pt',verbose=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
tracker = ObjectTracker()
cap = cv2.VideoCapture(tfile.name)
direction_colors = {
"left": (255, 0, 0),
"right": (0, 255, 0),
"up": (0, 255, 255),
"down": (0, 0, 255),
"stationary": (128, 128, 128)
}
frame_placeholder = st.empty()
info_placeholder = st.empty()
st.success("Detection Started!")
while cap.isOpened() and st.session_state.running:
success, frame = cap.read()
if not success:
break
results = model(frame,
conf=0.25,
iou=0.45,
max_det=20,
verbose=False)[0]
detections = []
for box in results.boxes.data:
x1, y1, x2, y2, conf, cls = box.tolist()
detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)])
tracked_objects = tracker.update(detections)
detection_counts = {}
for detection, obj_id, direction in tracked_objects:
x1, y1, x2, y2, conf, cls = detection
color = direction_colors.get(direction, (128, 128, 128))
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
label = f"{model.names[int(cls)]} {direction} {conf:.2f}"
font_scale = 1.2
thickness = 3
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0]
padding_y = 15
cv2.rectangle(frame,
(int(x1), int(y1) - text_size[1] - padding_y),
(int(x1) + text_size[0], int(y1)),
color, -1)
cv2.putText(frame, label,
(int(x1), int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
thickness)
class_name = model.names[int(cls)]
detection_counts[class_name] = detection_counts.get(class_name, 0) + 1
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Update frame
frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True)
# Update detection info
info_text = "Detected Objects:\n"
for class_name, count in detection_counts.items():
info_text += f"{class_name}: {count}\n"
info_placeholder.text(info_text)
# Check if stop button is pressed
if stop_detection:
st.session_state.running = False
break
cap.release()
st.session_state.running = False
st.warning("Detection Stopped")
elif uploaded_file is None and start_detection:
st.error("Please upload a video file first!")
if __name__ == "__main__":
main()