capradeepgujaran commited on
Commit
d2c67f3
·
verified ·
1 Parent(s): 160a45b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -83
app.py CHANGED
@@ -6,21 +6,32 @@ from PIL import Image as PILImage
6
  import io
7
  import base64
8
  import torch
 
 
9
 
 
 
10
 
11
  class RobustSafetyMonitor:
12
  def __init__(self):
13
  """Initialize the robust safety detection tool with configuration."""
14
  self.client = Groq()
15
- self.model_name = "llama-3.2-90b-vision-preview"
16
  self.max_image_size = (800, 800)
17
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
18
 
19
  # Load YOLOv5 model for general object detection
20
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
 
 
 
 
21
 
22
- def preprocess_image(self, frame):
23
  """Process image for analysis."""
 
 
 
24
  if len(frame.shape) == 2:
25
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
26
  elif len(frame.shape) == 3 and frame.shape[2] == 4:
@@ -28,7 +39,7 @@ class RobustSafetyMonitor:
28
 
29
  return self.resize_image(frame)
30
 
31
- def resize_image(self, image):
32
  """Resize image while maintaining aspect ratio."""
33
  height, width = image.shape[:2]
34
  if height > self.max_image_size[1] or width > self.max_image_size[0]:
@@ -42,35 +53,37 @@ class RobustSafetyMonitor:
42
  return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
43
  return image
44
 
45
- def encode_image(self, frame):
46
  """Convert image to base64 encoding with proper formatting."""
47
- frame_pil = PILImage.fromarray(frame)
48
- buffered = io.BytesIO()
49
- frame_pil.save(buffered, format="JPEG", quality=95) # Ensure JPEG format
50
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
51
-
52
- # Prepend the required image data type to base64 string
53
- img_base64_str = f"data:image/jpeg;base64,{img_base64}"
54
- return img_base64_str
55
 
56
- def detect_objects(self, frame):
57
  """Detect objects using YOLOv5."""
58
- results = self.yolo_model(frame)
59
- # Extract bounding boxes, class labels, and confidence scores
60
- bbox_data = results.xyxy[0].numpy() # Bounding box coordinates
61
- labels = results.names # Class names
62
- return bbox_data, labels
63
-
64
- def analyze_frame(self, frame):
65
- """Perform safety analysis on the frame using Llama Vision 3.2."""
 
 
 
66
  if frame is None:
67
- return "No frame received", {}
68
 
69
- frame = self.preprocess_image(frame)
70
- image_base64 = self.encode_image(frame)
71
-
72
  try:
73
- # Use Llama Vision 3.2 to analyze the context of the image and detect risks
 
 
74
  completion = self.client.chat.completions.create(
75
  model=self.model_name,
76
  messages=[
@@ -80,13 +93,13 @@ class RobustSafetyMonitor:
80
  {
81
  "type": "text",
82
  "text": """Analyze this workplace image and identify any potential safety risks.
83
- Consider the positioning of workers, the equipment, materials, and environment.
84
- Flag risks like improper equipment use, worker proximity to danger zones, unstable materials, and environmental hazards."""
85
  },
86
  {
87
  "type": "image_url",
88
  "image_url": {
89
- "url": image_base64 # Corrected: Now sending the proper data format
90
  }
91
  }
92
  ]
@@ -96,98 +109,134 @@ class RobustSafetyMonitor:
96
  max_tokens=1024,
97
  stream=False
98
  )
99
- # Process and parse the response correctly
100
- response = completion.choices[0].message['content']
101
- return self.parse_safety_analysis(response), response # Return parsed analysis and full response
 
 
 
 
 
 
102
 
103
  except Exception as e:
104
  print(f"Analysis error: {str(e)}")
105
- return f"Analysis Error: {str(e)}", {}
106
 
107
- def draw_bounding_boxes(self, image, bboxes, labels, safety_issues):
108
- """Draw bounding boxes around objects based on safety issues flagged by Llama Vision."""
 
 
109
  font = cv2.FONT_HERSHEY_SIMPLEX
110
  font_scale = 0.5
111
  thickness = 2
112
 
113
  for idx, bbox in enumerate(bboxes):
114
- x1, y1, x2, y2, conf, class_id = bbox
115
- label = labels[int(class_id)]
116
- color = self.colors[idx % len(self.colors)]
117
-
118
- # Draw bounding box
119
- cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
120
-
121
- # Link detected object to potential risks based on Llama Vision analysis
122
- for safety_issue in safety_issues:
123
- if safety_issue['object'].lower() in label.lower():
124
- label_text = f"Risk: {safety_issue['description']}"
125
- cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, (0, 0, 255), thickness)
126
- break
127
- else:
128
- label_text = f"{label} {conf:.2f}"
129
- cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, color, thickness)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- return image
132
 
133
- def process_frame(self, frame):
134
- """Main processing pipeline for dynamic, robust safety analysis."""
135
  if frame is None:
136
  return None, "No image provided"
137
 
138
  try:
139
- # Detect objects dynamically in the image using YOLO
140
  bbox_data, labels = self.detect_objects(frame)
141
- frame_with_boxes = self.draw_bounding_boxes(frame, bbox_data, labels, [])
142
-
143
- # Get dynamic safety analysis from Llama Vision 3.2
144
  safety_issues, analysis = self.analyze_frame(frame)
145
-
146
- # Update the frame with bounding boxes based on safety issues flagged
147
- annotated_frame = self.draw_bounding_boxes(frame_with_boxes, bbox_data, labels, safety_issues)
148
-
149
  return annotated_frame, analysis
150
 
151
  except Exception as e:
152
  print(f"Processing error: {str(e)}")
153
  return None, f"Error processing image: {str(e)}"
154
 
155
- def parse_safety_analysis(self, analysis):
156
- """Parse the safety analysis to identify contextual issues and link to objects."""
157
  safety_issues = []
 
 
 
 
158
  for line in analysis.split('\n'):
159
- if "risk" in line.lower() or "hazard" in line.lower():
160
- # Extract object involved and description
161
- parts = line.split(':', 1)
162
- if len(parts) == 2:
 
 
 
 
 
163
  safety_issues.append({
164
- "object": parts[0].strip(),
165
- "description": parts[1].strip()
166
  })
 
 
 
 
167
  return safety_issues
168
 
169
 
170
  def create_monitor_interface():
 
171
  monitor = RobustSafetyMonitor()
172
 
173
  with gr.Blocks() as demo:
174
- gr.Markdown("# Robust Safety Analysis System powered by Llama Vision 3.2")
 
175
 
176
  with gr.Row():
177
- input_image = gr.Image(label="Upload Image")
178
- output_image = gr.Image(label="Safety Analysis")
 
 
179
 
180
- analysis_text = gr.Textbox(label="Detailed Analysis", lines=5)
181
-
182
  def analyze_image(image):
183
  if image is None:
184
- return None, "No image provided"
185
  try:
186
  processed_frame, analysis = monitor.process_frame(image)
187
  return processed_frame, analysis
188
  except Exception as e:
189
- print(f"Processing error: {str(e)}")
190
- return None, f"Error processing image: {str(e)}"
191
 
192
  input_image.upload(
193
  fn=analyze_image,
@@ -196,14 +245,20 @@ def create_monitor_interface():
196
  )
197
 
198
  gr.Markdown("""
199
- ## Instructions:
200
- 1. Upload any workplace/safety-related image
201
- 2. View identified hazards and their locations
202
- 3. Read detailed analysis of safety concerns based on the image
 
 
 
 
 
 
203
  """)
204
 
205
  return demo
206
 
207
  if __name__ == "__main__":
208
  demo = create_monitor_interface()
209
- demo.launch()
 
6
  import io
7
  import base64
8
  import torch
9
+ import warnings
10
+ from typing import Tuple, List, Dict, Optional
11
 
12
+ # Suppress the CUDA autocast warning
13
+ warnings.filterwarnings('ignore', category=FutureWarning)
14
 
15
  class RobustSafetyMonitor:
16
  def __init__(self):
17
  """Initialize the robust safety detection tool with configuration."""
18
  self.client = Groq()
19
+ self.model_name = "llama-3.2-11b-vision-preview" # Updated to use the correct model
20
  self.max_image_size = (800, 800)
21
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
22
 
23
  # Load YOLOv5 model for general object detection
24
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
25
+
26
+ # Force CPU inference if CUDA is causing issues
27
+ self.yolo_model.cpu()
28
+ self.yolo_model.eval()
29
 
30
+ def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
31
  """Process image for analysis."""
32
+ if frame is None:
33
+ raise ValueError("No image provided")
34
+
35
  if len(frame.shape) == 2:
36
  frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
37
  elif len(frame.shape) == 3 and frame.shape[2] == 4:
 
39
 
40
  return self.resize_image(frame)
41
 
42
+ def resize_image(self, image: np.ndarray) -> np.ndarray:
43
  """Resize image while maintaining aspect ratio."""
44
  height, width = image.shape[:2]
45
  if height > self.max_image_size[1] or width > self.max_image_size[0]:
 
53
  return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
54
  return image
55
 
56
+ def encode_image(self, frame: np.ndarray) -> str:
57
  """Convert image to base64 encoding with proper formatting."""
58
+ try:
59
+ frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
60
+ buffered = io.BytesIO()
61
+ frame_pil.save(buffered, format="JPEG", quality=95)
62
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
63
+ return f"data:image/jpeg;base64,{img_base64}"
64
+ except Exception as e:
65
+ raise ValueError(f"Error encoding image: {str(e)}")
66
 
67
+ def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
68
  """Detect objects using YOLOv5."""
69
+ try:
70
+ with torch.no_grad():
71
+ results = self.yolo_model(frame)
72
+ bbox_data = results.xyxy[0].cpu().numpy()
73
+ labels = results.names
74
+ return bbox_data, labels
75
+ except Exception as e:
76
+ raise ValueError(f"Error detecting objects: {str(e)}")
77
+
78
+ def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
79
+ """Perform safety analysis on the frame using Llama Vision."""
80
  if frame is None:
81
+ return [], "No frame received"
82
 
 
 
 
83
  try:
84
+ frame = self.preprocess_image(frame)
85
+ image_base64 = self.encode_image(frame)
86
+
87
  completion = self.client.chat.completions.create(
88
  model=self.model_name,
89
  messages=[
 
93
  {
94
  "type": "text",
95
  "text": """Analyze this workplace image and identify any potential safety risks.
96
+ List each risk on a new line starting with 'Risk:'.
97
+ Format: Risk: [Object/Area] - [Description of hazard]"""
98
  },
99
  {
100
  "type": "image_url",
101
  "image_url": {
102
+ "url": image_base64
103
  }
104
  }
105
  ]
 
109
  max_tokens=1024,
110
  stream=False
111
  )
112
+
113
+ # Get the response content safely
114
+ try:
115
+ response = completion.choices[0].message.content
116
+ except AttributeError:
117
+ response = str(completion.choices[0].message)
118
+
119
+ safety_issues = self.parse_safety_analysis(response)
120
+ return safety_issues, response
121
 
122
  except Exception as e:
123
  print(f"Analysis error: {str(e)}")
124
+ return [], f"Analysis Error: {str(e)}"
125
 
126
+ def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
127
+ labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
128
+ """Draw bounding boxes around objects based on safety issues."""
129
+ image_copy = image.copy()
130
  font = cv2.FONT_HERSHEY_SIMPLEX
131
  font_scale = 0.5
132
  thickness = 2
133
 
134
  for idx, bbox in enumerate(bboxes):
135
+ try:
136
+ x1, y1, x2, y2, conf, class_id = bbox
137
+ label = labels[int(class_id)]
138
+ color = self.colors[idx % len(self.colors)]
139
+
140
+ # Convert coordinates to integers
141
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
142
+
143
+ # Draw bounding box
144
+ cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
145
+
146
+ # Check if object is associated with any safety issues
147
+ risk_found = False
148
+ for safety_issue in safety_issues:
149
+ if safety_issue.get('object', '').lower() in label.lower():
150
+ label_text = f"Risk: {safety_issue.get('description', '')}"
151
+ y_pos = max(y1 - 10, 20)
152
+ cv2.putText(image_copy, label_text, (x1, y_pos), font,
153
+ font_scale, (0, 0, 255), thickness)
154
+ risk_found = True
155
+ break
156
+
157
+ if not risk_found:
158
+ label_text = f"{label} {conf:.2f}"
159
+ y_pos = max(y1 - 10, 20)
160
+ cv2.putText(image_copy, label_text, (x1, y_pos), font,
161
+ font_scale, color, thickness)
162
+ except Exception as e:
163
+ print(f"Error drawing box: {str(e)}")
164
+ continue
165
 
166
+ return image_copy
167
 
168
+ def process_frame(self, frame: np.ndarray) -> Tuple[Optional[np.ndarray], str]:
169
+ """Main processing pipeline for safety analysis."""
170
  if frame is None:
171
  return None, "No image provided"
172
 
173
  try:
174
+ # Detect objects
175
  bbox_data, labels = self.detect_objects(frame)
176
+
177
+ # Get safety analysis
 
178
  safety_issues, analysis = self.analyze_frame(frame)
179
+
180
+ # Draw annotations
181
+ annotated_frame = self.draw_bounding_boxes(frame, bbox_data, labels, safety_issues)
182
+
183
  return annotated_frame, analysis
184
 
185
  except Exception as e:
186
  print(f"Processing error: {str(e)}")
187
  return None, f"Error processing image: {str(e)}"
188
 
189
+ def parse_safety_analysis(self, analysis: str) -> List[Dict]:
190
+ """Parse the safety analysis text into structured data."""
191
  safety_issues = []
192
+
193
+ if not isinstance(analysis, str):
194
+ return safety_issues
195
+
196
  for line in analysis.split('\n'):
197
+ if "risk:" in line.lower():
198
+ try:
199
+ # Extract object and description
200
+ parts = line.lower().split('risk:', 1)[1].strip()
201
+ if '-' in parts:
202
+ obj, desc = parts.split('-', 1)
203
+ else:
204
+ obj, desc = parts, parts
205
+
206
  safety_issues.append({
207
+ "object": obj.strip(),
208
+ "description": desc.strip()
209
  })
210
+ except Exception as e:
211
+ print(f"Error parsing line: {line}, Error: {str(e)}")
212
+ continue
213
+
214
  return safety_issues
215
 
216
 
217
  def create_monitor_interface():
218
+ """Create the Gradio interface for the safety monitoring system."""
219
  monitor = RobustSafetyMonitor()
220
 
221
  with gr.Blocks() as demo:
222
+ gr.Markdown("# Workplace Safety Analysis System")
223
+ gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
224
 
225
  with gr.Row():
226
+ input_image = gr.Image(label="Upload Workplace Image", type="numpy")
227
+ output_image = gr.Image(label="Safety Analysis Visualization")
228
+
229
+ analysis_text = gr.Textbox(label="Detailed Safety Analysis", lines=5)
230
 
 
 
231
  def analyze_image(image):
232
  if image is None:
233
+ return None, "Please upload an image"
234
  try:
235
  processed_frame, analysis = monitor.process_frame(image)
236
  return processed_frame, analysis
237
  except Exception as e:
238
+ print(f"Analysis error: {str(e)}")
239
+ return None, f"Error analyzing image: {str(e)}"
240
 
241
  input_image.upload(
242
  fn=analyze_image,
 
245
  )
246
 
247
  gr.Markdown("""
248
+ ## Instructions
249
+ 1. Upload a workplace image for safety analysis
250
+ 2. View detected hazards and their locations in the visualization
251
+ 3. Read the detailed safety analysis below the images
252
+
253
+ ## Features
254
+ - Real-time object detection
255
+ - AI-powered safety risk analysis
256
+ - Visual risk highlighting
257
+ - Detailed safety recommendations
258
  """)
259
 
260
  return demo
261
 
262
  if __name__ == "__main__":
263
  demo = create_monitor_interface()
264
+ demo.launch(share=True)