capradeepgujaran's picture
Update app.py
d2c67f3 verified
raw
history blame
10.4 kB
import gradio as gr
import cv2
import numpy as np
from groq import Groq
from PIL import Image as PILImage
import io
import base64
import torch
import warnings
from typing import Tuple, List, Dict, Optional
# Suppress the CUDA autocast warning
warnings.filterwarnings('ignore', category=FutureWarning)
class RobustSafetyMonitor:
def __init__(self):
"""Initialize the robust safety detection tool with configuration."""
self.client = Groq()
self.model_name = "llama-3.2-11b-vision-preview" # Updated to use the correct model
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
# Load YOLOv5 model for general object detection
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
# Force CPU inference if CUDA is causing issues
self.yolo_model.cpu()
self.yolo_model.eval()
def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
"""Process image for analysis."""
if frame is None:
raise ValueError("No image provided")
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
return self.resize_image(frame)
def resize_image(self, image: np.ndarray) -> np.ndarray:
"""Resize image while maintaining aspect ratio."""
height, width = image.shape[:2]
if height > self.max_image_size[1] or width > self.max_image_size[0]:
aspect = width / height
if width > height:
new_width = self.max_image_size[0]
new_height = int(new_width / aspect)
else:
new_height = self.max_image_size[1]
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
def encode_image(self, frame: np.ndarray) -> str:
"""Convert image to base64 encoding with proper formatting."""
try:
frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
buffered = io.BytesIO()
frame_pil.save(buffered, format="JPEG", quality=95)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image/jpeg;base64,{img_base64}"
except Exception as e:
raise ValueError(f"Error encoding image: {str(e)}")
def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
"""Detect objects using YOLOv5."""
try:
with torch.no_grad():
results = self.yolo_model(frame)
bbox_data = results.xyxy[0].cpu().numpy()
labels = results.names
return bbox_data, labels
except Exception as e:
raise ValueError(f"Error detecting objects: {str(e)}")
def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
"""Perform safety analysis on the frame using Llama Vision."""
if frame is None:
return [], "No frame received"
try:
frame = self.preprocess_image(frame)
image_base64 = self.encode_image(frame)
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this workplace image and identify any potential safety risks.
List each risk on a new line starting with 'Risk:'.
Format: Risk: [Object/Area] - [Description of hazard]"""
},
{
"type": "image_url",
"image_url": {
"url": image_base64
}
}
]
}
],
temperature=0.7,
max_tokens=1024,
stream=False
)
# Get the response content safely
try:
response = completion.choices[0].message.content
except AttributeError:
response = str(completion.choices[0].message)
safety_issues = self.parse_safety_analysis(response)
return safety_issues, response
except Exception as e:
print(f"Analysis error: {str(e)}")
return [], f"Analysis Error: {str(e)}"
def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
"""Draw bounding boxes around objects based on safety issues."""
image_copy = image.copy()
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.5
thickness = 2
for idx, bbox in enumerate(bboxes):
try:
x1, y1, x2, y2, conf, class_id = bbox
label = labels[int(class_id)]
color = self.colors[idx % len(self.colors)]
# Convert coordinates to integers
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
# Draw bounding box
cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
# Check if object is associated with any safety issues
risk_found = False
for safety_issue in safety_issues:
if safety_issue.get('object', '').lower() in label.lower():
label_text = f"Risk: {safety_issue.get('description', '')}"
y_pos = max(y1 - 10, 20)
cv2.putText(image_copy, label_text, (x1, y_pos), font,
font_scale, (0, 0, 255), thickness)
risk_found = True
break
if not risk_found:
label_text = f"{label} {conf:.2f}"
y_pos = max(y1 - 10, 20)
cv2.putText(image_copy, label_text, (x1, y_pos), font,
font_scale, color, thickness)
except Exception as e:
print(f"Error drawing box: {str(e)}")
continue
return image_copy
def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
"""Main processing pipeline for safety analysis."""
if frame is None:
return None, "No image provided"
try:
# Detect objects
bbox_data, labels = self.detect_objects(frame)
# Get safety analysis
safety_issues, analysis = self.analyze_frame(frame)
# Draw annotations
annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
return annotated_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
def parse_safety_analysis(self, analysis: str) -> List[Dict]:
"""Parse the safety analysis text into structured data."""
safety_issues = []
if not isinstance(analysis, str):
return safety_issues
for line in analysis.split('\n'):
if "risk:" in line.lower():
try:
# Extract object and description
parts = line.lower().split('risk:', 1)[1].strip()
if '-' in parts:
obj, desc = parts.split('-', 1)
else:
obj, desc = parts, parts
safety_issues.append({
"object": obj.strip(),
"description": desc.strip()
})
except Exception as e:
print(f"Error parsing line: {line}, Error: {str(e)}")
continue
return safety_issues
def create_monitor_interface():
"""Create the Gradio interface for the safety monitoring system."""
monitor = RobustSafetyMonitor()
with gr.Blocks() as demo:
gr.Markdown("# Workplace Safety Analysis System")
gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
with gr.Row():
input_image = gr.Image(label="Upload Workplace Image", type="numpy")
output_image = gr.Image(label="Safety Analysis Visualization")
analysis_text = gr.Textbox(label="Detailed Safety Analysis", lines=5)
def analyze_image(image):
if image is None:
return None, "Please upload an image"
try:
processed_frame, analysis = monitor.process_frame(image)
return processed_frame, analysis
except Exception as e:
print(f"Analysis error: {str(e)}")
return None, f"Error analyzing image: {str(e)}"
input_image.upload(
fn=analyze_image,
inputs=input_image,
outputs=[output_image, analysis_text]
)
gr.Markdown("""
## Instructions
1. Upload a workplace image for safety analysis
2. View detected hazards and their locations in the visualization
3. Read the detailed safety analysis below the images
## Features
- Real-time object detection
- AI-powered safety risk analysis
- Visual risk highlighting
- Detailed safety recommendations
""")
return demo
if __name__ == "__main__":
demo = create_monitor_interface()
demo.launch(share=True)