import gradio as gr import cv2 import numpy as np from groq import Groq import time from PIL import Image as PILImage import io import os import base64 import random def create_monitor_interface(): api_key = os.getenv("GROQ_API_KEY") class SafetyMonitor: def __init__(self): self.client = Groq() self.model_name = "llama-3.2-90b-vision-preview" self.max_image_size = (800, 800) self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)] self.last_analysis_time = 0 self.analysis_interval = 2 self.last_observations = [] def resize_image(self, image): height, width = image.shape[:2] if height > self.max_image_size[1] or width > self.max_image_size[0]: aspect = width / height if width > height: new_width = self.max_image_size[0] new_height = int(new_width / aspect) else: new_height = self.max_image_size[1] new_width = int(new_height * aspect) return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) return image def analyze_frame(self, frame: np.ndarray) -> str: if frame is None: return "No frame received" # Convert image if len(frame.shape) == 2: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) elif len(frame.shape) == 3 and frame.shape[2] == 4: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) frame = self.resize_image(frame) frame_pil = PILImage.fromarray(frame) buffered = io.BytesIO() frame_pil.save(buffered, format="JPEG", quality=85, optimize=True) img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') image_url = f"data:image/jpeg;base64,{img_base64}" try: completion = self.client.chat.completions.create( model=self.model_name, messages=[ { "role": "user", "content": [ { "type": "text", "text": """Analyze this image for safety hazards. For each issue, describe: 1. The location (top-left, center, bottom-right, etc.) 2. The specific safety concern Format: - position:description""" }, { "type": "image_url", "image_url": { "url": image_url } } ] }, { "role": "assistant", "content": "" } ], temperature=0.1, max_tokens=200, top_p=1, stream=False, stop=None ) return completion.choices[0].message.content except Exception as e: print(f"Detailed error: {str(e)}") return f"Analysis Error: {str(e)}" def get_region_coordinates(self, position: str, image_shape: tuple) -> tuple: height, width = image_shape[:2] regions = { 'top-left': (0, 0, width//3, height//3), 'top': (width//3, 0, 2*width//3, height//3), 'top-right': (2*width//3, 0, width, height//3), 'left': (0, height//3, width//3, 2*height//3), 'center': (width//3, height//3, 2*width//3, 2*height//3), 'right': (2*width//3, height//3, width, 2*height//3), 'bottom-left': (0, 2*height//3, width//3, height), 'bottom': (width//3, 2*height//3, 2*width//3, height), 'bottom-right': (2*width//3, 2*height//3, width, height) } for region_name, coords in regions.items(): if region_name in position.lower(): return coords return regions['center'] def draw_observations(self, image, observations): height, width = image.shape[:2] font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 for idx, obs in enumerate(observations): color = self.colors[idx % len(self.colors)] parts = obs.split(':') if len(parts) >= 2: position = parts[0] description = ':'.join(parts[1:]) else: position = 'center' description = obs x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape) cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) label = description[:50] + "..." if len(description) > 50 else description label_size = cv2.getTextSize(label, font, font_scale, thickness)[0] label_x = max(0, min(x1, width - label_size[0])) label_y = max(20, y1 - 5) cv2.rectangle(image, (label_x, label_y - 20), (label_x + label_size[0], label_y), color, -1) cv2.putText(image, label, (label_x, label_y - 5), font, font_scale, (255, 255, 255), thickness) return image def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]: if frame is None: return None, "No image provided" current_time = time.time() if current_time - self.last_analysis_time >= self.analysis_interval: analysis = self.analyze_frame(frame) self.last_analysis_time = current_time observations = [] for line in analysis.split('\n'): line = line.strip() if line.startswith('-'): if '' in line and '' in line: start = line.find('') + len('') end = line.find('') observation = line[start:end].strip() if observation: observations.append(observation) self.last_observations = observations display_frame = frame.copy() annotated_frame = self.draw_observations(display_frame, self.last_observations) return annotated_frame, '\n'.join([f"- {obs}" for obs in self.last_observations]) # Create the main interface monitor = SafetyMonitor() with gr.Blocks() as demo: gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision") with gr.Row(): input_image = gr.Image(label="Upload Image") output_image = gr.Image(label="Analysis") analysis_text = gr.Textbox(label="Safety Concerns", lines=5) def analyze_image(image): if image is None: return None, "No image provided" try: processed_frame, analysis = monitor.process_frame(image) return processed_frame, analysis except Exception as e: print(f"Processing error: {str(e)}") return None, f"Error processing image: {str(e)}" input_image.change( fn=analyze_image, inputs=input_image, outputs=[output_image, analysis_text] ) gr.Markdown(""" ## Instructions: 1. Upload an image to analyze safety concerns 2. View annotated results and detailed analysis 3. Each box highlights a potential safety issue """) return demo demo = create_monitor_interface() demo.launch()