File size: 5,484 Bytes
7b04d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image
import io
from typing import Optional

class SafetyMonitor:
    def __init__(self, api_key: str, model_name: str = "mixtral-8x7b-vision"):
        """
        Initialize the safety monitor with configurable model
        
        Args:
            api_key (str): Groq API key
            model_name (str): Name of the vision model to use
        """
        self.client = Groq(api_key=api_key)
        self.model_name = model_name
        self.analysis_interval = 2  # seconds
        
    def analyze_frame(self, frame: np.ndarray) -> str:
        """
        Analyze a single frame using specified vision model
        """
        # Convert frame to PIL Image
        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        # Convert image to bytes
        img_byte_arr = io.BytesIO()
        frame_pil.save(img_byte_arr, format='JPEG')
        img_byte_arr = img_byte_arr.getvalue()
        
        # Safety analysis prompt
        prompt = """Please analyze this image for workplace safety issues. Focus on:
        1. Required PPE usage (hard hats, safety glasses, reflective vests)
        2. Unsafe behaviors or positions
        3. Equipment and machinery safety
        4. Environmental hazards (spills, obstacles, poor lighting)
        5. Emergency exit accessibility
        
        Provide specific observations and any immediate safety concerns."""
        
        try:
            completion = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image", "image": img_byte_arr}
                        ]
                    }
                ],
                model=self.model_name,
                max_tokens=200,
                temperature=0.2  # Lower temperature for more focused safety analysis
            )
            return completion.choices[0].message.content
        except Exception as e:
            return f"Analysis Error: {str(e)}"

    def process_video_stream(self):
        """
        Process video stream and yield analyzed frames
        """
        cap = cv2.VideoCapture(0)  # Use 0 for webcam
        last_analysis_time = 0
        latest_analysis = "Initializing safety analysis..."

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            current_time = time.time()
            
            # Perform analysis at specified intervals
            if current_time - last_analysis_time >= self.analysis_interval:
                latest_analysis = self.analyze_frame(frame)
                last_analysis_time = current_time

            # Create a copy of frame for visualization
            display_frame = frame.copy()
            
            # Add semi-transparent overlay for text background
            overlay = display_frame.copy()
            cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1)
            cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame)
            
            # Add analysis text
            cv2.putText(display_frame, "Safety Analysis:", (10, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
            
            # Split and display analysis text
            y_position = 60
            for line in latest_analysis.split('\n'):
                cv2.putText(display_frame, line[:80], (10, y_position),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                y_position += 30

            yield display_frame

        cap.release()

def create_gradio_interface(monitor: SafetyMonitor):
    """
    Create and launch the Gradio interface
    """
    with gr.Blocks() as demo:
        gr.Markdown(f"""
        # Real-time Safety Monitoring System
        Using model: {monitor.model_name}
        """)
        
        with gr.Row():
            video_output = gr.Image(label="Live Feed with Safety Analysis")
        
        with gr.Row():
            start_button = gr.Button("Start Monitoring", variant="primary")
            stop_button = gr.Button("Stop")
            
        with gr.Row():
            interval_slider = gr.Slider(
                minimum=1,
                maximum=10,
                value=monitor.analysis_interval,
                step=0.5,
                label="Analysis Interval (seconds)"
            )

        def update_interval(value):
            monitor.analysis_interval = value
            return gr.update()

        def start_monitoring():
            return gr.Image.update(value=monitor.process_video_stream())

        start_button.click(fn=start_monitoring, outputs=[video_output])
        stop_button.click(fn=lambda: None, outputs=[video_output])
        interval_slider.change(fn=update_interval, inputs=[interval_slider])

    demo.launch(share=True)

def main():
    # Replace with your actual API key
    GROQ_API_KEY = "YOUR_GROQ_API_KEY"
    
    # Initialize the safety monitor with desired model
    monitor = SafetyMonitor(
        api_key=GROQ_API_KEY,
        model_name="mixtral-8x7b-vision"  # Replace with your preferred model
    )
    
    # Launch the Gradio interface
    create_gradio_interface(monitor)

if __name__ == "__main__":
    main()