capradeepgujaran's picture
Update app.py
46e12d1 verified
raw
history blame
10.1 kB
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
def create_monitor_interface():
api_key = os.getenv("GROQ_API_KEY")
class SafetyMonitor:
def __init__(self):
self.client = Groq()
self.model_name = "llama-3.2-90b-vision-preview"
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
def analyze_frame(self, frame: np.ndarray) -> str:
if frame is None:
return "No frame received"
frame = self.preprocess_image(frame)
image_url = self.encode_image(frame)
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this image for safety hazards and issues. For each identified hazard:
1. Specify the exact location in the image where the hazard exists
2. Describe the specific safety concern
3. Note any violations or risks
Format each observation exactly as:
- <location>area:hazard description</location>
Examples of locations: top-left, center, bottom-right, full-area, near-machine, workspace, etc.
Look for ALL types of safety issues including:
- Personal protective equipment (PPE)
- Machine and equipment hazards
- Ergonomic risks
- Environmental hazards
- Fire and electrical safety
- Chemical safety
- Fall protection
- Material handling
- Access/egress issues
- Housekeeping
- Tool safety
- Emergency equipment
Be specific about locations and provide detailed observations."""
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
}
],
temperature=0.5,
max_tokens=500,
stream=False
)
return completion.choices[0].message.content
except Exception as e:
print(f"Analysis error: {str(e)}")
return f"Analysis Error: {str(e)}"
def preprocess_image(self, frame):
"""Prepare image for analysis."""
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
return self.resize_image(frame)
def resize_image(self, image):
"""Resize image while maintaining aspect ratio."""
height, width = image.shape[:2]
if height > self.max_image_size[1] or width > self.max_image_size[0]:
aspect = width / height
if width > height:
new_width = self.max_image_size[0]
new_height = int(new_width / aspect)
else:
new_height = self.max_image_size[1]
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
def encode_image(self, frame):
"""Convert image to base64 encoding."""
frame_pil = PILImage.fromarray(frame)
buffered = io.BytesIO()
frame_pil.save(buffered, format="JPEG", quality=95)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
def parse_locations(self, observation: str) -> dict:
"""Parse location information from observation."""
locations = {
'full': (0, 0, 1, 1),
'top': (0.2, 0, 0.8, 0.3),
'bottom': (0.2, 0.7, 0.8, 1),
'left': (0, 0.2, 0.3, 0.8),
'right': (0.7, 0.2, 1, 0.8),
'center': (0.3, 0.3, 0.7, 0.7),
'top-left': (0, 0, 0.3, 0.3),
'top-right': (0.7, 0, 1, 0.3),
'bottom-left': (0, 0.7, 0.3, 1),
'bottom-right': (0.7, 0.7, 1, 1),
'workspace': (0.2, 0.2, 0.8, 0.8),
'near-machine': (0.6, 0.1, 1, 0.9),
'floor-area': (0, 0.7, 1, 1),
'equipment': (0.5, 0.1, 1, 0.9)
}
# Find best matching location
text = observation.lower()
best_match = 'center'
max_match = 0
for loc in locations.keys():
if loc in text:
words = loc.split('-')
matches = sum(1 for word in words if word in text)
if matches > max_match:
max_match = matches
best_match = loc
return locations[best_match]
def draw_observations(self, image, observations):
"""Draw bounding boxes and labels for safety observations."""
height, width = image.shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 2
padding = 10
for idx, obs in enumerate(observations):
color = self.colors[idx % len(self.colors)]
# Get relative coordinates and convert to absolute
rel_coords = self.parse_locations(obs['location'])
x1 = int(rel_coords[0] * width)
y1 = int(rel_coords[1] * height)
x2 = int(rel_coords[2] * width)
y2 = int(rel_coords[3] * height)
# Draw rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# Prepare label
label = obs['description'][:50]
if len(obs['description']) > 50:
label += "..."
# Calculate text position
label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
text_x = max(0, x1)
text_y = max(label_size[1] + padding, y1 - padding)
# Draw label background
cv2.rectangle(image,
(text_x, text_y - label_size[1] - padding),
(text_x + label_size[0] + padding, text_y),
color, -1)
# Draw label text
cv2.putText(image, label,
(text_x + padding//2, text_y - padding//2),
font, font_scale, (255, 255, 255), thickness)
return image
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
"""Process frame and generate safety analysis with visualizations."""
if frame is None:
return None, "No image provided"
# Get analysis
analysis = self.analyze_frame(frame)
display_frame = frame.copy()
# Parse observations
observations = []
for line in analysis.split('\n'):
line = line.strip()
if line.startswith('-') and '<location>' in line and '</location>' in line:
start = line.find('<location>') + len('<location>')
end = line.find('</location>')
location_description = line[start:end].strip()
# Split location and description
if ':' in location_description:
location, description = location_description.split(':', 1)
observations.append({
'location': location.strip(),
'description': description.strip()
})
# Draw observations if any were found
if observations:
annotated_frame = self.draw_observations(display_frame, observations)
return annotated_frame, analysis
return display_frame, analysis
# Create interface
monitor = SafetyMonitor()
with gr.Blocks() as demo:
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
with gr.Row():
input_image = gr.Image(label="Upload Image")
output_image = gr.Image(label="Safety Analysis")
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
def analyze_image(image):
if image is None:
return None, "No image provided"
try:
processed_frame, analysis = monitor.process_frame(image)
return processed_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
input_image.change(
fn=analyze_image,
inputs=input_image,
outputs=[output_image, analysis_text]
)
gr.Markdown("""
## Instructions:
1. Upload any workplace/safety-related image
2. View identified hazards and safety concerns
3. Check detailed analysis for recommendations
""")
return demo
demo = create_monitor_interface()
demo.launch()