simran0608's picture
Upload 8 files
2440952 verified
raw
history blame
16.1 kB
import streamlit as st
from PIL import Image
import numpy as np
import subprocess
import time
import tempfile
import os
from ultralytics import YOLO
import cv2 as cv
import pandas as pd
model_path="/home/bacancy/POCs/Driver-Distraction-Detection-main/models/best2.pt"
# --- Page Configuration ---
st.set_page_config(
page_title="Driver Distraction System",
page_icon="πŸš—",
layout="wide",
initial_sidebar_state="expanded",
)
# --- Sidebar ---
st.sidebar.title("πŸš— Driver Distraction System")
st.sidebar.write("Choose an option below:")
# Sidebar navigation
page = st.sidebar.radio("Select Feature", [
"Distraction System",
"Real-time Drowsiness Detection",
"Video Drowsiness Detection"
])
# --- Class Labels (for YOLO model) ---
class_names = ['drinking', 'hair and makeup', 'operating the radio', 'reaching behind',
'safe driving', 'talking on the phone', 'talking to passenger', 'texting']
# Sidebar Class Name Display
st.sidebar.subheader("Class Names")
for idx, class_name in enumerate(class_names):
st.sidebar.write(f"{idx}: {class_name}")
# --- Feature: YOLO Distraction Detection ---
if page == "Distraction System":
st.title("Driver Distraction System")
st.write("Upload an image or video to detect distractions using YOLO model.")
# File type selection
file_type = st.radio("Select file type:", ["Image", "Video"])
if file_type == "Image":
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
image_np = np.array(image)
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Uploaded Image")
st.image(image, caption="Original Image", use_container_width=True)
with col2:
st.subheader("Detection Results")
model = YOLO(model_path)
start_time = time.time()
results = model(image_np)
end_time = time.time()
prediction_time = end_time - start_time
result = results[0]
if len(result.boxes) > 0:
boxes = result.boxes
confidences = boxes.conf.cpu().numpy()
classes = boxes.cls.cpu().numpy()
class_names_dict = result.names
max_conf_idx = confidences.argmax()
predicted_class = class_names_dict[int(classes[max_conf_idx])]
confidence_score = confidences[max_conf_idx]
st.markdown(f"### Predicted Class: **{predicted_class}**")
st.markdown(f"### Confidence Score: **{confidence_score:.4f}** ({confidence_score*100:.1f}%)")
st.markdown(f"Inference Time: {prediction_time:.2f} seconds")
else:
st.warning("No distractions detected.")
else: # Video processing
uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"])
if uploaded_video is not None:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tfile.write(uploaded_video.read())
temp_input_path = tfile.name
temp_output_path = tempfile.mktemp(suffix="_distraction_detected.mp4")
st.subheader("Video Information")
cap = cv.VideoCapture(temp_input_path)
fps = cap.get(cv.CAP_PROP_FPS)
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
cap.release()
col1, col2 = st.columns(2)
with col1:
st.metric("Duration", f"{duration:.2f} seconds")
st.metric("Original FPS", f"{fps:.2f}")
with col2:
st.metric("Resolution", f"{width}x{height}")
st.metric("Total Frames", total_frames)
st.subheader("Original Video Preview")
st.video(uploaded_video)
if st.button("Process Video for Distraction Detection"):
TARGET_PROCESSING_FPS = 10
# --- NEW: Hyperparameter for the temporal smoothing logic ---
PERSISTENCE_CONFIDENCE_THRESHOLD = 0.40 # Stick with old class if found with >= 40% confidence
st.info(f"πŸš€ For faster results, video will be processed at ~{TARGET_PROCESSING_FPS} FPS.")
st.info(f"🧠 Applying temporal smoothing to reduce status flickering (Persistence Threshold: {PERSISTENCE_CONFIDENCE_THRESHOLD*100:.0f}%).")
progress_bar = st.progress(0, text="Starting video processing...")
with st.spinner(f"Processing video... This may take a while."):
model = YOLO(model_path)
cap = cv.VideoCapture(temp_input_path)
fourcc = cv.VideoWriter_fourcc(*'mp4v')
out = cv.VideoWriter(temp_output_path, fourcc, fps, (width, height))
frame_skip_interval = max(1, round(fps / TARGET_PROCESSING_FPS))
frame_count = 0
last_best_box_coords = None
last_best_box_label = ""
last_status_text = "Status: Initializing..."
last_status_color = (128, 128, 128)
# --- NEW: State variable to store the last confirmed class ---
last_confirmed_class_name = 'safe driving'
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_count += 1
progress = int((frame_count / total_frames) * 100) if total_frames > 0 else 0
progress_bar.progress(progress, text=f"Analyzing frame {frame_count}/{total_frames}")
annotated_frame = frame.copy()
if frame_count % frame_skip_interval == 0:
results = model(annotated_frame)
result = results[0]
last_best_box_coords = None # Reset box for this processing cycle
if len(result.boxes) > 0:
boxes = result.boxes
class_names_dict = result.names
confidences = boxes.conf.cpu().numpy()
classes = boxes.cls.cpu().numpy()
# --- NEW STABILITY LOGIC ---
final_box_to_use = None
# 1. Check if the last known class exists with reasonable confidence
for i in range(len(boxes)):
current_class_name = class_names_dict[int(classes[i])]
if current_class_name == last_confirmed_class_name and confidences[i] >= PERSISTENCE_CONFIDENCE_THRESHOLD:
final_box_to_use = boxes[i]
break
# 2. If not, fall back to the highest confidence detection in the current frame
if final_box_to_use is None:
max_conf_idx = confidences.argmax()
final_box_to_use = boxes[max_conf_idx]
# --- END OF NEW LOGIC ---
# Now, process the determined "final_box_to_use"
x1, y1, x2, y2 = final_box_to_use.xyxy[0].cpu().numpy()
confidence = final_box_to_use.conf[0].cpu().numpy()
class_id = int(final_box_to_use.cls[0].cpu().numpy())
class_name = class_names_dict[class_id]
# Update the state for the next frames
last_confirmed_class_name = class_name
last_best_box_coords = (int(x1), int(y1), int(x2), int(y2))
last_best_box_label = f"{class_name}: {confidence:.2f}"
if class_name != 'safe driving':
last_status_text = f"Status: {class_name.replace('_', ' ').title()}"
last_status_color = (0, 0, 255)
else:
last_status_text = "Status: Safe Driving"
last_status_color = (0, 128, 0)
else:
# No detections, reset to safe driving
last_confirmed_class_name = 'safe driving'
last_status_text = "Status: Safe Driving"
last_status_color = (0, 128, 0)
# Draw annotations on EVERY frame using the last known data
if last_best_box_coords:
cv.rectangle(annotated_frame, (last_best_box_coords[0], last_best_box_coords[1]),
(last_best_box_coords[2], last_best_box_coords[3]), (0, 255, 0), 2)
cv.putText(annotated_frame, last_best_box_label,
(last_best_box_coords[0], last_best_box_coords[1] - 10),
cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Draw status text
font_scale, font_thickness = 1.0, 2
(text_w, text_h), _ = cv.getTextSize(last_status_text, cv.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
padding = 10
rect_start = (padding, padding)
rect_end = (padding + text_w + padding, padding + text_h + padding)
cv.rectangle(annotated_frame, rect_start, rect_end, last_status_color, -1)
text_pos = (padding + 5, padding + text_h + 5)
cv.putText(annotated_frame, last_status_text, text_pos, cv.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness)
out.write(annotated_frame)
cap.release()
out.release()
progress_bar.progress(100, text="Video processing completed!")
st.success("Video processed successfully!")
if os.path.exists(temp_output_path):
with open(temp_output_path, "rb") as file:
video_bytes = file.read()
st.download_button(
label="πŸ“₯ Download Processed Video",
data=video_bytes,
file_name=f"distraction_detected_{uploaded_video.name}",
mime="video/mp4",
key="download_distraction_video"
)
st.subheader("Sample Frame from Processed Video")
cap_out = cv.VideoCapture(temp_output_path)
ret, frame = cap_out.read()
if ret:
frame_rgb = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
st.image(frame_rgb, caption="Sample frame with distraction detection", use_container_width=True)
cap_out.release()
try:
os.unlink(temp_input_path)
if os.path.exists(temp_output_path): os.unlink(temp_output_path)
except Exception as e:
st.warning(f"Failed to clean up temporary files: {e}")
# --- Feature: Real-time Drowsiness Detection ---
elif page == "Real-time Drowsiness Detection":
st.title("🧠 Real-time Drowsiness Detection")
st.write("This will open your webcam and run the detection script.")
if st.button("Start Drowsiness Detection"):
with st.spinner("Launching webcam..."):
subprocess.Popen(["python3", "src/drowsiness_detection.py", "--mode", "webcam"])
st.success("Drowsiness detection started in a separate window. Press 'q' in that window to quit.")
# --- Feature: Video Drowsiness Detection ---
elif page == "Video Drowsiness Detection":
st.title("πŸ“Ή Video Drowsiness Detection")
st.write("Upload a video file to detect drowsiness and download the processed video.")
uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"])
if uploaded_video is not None:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tfile.write(uploaded_video.read())
temp_input_path = tfile.name
temp_output_path = tempfile.mktemp(suffix="_processed.mp4")
st.subheader("Original Video Preview")
st.video(uploaded_video)
if st.button("Process Video for Drowsiness Detection"):
progress_bar = st.progress(0, text="Preparing to process video...")
with st.spinner("Processing video... This may take a while."):
process = subprocess.Popen([
"python3", "src/drowsiness_detection.py",
"--mode", "video",
"--input", temp_input_path,
"--output", temp_output_path
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate()
if process.returncode == 0:
progress_bar.progress(100, text="Video processing completed!")
if os.path.exists(temp_output_path):
st.success("Video processed successfully!")
if stdout: st.code(stdout)
with open(temp_output_path, "rb") as file: video_bytes = file.read()
st.download_button(
label="πŸ“₯ Download Processed Video",
data=video_bytes,
file_name=f"drowsiness_detected_{uploaded_video.name}",
mime="video/mp4",
key="download_processed_video"
)
st.subheader("Sample Frame from Processed Video")
cap = cv.VideoCapture(temp_output_path)
ret, frame = cap.read()
if ret: st.image(cv.cvtColor(frame, cv.COLOR_BGR2RGB), caption="Sample frame with drowsiness detection", use_container_width=True)
cap.release()
else:
st.error("Error: Processed video file not found.")
if stderr: st.code(stderr)
else:
st.error("An error occurred during video processing.")
if stderr: st.code(stderr)
try:
if os.path.exists(temp_input_path): os.unlink(temp_input_path)
if os.path.exists(temp_output_path): os.unlink(temp_output_path)
except Exception as e:
st.warning(f"Failed to clean up temporary files: {e}")