capradeepgujaran commited on
Commit
a217968
·
verified ·
1 Parent(s): d2c67f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -113
app.py CHANGED
@@ -14,151 +14,107 @@ warnings.filterwarnings('ignore', category=FutureWarning)
14
 
15
  class RobustSafetyMonitor:
16
  def __init__(self):
17
- """Initialize the robust safety detection tool with configuration."""
18
  self.client = Groq()
19
- self.model_name = "llama-3.2-11b-vision-preview" # Updated to use the correct model
20
  self.max_image_size = (800, 800)
21
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
22
 
23
- # Load YOLOv5 model for general object detection
24
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
25
-
26
- # Force CPU inference if CUDA is causing issues
 
 
27
  self.yolo_model.cpu()
28
  self.yolo_model.eval()
29
 
30
- def preprocess_image(self, frame: np.ndarray) -> np.ndarray:
31
- """Process image for analysis."""
32
- if frame is None:
33
- raise ValueError("No image provided")
34
-
35
- if len(frame.shape) == 2:
36
- frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
37
- elif len(frame.shape) == 3 and frame.shape[2] == 4:
38
- frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
39
-
40
- return self.resize_image(frame)
41
-
42
- def resize_image(self, image: np.ndarray) -> np.ndarray:
43
- """Resize image while maintaining aspect ratio."""
44
- height, width = image.shape[:2]
45
- if height > self.max_image_size[1] or width > self.max_image_size[0]:
46
- aspect = width / height
47
- if width > height:
48
- new_width = self.max_image_size[0]
49
- new_height = int(new_width / aspect)
50
- else:
51
- new_height = self.max_image_size[1]
52
- new_width = int(new_height * aspect)
53
- return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
54
- return image
55
-
56
- def encode_image(self, frame: np.ndarray) -> str:
57
- """Convert image to base64 encoding with proper formatting."""
58
- try:
59
- frame_pil = PILImage.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
60
- buffered = io.BytesIO()
61
- frame_pil.save(buffered, format="JPEG", quality=95)
62
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
63
- return f"data:image/jpeg;base64,{img_base64}"
64
- except Exception as e:
65
- raise ValueError(f"Error encoding image: {str(e)}")
66
-
67
  def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
68
- """Detect objects using YOLOv5."""
69
  try:
 
 
 
 
 
 
 
70
  with torch.no_grad():
71
- results = self.yolo_model(frame)
 
 
72
  bbox_data = results.xyxy[0].cpu().numpy()
73
  labels = results.names
74
- return bbox_data, labels
75
- except Exception as e:
76
- raise ValueError(f"Error detecting objects: {str(e)}")
77
-
78
- def analyze_frame(self, frame: np.ndarray) -> Tuple[List[Dict], str]:
79
- """Perform safety analysis on the frame using Llama Vision."""
80
- if frame is None:
81
- return [], "No frame received"
82
-
83
- try:
84
- frame = self.preprocess_image(frame)
85
- image_base64 = self.encode_image(frame)
86
-
87
- completion = self.client.chat.completions.create(
88
- model=self.model_name,
89
- messages=[
90
- {
91
- "role": "user",
92
- "content": [
93
- {
94
- "type": "text",
95
- "text": """Analyze this workplace image and identify any potential safety risks.
96
- List each risk on a new line starting with 'Risk:'.
97
- Format: Risk: [Object/Area] - [Description of hazard]"""
98
- },
99
- {
100
- "type": "image_url",
101
- "image_url": {
102
- "url": image_base64
103
- }
104
- }
105
- ]
106
- }
107
- ],
108
- temperature=0.7,
109
- max_tokens=1024,
110
- stream=False
111
- )
112
 
113
- # Get the response content safely
114
- try:
115
- response = completion.choices[0].message.content
116
- except AttributeError:
117
- response = str(completion.choices[0].message)
118
-
119
- safety_issues = self.parse_safety_analysis(response)
120
- return safety_issues, response
121
-
122
  except Exception as e:
123
- print(f"Analysis error: {str(e)}")
124
- return [], f"Analysis Error: {str(e)}"
125
 
126
  def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
127
  labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
128
- """Draw bounding boxes around objects based on safety issues."""
129
  image_copy = image.copy()
130
  font = cv2.FONT_HERSHEY_SIMPLEX
131
  font_scale = 0.5
132
  thickness = 2
133
 
 
 
 
 
 
 
 
134
  for idx, bbox in enumerate(bboxes):
135
  try:
136
  x1, y1, x2, y2, conf, class_id = bbox
137
  label = labels[int(class_id)]
138
- color = self.colors[idx % len(self.colors)]
139
-
140
- # Convert coordinates to integers
141
- x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
142
 
143
- # Draw bounding box
144
- cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
145
 
146
- # Check if object is associated with any safety issues
147
- risk_found = False
148
- for safety_issue in safety_issues:
149
- if safety_issue.get('object', '').lower() in label.lower():
150
- label_text = f"Risk: {safety_issue.get('description', '')}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  y_pos = max(y1 - 10, 20)
152
  cv2.putText(image_copy, label_text, (x1, y_pos), font,
153
- font_scale, (0, 0, 255), thickness)
154
- risk_found = True
155
- break
156
 
157
- if not risk_found:
158
- label_text = f"{label} {conf:.2f}"
159
- y_pos = max(y1 - 10, 20)
160
- cv2.putText(image_copy, label_text, (x1, y_pos), font,
161
- font_scale, color, thickness)
 
162
  except Exception as e:
163
  print(f"Error drawing box: {str(e)}")
164
  continue
 
14
 
15
  class RobustSafetyMonitor:
16
  def __init__(self):
17
+ """Initialize the safety detection tool with improved configuration."""
18
  self.client = Groq()
19
+ self.model_name = "llama-3.2-11b-vision-preview"
20
  self.max_image_size = (800, 800)
21
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
22
 
23
+ # Load YOLOv5 model with improved configuration
24
  self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
25
+ self.yolo_model.conf = 0.25 # Lower confidence threshold for more detections
26
+ self.yolo_model.iou = 0.45 # Adjusted IOU threshold
27
+ self.yolo_model.classes = None # Detect all classes
28
+ self.yolo_model.max_det = 50 # Increased maximum detections
29
  self.yolo_model.cpu()
30
  self.yolo_model.eval()
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def detect_objects(self, frame: np.ndarray) -> Tuple[np.ndarray, Dict]:
33
+ """Enhanced object detection using YOLOv5."""
34
  try:
35
+ # Ensure proper image format
36
+ if len(frame.shape) == 2:
37
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
38
+ elif frame.shape[2] == 4:
39
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
40
+
41
+ # Run inference with augmentation
42
  with torch.no_grad():
43
+ results = self.yolo_model(frame, augment=True) # Enable test-time augmentation
44
+
45
+ # Get detections
46
  bbox_data = results.xyxy[0].cpu().numpy()
47
  labels = results.names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Filter and process detections
50
+ processed_boxes = []
51
+ for box in bbox_data:
52
+ x1, y1, x2, y2, conf, cls = box
53
+ # Additional filtering for construction site objects
54
+ if conf > 0.25: # Keep lower confidence threshold for more detections
55
+ processed_boxes.append(box)
56
+
57
+ return np.array(processed_boxes), labels
58
  except Exception as e:
59
+ print(f"Error in object detection: {str(e)}")
60
+ return np.array([]), {}
61
 
62
  def draw_bounding_boxes(self, image: np.ndarray, bboxes: np.ndarray,
63
  labels: Dict, safety_issues: List[Dict]) -> np.ndarray:
64
+ """Improved bounding box visualization."""
65
  image_copy = image.copy()
66
  font = cv2.FONT_HERSHEY_SIMPLEX
67
  font_scale = 0.5
68
  thickness = 2
69
 
70
+ # Define construction-related keywords for better object association
71
+ construction_keywords = [
72
+ 'person', 'worker', 'helmet', 'tool', 'machine', 'equipment',
73
+ 'brick', 'block', 'pile', 'stack', 'surface', 'floor', 'ground',
74
+ 'construction', 'building', 'structure'
75
+ ]
76
+
77
  for idx, bbox in enumerate(bboxes):
78
  try:
79
  x1, y1, x2, y2, conf, class_id = bbox
80
  label = labels[int(class_id)]
 
 
 
 
81
 
82
+ # Check if object is construction-related
83
+ is_relevant = any(keyword in label.lower() for keyword in construction_keywords)
84
 
85
+ if is_relevant or conf > 0.35: # Higher threshold for non-construction objects
86
+ color = self.colors[idx % len(self.colors)]
87
+
88
+ # Convert coordinates to integers
89
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
90
+
91
+ # Draw thicker bounding box for better visibility
92
+ cv2.rectangle(image_copy, (x1, y1), (x2, y2), color, thickness)
93
+
94
+ # Check for associated safety issues
95
+ risk_found = False
96
+ for safety_issue in safety_issues:
97
+ issue_keywords = safety_issue.get('object', '').lower().split()
98
+ if any(keyword in label.lower() for keyword in issue_keywords):
99
+ label_text = f"Risk: {safety_issue.get('description', '')}"
100
+ y_pos = max(y1 - 10, 20)
101
+ cv2.putText(image_copy, label_text, (x1, y_pos), font,
102
+ font_scale, (0, 0, 255), thickness)
103
+ risk_found = True
104
+ break
105
+
106
+ if not risk_found:
107
+ label_text = f"{label} {conf:.2f}"
108
  y_pos = max(y1 - 10, 20)
109
  cv2.putText(image_copy, label_text, (x1, y_pos), font,
110
+ font_scale, color, thickness)
 
 
111
 
112
+ # Draw additional markers for high-risk areas
113
+ if conf > 0.5 and any(risk_word in label.lower() for risk_word in
114
+ ['worker', 'person', 'equipment', 'machine']):
115
+ cv2.circle(image_copy, (int((x1 + x2)/2), int((y1 + y2)/2)),
116
+ 5, (0, 0, 255), -1)
117
+
118
  except Exception as e:
119
  print(f"Error drawing box: {str(e)}")
120
  continue