capradeepgujaran commited on
Commit
95ca446
·
verified ·
1 Parent(s): a217968

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -23
app.py CHANGED
@@ -9,8 +9,9 @@ import torch
9
  import warnings
10
  from typing import Tuple, List, Dict, Optional
11
 
12
- # Suppress the CUDA autocast warning
13
  warnings.filterwarnings('ignore', category=FutureWarning)
 
14
 
15
  class RobustSafetyMonitor:
16
  def __init__(self):
@@ -20,15 +21,59 @@ class RobustSafetyMonitor:
20
  self.max_image_size = (800, 800)
21
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
22
 
23
- # Load YOLOv5 model with improved configuration
24
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
25
- self.yolo_model.conf = 0.25 # Lower confidence threshold for more detections
26
  self.yolo_model.iou = 0.45 # Adjusted IOU threshold
27
  self.yolo_model.classes = None # Detect all classes
28
  self.yolo_model.max_det = 50 # Increased maximum detections
29
  self.yolo_model.cpu()
30
  self.yolo_model.eval()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
33
  """Enhanced object detection using YOLOv5."""
34
  try:
@@ -40,7 +85,7 @@ class RobustSafetyMonitor:
40
 
41
  # Run inference with augmentation
42
  with torch.no_grad():
43
- results = self.yolo_model(frame, augment=True) # Enable test-time augmentation
44
 
45
  # Get detections
46
  bbox_data = results.xyxy[0].cpu().numpy()
@@ -50,8 +95,7 @@ class RobustSafetyMonitor:
50
  processed_boxes = []
51
  for box in bbox_data:
52
  x1, y1, x2, y2, conf, cls = box
53
- # Additional filtering for construction site objects
54
- if conf > 0.25: # Keep lower confidence threshold for more detections
55
  processed_boxes.append(box)
56
 
57
  return np.array(processed_boxes), labels
@@ -59,6 +103,59 @@ class RobustSafetyMonitor:
59
  print(f"Error in object detection: {str(e)}")
60
  return np.array([]), {}
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
63
  labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
64
  """Improved bounding box visualization."""
@@ -67,28 +164,21 @@ class RobustSafetyMonitor:
67
  font_scale = 0.5
68
  thickness = 2
69
 
70
- # Define construction-related keywords for better object association
71
- construction_keywords = [
72
- 'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
73
- 'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
74
- 'construction', 'building', 'structure'
75
- ]
76
-
77
  for idx, bbox in enumerate(bboxes):
78
  try:
79
  x1, y1, x2, y2, conf, class_id = bbox
80
  label = labels[int(class_id)]
81
 
82
  # Check if object is construction-related
83
- is_relevant = any(keyword in label.lower() for keyword in construction_keywords)
84
 
85
- if is_relevant or conf > 0.35: # Higher threshold for non-construction objects
86
  color = self.colors[idx % len(self.colors)]
87
 
88
  # Convert coordinates to integers
89
  x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
90
 
91
- # Draw thicker bounding box for better visibility
92
  cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
93
 
94
  # Check for associated safety issues
@@ -108,8 +198,8 @@ class RobustSafetyMonitor:
108
  y_pos = max(y1 - 10, 20)
109
  cv2.putText(image_copy, label_text, (x1, y_pos), font,
110
  font_scale, color, thickness)
111
-
112
- # Draw additional markers for high-risk areas
113
  if conf > 0.5 and any(risk_word in label.lower() for risk_word in
114
  ['worker', 'person', 'equipment', 'machine']):
115
  cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
@@ -143,7 +233,7 @@ class RobustSafetyMonitor:
143
  return None, f"Error processing image: {str(e)}"
144
 
145
  def parse_safety_analysis(self, analysis: str) -> List[Dict]:
146
- """Parse the safety analysis text into structured data."""
147
  safety_issues = []
148
 
149
  if not isinstance(analysis, str):
@@ -152,7 +242,6 @@ class RobustSafetyMonitor:
152
  for line in analysis.split('\n'):
153
  if "risk:" in line.lower():
154
  try:
155
- # Extract object and description
156
  parts = line.lower().split('risk:', 1)[1].strip()
157
  if '-' in parts:
158
  obj, desc = parts.split('-', 1)
@@ -171,10 +260,10 @@ class RobustSafetyMonitor:
171
 
172
 
173
  def create_monitor_interface():
174
- """Create the Gradio interface for the safety monitoring system."""
175
  monitor = RobustSafetyMonitor()
176
 
177
- with gr.Blocks() as demo:
178
  gr.Markdown("# Workplace Safety Analysis System")
179
  gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
180
 
@@ -182,7 +271,12 @@ def create_monitor_interface():
182
  input_image = gr.Image(label="Upload Workplace Image", type="numpy")
183
  output_image = gr.Image(label="Safety Analysis Visualization")
184
 
185
- analysis_text = gr.Textbox(label="Detailed Safety Analysis", lines=5)
 
 
 
 
 
186
 
187
  def analyze_image(image):
188
  if image is None:
 
9
  import warnings
10
  from typing import Tuple, List, Dict, Optional
11
 
12
+ # Suppress warnings
13
  warnings.filterwarnings('ignore', category=FutureWarning)
14
+ warnings.filterwarnings('ignore', category=UserWarning)
15
 
16
  class RobustSafetyMonitor:
17
  def __init__(self):
 
21
  self.max_image_size = (800, 800)
22
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
23
 
24
+ # Load YOLOv5 with optimized settings
25
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
26
+ self.yolo_model.conf = 0.25 # Lower confidence threshold
27
  self.yolo_model.iou = 0.45 # Adjusted IOU threshold
28
  self.yolo_model.classes = None # Detect all classes
29
  self.yolo_model.max_det = 50 # Increased maximum detections
30
  self.yolo_model.cpu()
31
  self.yolo_model.eval()
32
 
33
+ # Construction-specific keywords
34
+ self.construction_keywords = [
35
+ 'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
36
+ 'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
37
+ 'construction', 'building', 'structure'
38
+ ]
39
+
40
+ def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
41
+ """Process image for analysis."""
42
+ if frame is None:
43
+ raise ValueError("No image provided")
44
+
45
+ if len(frame.shape) == 2:
46
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
47
+ elif len(frame.shape) == 3 and frame.shape[2] == 4:
48
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
49
+
50
+ return self.resize_image(frame)
51
+
52
+ def resize_image(self, image: np.ndarray) -> np.ndarray:
53
+ """Resize image while maintaining aspect ratio."""
54
+ height, width = image.shape[:2]
55
+ if height > self.max_image_size[1] or width > self.max_image_size[0]:
56
+ aspect = width / height
57
+ if width > height:
58
+ new_width = self.max_image_size[0]
59
+ new_height = int(new_width / aspect)
60
+ else:
61
+ new_height = self.max_image_size[1]
62
+ new_width = int(new_height * aspect)
63
+ return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
64
+ return image
65
+
66
+ def encode_image(self, frame: np.ndarray) -> str:
67
+ """Convert image to base64 encoding."""
68
+ try:
69
+ frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
70
+ buffered = io.BytesIO()
71
+ frame_pil.save(buffered, format="JPEG", quality=95)
72
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
73
+ return f"data:image/jpeg;base64,{img_base64}"
74
+ except Exception as e:
75
+ raise ValueError(f"Error encoding image: {str(e)}")
76
+
77
  def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
78
  """Enhanced object detection using YOLOv5."""
79
  try:
 
85
 
86
  # Run inference with augmentation
87
  with torch.no_grad():
88
+ results = self.yolo_model(frame, augment=True)
89
 
90
  # Get detections
91
  bbox_data = results.xyxy[0].cpu().numpy()
 
95
  processed_boxes = []
96
  for box in bbox_data:
97
  x1, y1, x2, y2, conf, cls = box
98
+ if conf > 0.25: # Keep lower confidence threshold
 
99
  processed_boxes.append(box)
100
 
101
  return np.array(processed_boxes), labels
 
103
  print(f"Error in object detection: {str(e)}")
104
  return np.array([]), {}
105
 
106
+ def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
107
+ """Perform safety analysis using Llama Vision."""
108
+ if frame is None:
109
+ return [], "No frame received"
110
+
111
+ try:
112
+ frame = self.preprocess_image(frame)
113
+ image_base64 = self.encode_image(frame)
114
+
115
+ completion = self.client.chat.completions.create(
116
+ model=self.model_name,
117
+ messages=[
118
+ {
119
+ "role": "user",
120
+ "content": [
121
+ {
122
+ "type": "text",
123
+ "text": """Analyze this workplace image for safety risks. Focus on:
124
+ 1. Worker posture and positioning
125
+ 2. Equipment and tool safety
126
+ 3. Environmental hazards
127
+ 4. PPE compliance
128
+ 5. Material handling
129
+
130
+ List each risk on a new line starting with 'Risk:'.
131
+ Format: Risk: [Object/Area] - [Detailed description of hazard]"""
132
+ },
133
+ {
134
+ "type": "image_url",
135
+ "image_url": {
136
+ "url": image_base64
137
+ }
138
+ }
139
+ ]
140
+ }
141
+ ],
142
+ temperature=0.7,
143
+ max_tokens=1024,
144
+ stream=False
145
+ )
146
+
147
+ try:
148
+ response = completion.choices[0].message.content
149
+ except AttributeError:
150
+ response = str(completion.choices[0].message)
151
+
152
+ safety_issues = self.parse_safety_analysis(response)
153
+ return safety_issues, response
154
+
155
+ except Exception as e:
156
+ print(f"Analysis error: {str(e)}")
157
+ return [], f"Analysis Error: {str(e)}"
158
+
159
  def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
160
  labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
161
  """Improved bounding box visualization."""
 
164
  font_scale = 0.5
165
  thickness = 2
166
 
 
 
 
 
 
 
 
167
  for idx, bbox in enumerate(bboxes):
168
  try:
169
  x1, y1, x2, y2, conf, class_id = bbox
170
  label = labels[int(class_id)]
171
 
172
  # Check if object is construction-related
173
+ is_relevant = any(keyword in label.lower() for keyword in self.construction_keywords)
174
 
175
+ if is_relevant or conf > 0.35:
176
  color = self.colors[idx % len(self.colors)]
177
 
178
  # Convert coordinates to integers
179
  x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
180
 
181
+ # Draw bounding box
182
  cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
183
 
184
  # Check for associated safety issues
 
198
  y_pos = max(y1 - 10, 20)
199
  cv2.putText(image_copy, label_text, (x1, y_pos), font,
200
  font_scale, color, thickness)
201
+
202
+ # Mark high-risk areas
203
  if conf > 0.5 and any(risk_word in label.lower() for risk_word in
204
  ['worker', 'person', 'equipment', 'machine']):
205
  cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
 
233
  return None, f"Error processing image: {str(e)}"
234
 
235
  def parse_safety_analysis(self, analysis: str) -> List[Dict]:
236
+ """Parse the safety analysis text."""
237
  safety_issues = []
238
 
239
  if not isinstance(analysis, str):
 
242
  for line in analysis.split('\n'):
243
  if "risk:" in line.lower():
244
  try:
 
245
  parts = line.lower().split('risk:', 1)[1].strip()
246
  if '-' in parts:
247
  obj, desc = parts.split('-', 1)
 
260
 
261
 
262
  def create_monitor_interface():
263
+ """Create the Gradio interface."""
264
  monitor = RobustSafetyMonitor()
265
 
266
+ with gr.Blocks(theme=gr.themes.Base()) as demo:
267
  gr.Markdown("# Workplace Safety Analysis System")
268
  gr.Markdown("Powered by Groq LLaVA Vision and YOLOv5")
269
 
 
271
  input_image = gr.Image(label="Upload Workplace Image", type="numpy")
272
  output_image = gr.Image(label="Safety Analysis Visualization")
273
 
274
+ with gr.Row():
275
+ analysis_text = gr.Textbox(
276
+ label="Detailed Safety Analysis",
277
+ lines=8,
278
+ placeholder="Safety analysis will appear here..."
279
+ )
280
 
281
  def analyze_image(image):
282
  if image is None: