capradeepgujaran's picture
Update app.py
e43f38f verified
raw
history blame
8.84 kB
import gradio as gr
import cv2
import numpy as np
from groq import Groq
import time
from PIL import Image as PILImage
import io
import os
import base64
def create_monitor_interface():
api_key = os.getenv("GROQ_API_KEY")
class SafetyMonitor:
def __init__(self):
self.client = Groq()
self.model_name = "llama-3.2-90b-vision-preview"
self.max_image_size = (800, 800)
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
def resize_image(self, image):
height, width = image.shape[:2]
if height > self.max_image_size[1] or width > self.max_image_size[0]:
aspect = width / height
if width > height:
new_width = self.max_image_size[0]
new_height = int(new_width / aspect)
else:
new_height = self.max_image_size[1]
new_width = int(new_height * aspect)
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
return image
def analyze_frame(self, frame: np.ndarray) -> str:
if frame is None:
return ""
# Convert image
if len(frame.shape) == 2:
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
elif len(frame.shape) == 3 and frame.shape[2] == 4:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
frame = self.resize_image(frame)
frame_pil = PILImage.fromarray(frame)
buffered = io.BytesIO()
frame_pil.save(buffered,
format="JPEG",
quality=85,
optimize=True)
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
image_url = f"data:image/jpeg;base64,{img_base64}"
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": """Analyze this image for safety concerns. For each specific issue you identify, provide:
1. Exact location in the image (e.g., 'top-left', 'center', 'bottom-right', etc.)
2. Description of the safety concern
Format your response with each issue on a new line as:
- <location>position:detailed description of the safety concern</location>
Be specific about what you observe in the image."""
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
}
]
},
{
"role": "assistant",
"content": ""
}
],
temperature=0.2,
max_tokens=500,
top_p=1,
stream=False,
stop=None
)
return completion.choices[0].message.content
except Exception as e:
print(f"Analysis error: {str(e)}")
return ""
def get_region_coordinates(self, position: str, image_shape: tuple) -> tuple:
height, width = image_shape[:2]
regions = {
'top-left': (0, 0, width//3, height//3),
'top': (width//3, 0, 2*width//3, height//3),
'top-right': (2*width//3, 0, width, height//3),
'left': (0, height//3, width//3, 2*height//3),
'center': (width//3, height//3, 2*width//3, 2*height//3),
'right': (2*width//3, height//3, width, 2*height//3),
'bottom-left': (0, 2*height//3, width//3, height),
'bottom': (width//3, 2*height//3, 2*width//3, height),
'bottom-right': (2*width//3, 2*height//3, width, height)
}
for region_name, coords in regions.items():
if region_name in position.lower():
return coords
return regions['center']
def draw_observations(self, image, observations):
height, width = image.shape[:2]
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
for idx, obs in enumerate(observations):
color = self.colors[idx % len(self.colors)]
parts = obs.split(':')
if len(parts) >= 2:
position = parts[0]
description = ':'.join(parts[1:])
else:
continue
x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
# Draw rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# Add label with background
label = description[:50] + "..." if len(description) > 50 else description
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
label_x = max(0, min(x1, width - label_size[0]))
label_y = max(20, y1 - 5)
# Draw background for text
cv2.rectangle(image, (label_x, label_y - 20),
(label_x + label_size[0], label_y), color, -1)
# Draw text
cv2.putText(image, label, (label_x, label_y - 5),
font, font_scale, (255, 255, 255), thickness)
return image
def process_frame(self, frame: np.ndarray) -> tuple[np.ndarray, str]:
if frame is None:
return None, "No image provided"
analysis = self.analyze_frame(frame)
# Parse observations
observations = []
for line in analysis.split('\n'):
line = line.strip()
if line.startswith('-'):
if '<location>' in line and '</location>' in line:
start = line.find('<location>') + len('<location>')
end = line.find('</location>')
observation = line[start:end].strip()
if observation and ':' in observation:
observations.append(observation)
display_frame = frame.copy()
if observations:
annotated_frame = self.draw_observations(display_frame, observations)
return annotated_frame, analysis
else:
return display_frame, "No safety concerns detected in the image."
# Create the main interface
monitor = SafetyMonitor()
with gr.Blocks() as demo:
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
with gr.Row():
input_image = gr.Image(label="Upload Image")
output_image = gr.Image(label="Analysis Results")
analysis_text = gr.Textbox(label="Safety Analysis", lines=5)
def analyze_image(image):
if image is None:
return None, "No image provided"
try:
processed_frame, analysis = monitor.process_frame(image)
return processed_frame, analysis
except Exception as e:
print(f"Processing error: {str(e)}")
return None, f"Error processing image: {str(e)}"
input_image.change(
fn=analyze_image,
inputs=input_image,
outputs=[output_image, analysis_text]
)
gr.Markdown("""
## Instructions:
1. Upload an image to analyze
2. View identified safety concerns with bounding boxes
3. Read detailed analysis results
""")
return demo
demo = create_monitor_interface()
demo.launch()