Update app.py
Browse files
app.py
CHANGED
@@ -6,21 +6,32 @@ from PIL import Image as PILImage
|
|
6 |
import io
|
7 |
import base64
|
8 |
import torch
|
|
|
|
|
9 |
|
|
|
|
|
10 |
|
11 |
class RobustSafetyMonitor:
|
12 |
def __init__(self):
|
13 |
"""Initialize the robust safety detection tool with configuration."""
|
14 |
self.client = Groq()
|
15 |
-
self.model_name = "llama-3.2-
|
16 |
self.max_image_size = (800, 800)
|
17 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
18 |
|
19 |
# Load YOLOv5 model for general object detection
|
20 |
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
def preprocess_image(self, frame):
|
23 |
"""Process image for analysis."""
|
|
|
|
|
|
|
24 |
if len(frame.shape) == 2:
|
25 |
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
26 |
elif len(frame.shape) == 3 and frame.shape[2] == 4:
|
@@ -28,7 +39,7 @@ class RobustSafetyMonitor:
|
|
28 |
|
29 |
return self.resize_image(frame)
|
30 |
|
31 |
-
def resize_image(self, image):
|
32 |
"""Resize image while maintaining aspect ratio."""
|
33 |
height, width = image.shape[:2]
|
34 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
@@ -42,35 +53,37 @@ class RobustSafetyMonitor:
|
|
42 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
43 |
return image
|
44 |
|
45 |
-
def encode_image(self, frame):
|
46 |
"""Convert image to base64 encoding with proper formatting."""
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
def detect_objects(self, frame):
|
57 |
"""Detect objects using YOLOv5."""
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
if frame is None:
|
67 |
-
return "No frame received"
|
68 |
|
69 |
-
frame = self.preprocess_image(frame)
|
70 |
-
image_base64 = self.encode_image(frame)
|
71 |
-
|
72 |
try:
|
73 |
-
|
|
|
|
|
74 |
completion = self.client.chat.completions.create(
|
75 |
model=self.model_name,
|
76 |
messages=[
|
@@ -80,13 +93,13 @@ class RobustSafetyMonitor:
|
|
80 |
{
|
81 |
"type": "text",
|
82 |
"text": """Analyze this workplace image and identify any potential safety risks.
|
83 |
-
|
84 |
-
|
85 |
},
|
86 |
{
|
87 |
"type": "image_url",
|
88 |
"image_url": {
|
89 |
-
"url": image_base64
|
90 |
}
|
91 |
}
|
92 |
]
|
@@ -96,98 +109,134 @@ class RobustSafetyMonitor:
|
|
96 |
max_tokens=1024,
|
97 |
stream=False
|
98 |
)
|
99 |
-
|
100 |
-
response
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
except Exception as e:
|
104 |
print(f"Analysis error: {str(e)}")
|
105 |
-
return f"Analysis Error: {str(e)}"
|
106 |
|
107 |
-
def draw_bounding_boxes(self, image, bboxes
|
108 |
-
|
|
|
|
|
109 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
110 |
font_scale = 0.5
|
111 |
thickness = 2
|
112 |
|
113 |
for idx, bbox in enumerate(bboxes):
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
-
return
|
132 |
|
133 |
-
def process_frame(self, frame):
|
134 |
-
"""Main processing pipeline for
|
135 |
if frame is None:
|
136 |
return None, "No image provided"
|
137 |
|
138 |
try:
|
139 |
-
# Detect objects
|
140 |
bbox_data, labels = self.detect_objects(frame)
|
141 |
-
|
142 |
-
|
143 |
-
# Get dynamic safety analysis from Llama Vision 3.2
|
144 |
safety_issues, analysis = self.analyze_frame(frame)
|
145 |
-
|
146 |
-
#
|
147 |
-
annotated_frame = self.draw_bounding_boxes(
|
148 |
-
|
149 |
return annotated_frame, analysis
|
150 |
|
151 |
except Exception as e:
|
152 |
print(f"Processing error: {str(e)}")
|
153 |
return None, f"Error processing image: {str(e)}"
|
154 |
|
155 |
-
def parse_safety_analysis(self, analysis):
|
156 |
-
"""Parse the safety analysis
|
157 |
safety_issues = []
|
|
|
|
|
|
|
|
|
158 |
for line in analysis.split('\n'):
|
159 |
-
if "risk" in line.lower()
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
163 |
safety_issues.append({
|
164 |
-
"object":
|
165 |
-
"description":
|
166 |
})
|
|
|
|
|
|
|
|
|
167 |
return safety_issues
|
168 |
|
169 |
|
170 |
def create_monitor_interface():
|
|
|
171 |
monitor = RobustSafetyMonitor()
|
172 |
|
173 |
with gr.Blocks() as demo:
|
174 |
-
gr.Markdown("#
|
|
|
175 |
|
176 |
with gr.Row():
|
177 |
-
input_image = gr.Image(label="Upload Image")
|
178 |
-
output_image = gr.Image(label="Safety Analysis")
|
|
|
|
|
179 |
|
180 |
-
analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
|
181 |
-
|
182 |
def analyze_image(image):
|
183 |
if image is None:
|
184 |
-
return None, "
|
185 |
try:
|
186 |
processed_frame, analysis = monitor.process_frame(image)
|
187 |
return processed_frame, analysis
|
188 |
except Exception as e:
|
189 |
-
print(f"
|
190 |
-
return None, f"Error
|
191 |
|
192 |
input_image.upload(
|
193 |
fn=analyze_image,
|
@@ -196,14 +245,20 @@ def create_monitor_interface():
|
|
196 |
)
|
197 |
|
198 |
gr.Markdown("""
|
199 |
-
## Instructions
|
200 |
-
1. Upload
|
201 |
-
2. View
|
202 |
-
3. Read detailed
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
""")
|
204 |
|
205 |
return demo
|
206 |
|
207 |
if __name__ == "__main__":
|
208 |
demo = create_monitor_interface()
|
209 |
-
demo.launch()
|
|
|
6 |
import io
|
7 |
import base64
|
8 |
import torch
|
9 |
+
import warnings
|
10 |
+
from typing import Tuple, List, Dict, Optional
|
11 |
|
12 |
+
# Suppress the CUDA autocast warning
|
13 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
14 |
|
15 |
class RobustSafetyMonitor:
|
16 |
def __init__(self):
|
17 |
"""Initialize the robust safety detection tool with configuration."""
|
18 |
self.client = Groq()
|
19 |
+
self.model_name = "llama-3.2-11b-vision-preview" # Updated to use the correct model
|
20 |
self.max_image_size = (800, 800)
|
21 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
22 |
|
23 |
# Load YOLOv5 model for general object detection
|
24 |
self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
25 |
+
|
26 |
+
# Force CPU inference if CUDA is causing issues
|
27 |
+
self.yolo_model.cpu()
|
28 |
+
self.yolo_model.eval()
|
29 |
|
30 |
+
def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
|
31 |
"""Process image for analysis."""
|
32 |
+
if frame is None:
|
33 |
+
raise ValueError("No image provided")
|
34 |
+
|
35 |
if len(frame.shape) == 2:
|
36 |
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
37 |
elif len(frame.shape) == 3 and frame.shape[2] == 4:
|
|
|
39 |
|
40 |
return self.resize_image(frame)
|
41 |
|
42 |
+
def resize_image(self, image: np.ndarray) -> np.ndarray:
|
43 |
"""Resize image while maintaining aspect ratio."""
|
44 |
height, width = image.shape[:2]
|
45 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
|
|
53 |
return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
|
54 |
return image
|
55 |
|
56 |
+
def encode_image(self, frame: np.ndarray) -> str:
|
57 |
"""Convert image to base64 encoding with proper formatting."""
|
58 |
+
try:
|
59 |
+
frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
60 |
+
buffered = io.BytesIO()
|
61 |
+
frame_pil.save(buffered, format="JPEG", quality=95)
|
62 |
+
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
63 |
+
return f"data:image/jpeg;base64,{img_base64}"
|
64 |
+
except Exception as e:
|
65 |
+
raise ValueError(f"Error encoding image: {str(e)}")
|
66 |
|
67 |
+
def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
|
68 |
"""Detect objects using YOLOv5."""
|
69 |
+
try:
|
70 |
+
with torch.no_grad():
|
71 |
+
results = self.yolo_model(frame)
|
72 |
+
bbox_data = results.xyxy[0].cpu().numpy()
|
73 |
+
labels = results.names
|
74 |
+
return bbox_data, labels
|
75 |
+
except Exception as e:
|
76 |
+
raise ValueError(f"Error detecting objects: {str(e)}")
|
77 |
+
|
78 |
+
def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
|
79 |
+
"""Perform safety analysis on the frame using Llama Vision."""
|
80 |
if frame is None:
|
81 |
+
return [], "No frame received"
|
82 |
|
|
|
|
|
|
|
83 |
try:
|
84 |
+
frame = self.preprocess_image(frame)
|
85 |
+
image_base64 = self.encode_image(frame)
|
86 |
+
|
87 |
completion = self.client.chat.completions.create(
|
88 |
model=self.model_name,
|
89 |
messages=[
|
|
|
93 |
{
|
94 |
"type": "text",
|
95 |
"text": """Analyze this workplace image and identify any potential safety risks.
|
96 |
+
List each risk on a new line starting with 'Risk:'.
|
97 |
+
Format: Risk: [Object/Area] - [Description of hazard]"""
|
98 |
},
|
99 |
{
|
100 |
"type": "image_url",
|
101 |
"image_url": {
|
102 |
+
"url": image_base64
|
103 |
}
|
104 |
}
|
105 |
]
|
|
|
109 |
max_tokens=1024,
|
110 |
stream=False
|
111 |
)
|
112 |
+
|
113 |
+
# Get the response content safely
|
114 |
+
try:
|
115 |
+
response = completion.choices[0].message.content
|
116 |
+
except AttributeError:
|
117 |
+
response = str(completion.choices[0].message)
|
118 |
+
|
119 |
+
safety_issues = self.parse_safety_analysis(response)
|
120 |
+
return safety_issues, response
|
121 |
|
122 |
except Exception as e:
|
123 |
print(f"Analysis error: {str(e)}")
|
124 |
+
return [], f"Analysis Error: {str(e)}"
|
125 |
|
126 |
+
def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
|
127 |
+
labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
|
128 |
+
"""Draw bounding boxes around objects based on safety issues."""
|
129 |
+
image_copy = image.copy()
|
130 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
131 |
font_scale = 0.5
|
132 |
thickness = 2
|
133 |
|
134 |
for idx, bbox in enumerate(bboxes):
|
135 |
+
try:
|
136 |
+
x1, y1, x2, y2, conf, class_id = bbox
|
137 |
+
label = labels[int(class_id)]
|
138 |
+
color = self.colors[idx % len(self.colors)]
|
139 |
+
|
140 |
+
# Convert coordinates to integers
|
141 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
142 |
+
|
143 |
+
# Draw bounding box
|
144 |
+
cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
|
145 |
+
|
146 |
+
# Check if object is associated with any safety issues
|
147 |
+
risk_found = False
|
148 |
+
for safety_issue in safety_issues:
|
149 |
+
if safety_issue.get('object', '').lower() in label.lower():
|
150 |
+
label_text = f"Risk: {safety_issue.get('description', '')}"
|
151 |
+
y_pos = max(y1 - 10, 20)
|
152 |
+
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
153 |
+
font_scale, (0, 0, 255), thickness)
|
154 |
+
risk_found = True
|
155 |
+
break
|
156 |
+
|
157 |
+
if not risk_found:
|
158 |
+
label_text = f"{label} {conf:.2f}"
|
159 |
+
y_pos = max(y1 - 10, 20)
|
160 |
+
cv2.putText(image_copy, label_text, (x1, y_pos), font,
|
161 |
+
font_scale, color, thickness)
|
162 |
+
except Exception as e:
|
163 |
+
print(f"Error drawing box: {str(e)}")
|
164 |
+
continue
|
165 |
|
166 |
+
return image_copy
|
167 |
|
168 |
+
def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
|
169 |
+
"""Main processing pipeline for safety analysis."""
|
170 |
if frame is None:
|
171 |
return None, "No image provided"
|
172 |
|
173 |
try:
|
174 |
+
# Detect objects
|
175 |
bbox_data, labels = self.detect_objects(frame)
|
176 |
+
|
177 |
+
# Get safety analysis
|
|
|
178 |
safety_issues, analysis = self.analyze_frame(frame)
|
179 |
+
|
180 |
+
# Draw annotations
|
181 |
+
annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
|
182 |
+
|
183 |
return annotated_frame, analysis
|
184 |
|
185 |
except Exception as e:
|
186 |
print(f"Processing error: {str(e)}")
|
187 |
return None, f"Error processing image: {str(e)}"
|
188 |
|
189 |
+
def parse_safety_analysis(self, analysis: str) -> List[Dict]:
|
190 |
+
"""Parse the safety analysis text into structured data."""
|
191 |
safety_issues = []
|
192 |
+
|
193 |
+
if not isinstance(analysis, str):
|
194 |
+
return safety_issues
|
195 |
+
|
196 |
for line in analysis.split('\n'):
|
197 |
+
if "risk:" in line.lower():
|
198 |
+
try:
|
199 |
+
# Extract object and description
|
200 |
+
parts = line.lower().split('risk:', 1)[1].strip()
|
201 |
+
if '-' in parts:
|
202 |
+
obj, desc = parts.split('-', 1)
|
203 |
+
else:
|
204 |
+
obj, desc = parts, parts
|
205 |
+
|
206 |
safety_issues.append({
|
207 |
+
"object": obj.strip(),
|
208 |
+
"description": desc.strip()
|
209 |
})
|
210 |
+
except Exception as e:
|
211 |
+
print(f"Error parsing line: {line}, Error: {str(e)}")
|
212 |
+
continue
|
213 |
+
|
214 |
return safety_issues
|
215 |
|
216 |
|
217 |
def create_monitor_interface():
|
218 |
+
"""Create the Gradio interface for the safety monitoring system."""
|
219 |
monitor = RobustSafetyMonitor()
|
220 |
|
221 |
with gr.Blocks() as demo:
|
222 |
+
gr.Markdown("# Workplace Safety Analysis System")
|
223 |
+
gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
|
224 |
|
225 |
with gr.Row():
|
226 |
+
input_image = gr.Image(label="Upload Workplace Image", type="numpy")
|
227 |
+
output_image = gr.Image(label="Safety Analysis Visualization")
|
228 |
+
|
229 |
+
analysis_text = gr.Textbox(label="Detailed Safety Analysis", lines=5)
|
230 |
|
|
|
|
|
231 |
def analyze_image(image):
|
232 |
if image is None:
|
233 |
+
return None, "Please upload an image"
|
234 |
try:
|
235 |
processed_frame, analysis = monitor.process_frame(image)
|
236 |
return processed_frame, analysis
|
237 |
except Exception as e:
|
238 |
+
print(f"Analysis error: {str(e)}")
|
239 |
+
return None, f"Error analyzing image: {str(e)}"
|
240 |
|
241 |
input_image.upload(
|
242 |
fn=analyze_image,
|
|
|
245 |
)
|
246 |
|
247 |
gr.Markdown("""
|
248 |
+
## Instructions
|
249 |
+
1. Upload a workplace image for safety analysis
|
250 |
+
2. View detected hazards and their locations in the visualization
|
251 |
+
3. Read the detailed safety analysis below the images
|
252 |
+
|
253 |
+
## Features
|
254 |
+
- Real-time object detection
|
255 |
+
- AI-powered safety risk analysis
|
256 |
+
- Visual risk highlighting
|
257 |
+
- Detailed safety recommendations
|
258 |
""")
|
259 |
|
260 |
return demo
|
261 |
|
262 |
if __name__ == "__main__":
|
263 |
demo = create_monitor_interface()
|
264 |
+
demo.launch(share=True)
|