import gradio as gr import cv2 import numpy as np from groq import Groq import time from PIL import Image as PILImage import io import os import base64 def create_monitor_interface(): api_key = os.getenv("GROQ_API_KEY") class SafetyMonitor: def __init__(self): self.client = Groq() self.model_name = "llama-3.2-90b-vision-preview" self.max_image_size = (800, 800) self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)] def analyze_frame(self, frame: np.ndarray) -> str: if frame is None: return "" # Convert image if len(frame.shape) == 2: frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) elif len(frame.shape) == 3 and frame.shape[2] == 4: frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) frame = self.resize_image(frame) frame_pil = PILImage.fromarray(frame) buffered = io.BytesIO() frame_pil.save(buffered, format="JPEG", quality=85, optimize=True) img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') image_url = f"data:image/jpeg;base64,{img_base64}" try: completion = self.client.chat.completions.create( model=self.model_name, messages=[ { "role": "system", "content": "You are a safety analysis expert. Analyze images for safety concerns and provide detailed observations." }, { "role": "user", "content": [ { "type": "text", "text": """Analyze this image for safety concerns and risks. For each issue you identify: 1. Specify the exact location in the image where the issue is visible 2. Describe what the safety concern is 3. Include any relevant details about PPE, posture, equipment, or environmental hazards Format EACH observation exactly like this: - position:detailed description of the concern Example format: - center:Worker bending incorrectly while lifting heavy materials - top-right:Missing safety guardrail near elevated platform Provide multiple observations if you see multiple issues.""" }, { "type": "image_url", "image_url": { "url": image_url } } ] } ], temperature=0.5, # Increased for more varied observations max_tokens=500, stream=False ) response = completion.choices[0].message.content print(f"Raw response: {response}") # For debugging return response except Exception as e: print(f"Analysis error: {str(e)}") return "" def resize_image(self, image): height, width = image.shape[:2] if height > self.max_image_size[1] or width > self.max_image_size[0]: aspect = width / height if width > height: new_width = self.max_image_size[0] new_height = int(new_width / aspect) else: new_height = self.max_image_size[1] new_width = int(new_height * aspect) return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) return image def get_region_coordinates(self, position: str, image_shape: tuple) -> tuple: height, width = image_shape[:2] regions = { 'top-left': (0, 0, width//3, height//3), 'top': (width//3, 0, 2*width//3, height//3), 'top-right': (2*width//3, 0, width, height//3), 'left': (0, height//3, width//3, 2*height//3), 'center': (width//3, height//3, 2*width//3, 2*height//3), 'right': (2*width//3, height//3, width, 2*height//3), 'bottom-left': (0, 2*height//3, width//3, height), 'bottom': (width//3, 2*height//3, 2*width//3, height), 'bottom-right': (2*width//3, 2*height//3, width, height) } # Try to match the position with regions matched_region = None max_match_length = 0 position_lower = position.lower() for region_name in regions: if region_name in position_lower: if len(region_name) > max_match_length: matched_region = region_name max_match_length = len(region_name) if matched_region: return regions[matched_region] return regions['center'] def draw_observations(self, image, observations): height, width = image.shape[:2] font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 for idx, obs in enumerate(observations): color = self.colors[idx % len(self.colors)] parts = obs.split(':') if len(parts) >= 2: position = parts[0] description = ':'.join(parts[1:]) x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape) # Draw rectangle cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) # Add label with background label = description[:50] + "..." if len(description) > 50 else description label_size = cv2.getTextSize(label, font, font_scale, thickness)[0] label_x = max(0, min(x1, width - label_size[0])) label_y = max(20, y1 - 5) cv2.rectangle(image, (label_x, label_y - 20), (label_x + label_size[0], label_y), color, -1) cv2.putText(image, label, (label_x, label_y - 5), font, font_scale, (255, 255, 255), thickness) return image def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]: if frame is None: return None, "No image provided" analysis = self.analyze_frame(frame) print(f"Analysis received: {analysis}") # Debug print observations = [] for line in analysis.split('\n'): line = line.strip() if line.startswith('-'): if '' in line and '' in line: start = line.find('') + len('') end = line.find('') observation = line[start:end].strip() if observation and ':' in observation: observations.append(observation) print(f"Parsed observations: {observations}") # Debug print display_frame = frame.copy() if observations: annotated_frame = self.draw_observations(display_frame, observations) return annotated_frame, analysis # If no observations were found but we got some analysis if analysis and not analysis.isspace(): return display_frame, analysis return display_frame, "Please try again - no safety analysis was generated." monitor = SafetyMonitor() with gr.Blocks() as demo: gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision") with gr.Row(): input_image = gr.Image(label="Upload Image") output_image = gr.Image(label="Analysis Results") analysis_text = gr.Textbox(label="Safety Analysis", lines=5) def analyze_image(image): if image is None: return None, "No image provided" try: processed_frame, analysis = monitor.process_frame(image) return processed_frame, analysis except Exception as e: print(f"Processing error: {str(e)}") return None, f"Error processing image: {str(e)}" input_image.change( fn=analyze_image, inputs=input_image, outputs=[output_image, analysis_text] ) return demo demo = create_monitor_interface() demo.launch()