capradeepgujaran commited on
Commit
a5f647b
·
verified ·
1 Parent(s): f6cffbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -50
app.py CHANGED
@@ -5,20 +5,18 @@ from groq import Groq
5
  import time
6
  from PIL import Image as PILImage
7
  import io
8
- import os
9
  import base64
10
  import torch
11
 
12
-
13
- class SafetyMonitor:
14
  def __init__(self):
15
- """Initialize Safety Monitor with configuration."""
16
  self.client = Groq()
17
  self.model_name = "llama-3.2-90b-vision-preview"
18
  self.max_image_size = (800, 800)
19
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
20
 
21
- # Load YOLOv5 model for object detection
22
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
23
 
24
  def preprocess_image(self, frame):
@@ -61,14 +59,15 @@ class SafetyMonitor:
61
  return bbox_data, labels
62
 
63
  def analyze_frame(self, frame):
64
- """Perform safety analysis on the frame."""
65
  if frame is None:
66
  return "No frame received", {}
67
 
68
  frame = self.preprocess_image(frame)
69
- image_url = self.encode_image(frame)
70
 
71
  try:
 
72
  completion = self.client.chat.completions.create(
73
  model=self.model_name,
74
  messages=[
@@ -77,19 +76,21 @@ class SafetyMonitor:
77
  "content": [
78
  {
79
  "type": "text",
80
- "text": "Identify and list safety concerns in this workplace image. For each issue found, include its location and specific safety concern. Look for hazards related to PPE, ergonomics, equipment, environment, and work procedures."
 
 
81
  },
82
  {
83
  "type": "image_url",
84
  "image_url": {
85
- "url": image_url
86
  }
87
  }
88
  ]
89
  }
90
  ],
91
  temperature=0.7,
92
- max_tokens=500,
93
  stream=False
94
  )
95
  return completion.choices[0].message.content, {}
@@ -97,11 +98,12 @@ class SafetyMonitor:
97
  print(f"Analysis error: {str(e)}")
98
  return f"Analysis Error: {str(e)}", {}
99
 
100
- def draw_bounding_boxes(self, image, bboxes, labels):
101
- """Draw bounding boxes around detected objects."""
102
  font = cv2.FONT_HERSHEY_SIMPLEX
103
  font_scale = 0.5
104
  thickness = 2
 
105
  for idx, bbox in enumerate(bboxes):
106
  x1, y1, x2, y2, conf, class_id = bbox
107
  label = labels[int(class_id)]
@@ -110,68 +112,55 @@ class SafetyMonitor:
110
  # Draw bounding box
111
  cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
112
 
113
- # Draw label
114
- label_text = f"{label} {conf:.2f}"
115
- cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, color, thickness)
116
-
 
 
 
 
117
  return image
118
 
119
  def process_frame(self, frame):
120
- """Main processing pipeline for dynamic safety analysis."""
121
  if frame is None:
122
  return None, "No image provided"
123
 
124
  try:
125
  # Detect objects dynamically in the image using YOLO
126
  bbox_data, labels = self.detect_objects(frame)
127
- frame_with_boxes = self.draw_bounding_boxes(frame, bbox_data, labels)
128
-
129
- # Get dynamic safety analysis from Groq's model
130
  analysis, _ = self.analyze_frame(frame)
131
-
132
- # Dynamically parse the analysis to find any safety issues flagged
133
  safety_issues = self.parse_safety_analysis(analysis)
134
 
135
- # Dynamically link detected objects to safety issues
136
- for issue in safety_issues:
137
- if 'helmet' in issue.lower():
138
- for idx, bbox in enumerate(bbox_data):
139
- x1, y1, x2, y2, conf, class_id = bbox
140
- if labels[int(class_id)] == 'person':
141
- # Dynamically label the missing helmet issue for detected persons
142
- cv2.putText(frame_with_boxes, "No Helmet!", (int(x1), int(y1) - 20),
143
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
144
- cv2.rectangle(frame_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
145
-
146
- # Add more dynamic checks here for gloves, boots, etc.
147
- if 'glove' in issue.lower():
148
- for idx, bbox in enumerate(bbox_data):
149
- x1, y1, x2, y2, conf, class_id = bbox
150
- if labels[int(class_id)] == 'person':
151
- # Dynamically label missing gloves for detected persons
152
- cv2.putText(frame_with_boxes, "No Gloves!", (int(x1), int(y1) - 20),
153
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
154
- cv2.rectangle(frame_with_boxes, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 255), 2)
155
-
156
- return frame_with_boxes, analysis
157
 
158
  except Exception as e:
159
  print(f"Processing error: {str(e)}")
160
  return None, f"Error processing image: {str(e)}"
161
 
162
  def parse_safety_analysis(self, analysis):
163
- """Dynamically parse the safety analysis to identify issues."""
164
  safety_issues = []
165
  for line in analysis.split('\n'):
166
- if "missing" in line.lower() or "no" in line.lower():
167
  safety_issues.append(line.strip())
168
  return safety_issues
169
 
 
170
  def create_monitor_interface():
171
- monitor = SafetyMonitor()
172
 
173
  with gr.Blocks() as demo:
174
- gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
175
 
176
  with gr.Row():
177
  input_image = gr.Image(label="Upload Image")
@@ -199,7 +188,7 @@ def create_monitor_interface():
199
  ## Instructions:
200
  1. Upload any workplace/safety-related image
201
  2. View identified hazards and their locations
202
- 3. Read detailed analysis of safety concerns
203
  """)
204
 
205
  return demo
@@ -207,4 +196,3 @@ def create_monitor_interface():
207
  if __name__ == "__main__":
208
  demo = create_monitor_interface()
209
  demo.launch()
210
-
 
5
  import time
6
  from PIL import Image as PILImage
7
  import io
 
8
  import base64
9
  import torch
10
 
11
+ class RobustSafetyMonitor:
 
12
  def __init__(self):
13
+ """Initialize the robust safety detection tool with configuration."""
14
  self.client = Groq()
15
  self.model_name = "llama-3.2-90b-vision-preview"
16
  self.max_image_size = (800, 800)
17
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
18
 
19
+ # Load YOLOv5 model for general object detection
20
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
21
 
22
  def preprocess_image(self, frame):
 
59
  return bbox_data, labels
60
 
61
  def analyze_frame(self, frame):
62
+ """Perform safety analysis on the frame using Llama Vision 3.2."""
63
  if frame is None:
64
  return "No frame received", {}
65
 
66
  frame = self.preprocess_image(frame)
67
+ image_base64 = self.encode_image(frame)
68
 
69
  try:
70
+ # Use Llama Vision 3.2 to analyze the context of the image and detect risks
71
  completion = self.client.chat.completions.create(
72
  model=self.model_name,
73
  messages=[
 
76
  "content": [
77
  {
78
  "type": "text",
79
+ "text": """Analyze this workplace image and identify any potential safety risks.
80
+ Consider the positioning of workers, the equipment, materials, and environment.
81
+ Flag risks like improper equipment use, worker proximity to danger zones, unstable materials, and environmental hazards."""
82
  },
83
  {
84
  "type": "image_url",
85
  "image_url": {
86
+ "url": f"data:image/jpeg;base64,{image_base64}" # Use base64 for image
87
  }
88
  }
89
  ]
90
  }
91
  ],
92
  temperature=0.7,
93
+ max_tokens=1024,
94
  stream=False
95
  )
96
  return completion.choices[0].message.content, {}
 
98
  print(f"Analysis error: {str(e)}")
99
  return f"Analysis Error: {str(e)}", {}
100
 
101
+ def draw_bounding_boxes(self, image, bboxes, labels, safety_issues):
102
+ """Draw bounding boxes around objects based on safety issues flagged by Llama Vision."""
103
  font = cv2.FONT_HERSHEY_SIMPLEX
104
  font_scale = 0.5
105
  thickness = 2
106
+
107
  for idx, bbox in enumerate(bboxes):
108
  x1, y1, x2, y2, conf, class_id = bbox
109
  label = labels[int(class_id)]
 
112
  # Draw bounding box
113
  cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
114
 
115
+ # Link detected object to potential risks based on Llama Vision analysis
116
+ if any(safety_issue.lower() in label.lower() for safety_issue in safety_issues):
117
+ label_text = f"Risk: {label}"
118
+ cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, (0, 0, 255), thickness)
119
+ else:
120
+ label_text = f"{label} {conf:.2f}"
121
+ cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, color, thickness)
122
+
123
  return image
124
 
125
  def process_frame(self, frame):
126
+ """Main processing pipeline for dynamic, robust safety analysis."""
127
  if frame is None:
128
  return None, "No image provided"
129
 
130
  try:
131
  # Detect objects dynamically in the image using YOLO
132
  bbox_data, labels = self.detect_objects(frame)
133
+ frame_with_boxes = self.draw_bounding_boxes(frame, bbox_data, labels, [])
134
+
135
+ # Get dynamic safety analysis from Llama Vision 3.2
136
  analysis, _ = self.analyze_frame(frame)
137
+
138
+ # Dynamically parse the analysis to identify safety issues flagged
139
  safety_issues = self.parse_safety_analysis(analysis)
140
 
141
+ # Update the frame with bounding boxes based on safety issues flagged
142
+ annotated_frame = self.draw_bounding_boxes(frame_with_boxes, bbox_data, labels, safety_issues)
143
+
144
+ return annotated_frame, analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  except Exception as e:
147
  print(f"Processing error: {str(e)}")
148
  return None, f"Error processing image: {str(e)}"
149
 
150
  def parse_safety_analysis(self, analysis):
151
+ """Dynamically parse the safety analysis to identify contextual issues."""
152
  safety_issues = []
153
  for line in analysis.split('\n'):
154
+ if "risk" in line.lower() or "hazard" in line.lower():
155
  safety_issues.append(line.strip())
156
  return safety_issues
157
 
158
+
159
  def create_monitor_interface():
160
+ monitor = RobustSafetyMonitor()
161
 
162
  with gr.Blocks() as demo:
163
+ gr.Markdown("# Robust Safety Analysis System powered by Llama Vision 3.2")
164
 
165
  with gr.Row():
166
  input_image = gr.Image(label="Upload Image")
 
188
  ## Instructions:
189
  1. Upload any workplace/safety-related image
190
  2. View identified hazards and their locations
191
+ 3. Read detailed analysis of safety concerns based on the image
192
  """)
193
 
194
  return demo
 
196
  if __name__ == "__main__":
197
  demo = create_monitor_interface()
198
  demo.launch()