File size: 4,816 Bytes
7b04d4e
 
 
 
 
 
 
33fd6ad
 
 
 
 
7b04d4e
 
33fd6ad
7b04d4e
33fd6ad
7b04d4e
33fd6ad
 
 
 
7b04d4e
 
 
 
 
 
 
33fd6ad
 
 
7b04d4e
33fd6ad
7b04d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33fd6ad
7b04d4e
 
 
 
 
33fd6ad
7b04d4e
33fd6ad
7b04d4e
33fd6ad
 
7b04d4e
33fd6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b04d4e
33fd6ad
7b04d4e
33fd6ad
7b04d4e
33fd6ad
7b04d4e
33fd6ad
 
 
7b04d4e
33fd6ad
 
 
 
 
7b04d4e
33fd6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b04d4e
 
33fd6ad
 
 
 
 
 
 
7b04d4e
33fd6ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image
import io
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

class SafetyMonitor:
    def __init__(self, model_name: str = "mixtral-8x7b-vision"):
        """
        Initialize the safety monitor using environment variables for API key
        """
        api_key = os.getenv("GROQ_API_KEY")
        if not api_key:
            raise ValueError("GROQ_API_KEY environment variable is not set")
            
        self.client = Groq(api_key=api_key)
        self.model_name = model_name
        
    def analyze_frame(self, frame: np.ndarray) -> str:
        """
        Analyze a single frame using specified vision model
        """
        if frame is None:
            return "No frame received"
            
        # Convert frame to PIL Image
        frame_pil = Image.fromarray(frame)
        
        # Convert image to bytes
        img_byte_arr = io.BytesIO()
        frame_pil.save(img_byte_arr, format='JPEG')
        img_byte_arr = img_byte_arr.getvalue()
        
        # Safety analysis prompt
        prompt = """Please analyze this image for workplace safety issues. Focus on:
        1. Required PPE usage (hard hats, safety glasses, reflective vests)
        2. Unsafe behaviors or positions
        3. Equipment and machinery safety
        4. Environmental hazards (spills, obstacles, poor lighting)
        
        Provide specific observations and any immediate safety concerns."""
        
        try:
            completion = self.client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {"type": "image", "image": img_byte_arr}
                        ]
                    }
                ],
                model=self.model_name,
                max_tokens=200,
                temperature=0.2
            )
            return completion.choices[0].message.content
        except Exception as e:
            return f"Analysis Error: {str(e)}"

    def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
        """
        Process and analyze a single frame
        """
        if frame is None:
            return None, "No frame received"
            
        analysis = self.analyze_frame(frame)
        
        # Create a copy of frame for visualization
        display_frame = frame.copy()
        
        # Add semi-transparent overlay for text background
        overlay = display_frame.copy()
        cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1)
        cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame)
        
        # Add analysis text
        cv2.putText(display_frame, "Safety Analysis:", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
        
        # Split and display analysis text
        y_position = 60
        for line in analysis.split('\n'):
            cv2.putText(display_frame, line[:80], (10, y_position),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
            y_position += 30

        return display_frame, analysis

def create_gradio_interface():
    """
    Create and launch the Gradio interface with webcam input
    """
    try:
        # Initialize the safety monitor
        monitor = SafetyMonitor(model_name="mixtral-8x7b-vision")
        
        with gr.Blocks() as demo:
            gr.Markdown("""
            # Real-time Safety Monitoring System
            Click 'Start Webcam' to begin monitoring.
            """)
            
            with gr.Row():
                # Webcam input
                webcam = gr.Image(source="webcam", streaming=True, label="Webcam Feed")
                # Analysis output
                output_image = gr.Image(label="Analyzed Feed")
            
            with gr.Row():
                analysis_text = gr.Textbox(label="Safety Analysis", lines=5)
                
            def analyze_stream(frame):
                if frame is None:
                    return None, "Webcam not started"
                processed_frame, analysis = monitor.process_frame(frame)
                return processed_frame, analysis
                
            webcam.stream(
                fn=analyze_stream,
                outputs=[output_image, analysis_text],
                show_progress="hidden"
            )

        demo.queue()
        demo.launch()
        
    except ValueError as e:
        print(f"Error: {e}")
        print("Please make sure to set the GROQ_API_KEY environment variable")
        
if __name__ == "__main__":
    create_gradio_interface()