|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
from groq import Groq |
|
import time |
|
from PIL import Image |
|
import io |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
class SafetyMonitor: |
|
def __init__(self, model_name: str = "mixtral-8x7b-vision"): |
|
""" |
|
Initialize the safety monitor using environment variables for API key |
|
""" |
|
api_key = os.getenv("GROQ_API_KEY") |
|
if not api_key: |
|
raise ValueError("GROQ_API_KEY environment variable is not set") |
|
|
|
self.client = Groq(api_key=api_key) |
|
self.model_name = model_name |
|
|
|
def analyze_frame(self, frame: np.ndarray) -> str: |
|
""" |
|
Analyze a single frame using specified vision model |
|
""" |
|
if frame is None: |
|
return "No frame received" |
|
|
|
|
|
frame_pil = Image.fromarray(frame) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
frame_pil.save(img_byte_arr, format='JPEG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
|
|
prompt = """Please analyze this image for workplace safety issues. Focus on: |
|
1. Required PPE usage (hard hats, safety glasses, reflective vests) |
|
2. Unsafe behaviors or positions |
|
3. Equipment and machinery safety |
|
4. Environmental hazards (spills, obstacles, poor lighting) |
|
|
|
Provide specific observations and any immediate safety concerns.""" |
|
|
|
try: |
|
completion = self.client.chat.completions.create( |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "image", "image": img_byte_arr} |
|
] |
|
} |
|
], |
|
model=self.model_name, |
|
max_tokens=200, |
|
temperature=0.2 |
|
) |
|
return completion.choices[0].message.content |
|
except Exception as e: |
|
return f"Analysis Error: {str(e)}" |
|
|
|
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]: |
|
""" |
|
Process and analyze a single frame |
|
""" |
|
if frame is None: |
|
return None, "No frame received" |
|
|
|
analysis = self.analyze_frame(frame) |
|
|
|
|
|
display_frame = frame.copy() |
|
|
|
|
|
overlay = display_frame.copy() |
|
cv2.rectangle(overlay, (5, 5), (640, 200), (0, 0, 0), -1) |
|
cv2.addWeighted(overlay, 0.3, display_frame, 0.7, 0, display_frame) |
|
|
|
|
|
cv2.putText(display_frame, "Safety Analysis:", (10, 30), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) |
|
|
|
|
|
y_position = 60 |
|
for line in analysis.split('\n'): |
|
cv2.putText(display_frame, line[:80], (10, y_position), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) |
|
y_position += 30 |
|
|
|
return display_frame, analysis |
|
|
|
def create_gradio_interface(): |
|
""" |
|
Create and launch the Gradio interface with webcam input |
|
""" |
|
try: |
|
|
|
monitor = SafetyMonitor(model_name="mixtral-8x7b-vision") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Real-time Safety Monitoring System |
|
Click 'Start Webcam' to begin monitoring. |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
webcam = gr.Image(source="webcam", streaming=True, label="Webcam Feed") |
|
|
|
output_image = gr.Image(label="Analyzed Feed") |
|
|
|
with gr.Row(): |
|
analysis_text = gr.Textbox(label="Safety Analysis", lines=5) |
|
|
|
def analyze_stream(frame): |
|
if frame is None: |
|
return None, "Webcam not started" |
|
processed_frame, analysis = monitor.process_frame(frame) |
|
return processed_frame, analysis |
|
|
|
webcam.stream( |
|
fn=analyze_stream, |
|
outputs=[output_image, analysis_text], |
|
show_progress="hidden" |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
except ValueError as e: |
|
print(f"Error: {e}") |
|
print("Please make sure to set the GROQ_API_KEY environment variable") |
|
|
|
if __name__ == "__main__": |
|
create_gradio_interface() |