File size: 4,017 Bytes
993db33
8b83fe8
993db33
 
 
ea7c227
c7782e0
 
ea7c227
89530ea
993db33
c7782e0
993db33
c7782e0
993db33
c153693
 
 
 
 
 
 
c7782e0
c153693
 
 
 
c7782e0
 
c153693
 
c7782e0
 
 
 
 
 
 
 
 
 
 
 
 
c153693
e7397f9
993db33
88cc205
 
 
c7782e0
993db33
 
 
 
c7782e0
993db33
c0be576
c153693
993db33
 
 
 
 
c153693
c7782e0
 
 
993db33
 
 
 
 
c7782e0
 
 
 
 
 
c153693
 
 
 
 
 
c7782e0
 
 
 
 
 
 
 
 
 
993db33
 
c153693
993db33
 
 
 
c7782e0
993db33
 
 
c7782e0
 
 
 
abacf0c
993db33
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import cv2
import spaces
import numpy as np
import gradio as gr
import tempfile
import os
from datetime import timedelta

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

def preprocess_frame(frame):
    resized_frame = cv2.resize(frame, (224, 224))
    normalized_frame = resized_frame / 255.0
    return np.expand_dims(normalized_frame, axis=0)

def draw_label(frame, label, position=(50, 50), font_scale=1, thickness=2):
    if label == 'Drowsy':
        color = (0, 0, 255)  # Red for Drowsy
        bg_color = (0, 0, 100)  # Darker background for Drowsy
    else:
        color = (0, 255, 0)  # Green for Alert
        bg_color = (0, 100, 0)  # Darker background for Alert
    
    font = cv2.FONT_HERSHEY_SIMPLEX
    text_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
    
    text_x, text_y = position
    rect_start = (text_x - 5, text_y - text_size[1] - 15)
    rect_end = (text_x + text_size[0] + 5, text_y + 5)
    
    cv2.rectangle(frame, rect_start, rect_end, bg_color, -1)
    cv2.putText(frame, label, (text_x, text_y), font, font_scale, (255, 255, 255), thickness + 2, lineType=cv2.LINE_AA)
    cv2.putText(frame, label, (text_x, text_y), font, font_scale, color, thickness, lineType=cv2.LINE_AA)

def add_timestamp(frame, timestamp):
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(frame, timestamp, (10, frame.shape[0] - 10), font, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

def draw_progress_bar(frame, progress):
    frame_width = frame.shape[1]
    bar_height = 5
    bar_width = int(frame_width * progress)
    cv2.rectangle(frame, (0, 0), (bar_width, bar_height), (0, 255, 0), -1)
    cv2.rectangle(frame, (0, 0), (frame_width, bar_height), (255, 255, 255), 1)

@spaces.GPU(duration=60)
def predict_drowsiness(video_path):
    import tensorflow as tf
    print(tf.config.list_physical_devices("GPU"))
    model = tf.keras.models.load_model('cnn.keras')
    
    cap = cv2.VideoCapture(video_path)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    skip_interval = int(fps * 0.2)
    
    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_output:
        temp_output_path = temp_output.name
    
    out = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
    
    frame_count = 0
    drowsy_count = 0
    alert_count = 0
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        progress = frame_count / total_frames
        draw_progress_bar(frame, progress)
        
        timestamp = str(timedelta(seconds=int(frame_count/fps)))
        add_timestamp(frame, timestamp)
        
        if frame_count % skip_interval == 0:
            preprocessed_frame = preprocess_frame(frame)
            prediction = model.predict(preprocessed_frame)
            drowsiness = np.argmax(prediction)
            
            label = 'Drowsy' if drowsiness == 0 else 'Alert'
            draw_label(frame, label, position=(50, 50))
            
            if label == 'Drowsy':
                drowsy_count += 1
            else:
                alert_count += 1
        
        # Add drowsiness statistics
        stats_text = f"Drowsy: {drowsy_count} | Alert: {alert_count}"
        cv2.putText(frame, stats_text, (frame_width - 200, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
        
        out.write(frame)
        frame_count += 1
    
    cap.release()
    out.release()
    
    return temp_output_path

interface = gr.Interface(
    fn=predict_drowsiness,
    inputs=gr.Video(),
    outputs="video",
    title="Enhanced Drowsiness Detection in Video",
    description="Upload a video or record one to detect drowsiness with improved visuals and statistics.",
    examples=["003_nightglasses_mix.mp4"]
)

if __name__ == "__main__":
    interface.launch()