capradeepgujaran commited on
Commit
aca1712
·
verified ·
1 Parent(s): 7f51b6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -90
app.py CHANGED
@@ -8,6 +8,8 @@ import io
8
  import os
9
  import base64
10
 
 
 
11
  class SafetyMonitor:
12
  def __init__(self):
13
  """Initialize Safety Monitor with configuration."""
@@ -15,6 +17,9 @@ class SafetyMonitor:
15
  self.model_name = "llama-3.2-90b-vision-preview"
16
  self.max_image_size = (800, 800)
17
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
 
 
 
18
 
19
  def preprocess_image(self, frame):
20
  """Process image for analysis."""
@@ -39,13 +44,13 @@ class SafetyMonitor:
39
  return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
40
  return image
41
 
42
- def encode_image(self, frame):
43
- """Convert image to base64 encoding."""
44
- frame_pil = PILImage.fromarray(frame)
45
- buffered = io.BytesIO()
46
- frame_pil.save(buffered, format="JPEG", quality=95)
47
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
48
- return f"data:image/jpeg;base64,{img_base64}"
49
 
50
  def analyze_frame(self, frame):
51
  """Perform safety analysis on the frame."""
@@ -54,7 +59,7 @@ class SafetyMonitor:
54
 
55
  frame = self.preprocess_image(frame)
56
  image_url = self.encode_image(frame)
57
-
58
  try:
59
  completion = self.client.chat.completions.create(
60
  model=self.model_name,
@@ -84,66 +89,22 @@ class SafetyMonitor:
84
  print(f"Analysis error: {str(e)}")
85
  return f"Analysis Error: {str(e)}", {}
86
 
87
- def get_region_coordinates(self, position, image_shape):
88
- """Convert textual position to coordinates."""
89
- height, width = image_shape[:2]
90
-
91
- # Define regions
92
- regions = {
93
- 'center': (width//3, height//3, 2*width//3, 2*height//3),
94
- 'top': (width//3, 0, 2*width//3, height//3),
95
- 'bottom': (width//3, 2*height//3, 2*width//3, height),
96
- 'left': (0, height//3, width//3, 2*height//3),
97
- 'right': (2*width//3, height//3, width, 2*height//3),
98
- 'top-left': (0, 0, width//3, height//3),
99
- 'top-right': (2*width//3, 0, width, height//3),
100
- 'bottom-left': (0, 2*height//3, width//3, height),
101
- 'bottom-right': (2*width//3, 2*height//3, width, height),
102
- 'upper': (0, 0, width, height//2),
103
- 'lower': (0, height//2, width, height),
104
- 'middle': (0, height//3, width, 2*height//3)
105
- }
106
-
107
- # Ensure the region name from the model output matches one of our predefined regions
108
- position = position.lower()
109
- return regions.get(position, (0, 0, width, height)) # Default to full image if no match
110
-
111
- def draw_observations(self, image, observations):
112
- """Draw bounding boxes and labels for safety observations."""
113
- height, width = image.shape[:2]
114
  font = cv2.FONT_HERSHEY_SIMPLEX
115
  font_scale = 0.5
116
  thickness = 2
117
- padding = 10
118
-
119
- for idx, obs in enumerate(observations):
120
  color = self.colors[idx % len(self.colors)]
121
 
122
- # Get coordinates for this observation
123
- x1, y1, x2, y2 = self.get_region_coordinates(obs['location'], image.shape)
124
- print(f"Drawing box at coordinates: ({x1}, {y1}, {x2}, {y2}) for {obs['description']}")
125
 
126
- # Draw rectangle
127
- cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
128
-
129
- # Add label with background
130
- label = obs['description'][:50] + "..." if len(obs['description']) > 50 else obs['description']
131
- label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
132
-
133
- # Position text above the box
134
- text_x = max(0, x1)
135
- text_y = max(label_size[1] + padding, y1 - padding)
136
-
137
- # Draw text background
138
- cv2.rectangle(image,
139
- (text_x, text_y - label_size[1] - padding),
140
- (text_x + label_size[0] + padding, text_y),
141
- color, -1)
142
-
143
- # Draw text
144
- cv2.putText(image, label,
145
- (text_x + padding//2, text_y - padding//2),
146
- font, font_scale, (255, 255, 255), thickness)
147
 
148
  return image
149
 
@@ -153,35 +114,13 @@ class SafetyMonitor:
153
  return None, "No image provided"
154
 
155
  try:
156
- # Get analysis
 
 
 
 
157
  analysis, _ = self.analyze_frame(frame)
158
- print(f"Raw analysis: {analysis}") # Debug print
159
- display_frame = frame.copy()
160
-
161
- # Parse observations
162
- observations = []
163
- for line in analysis.split('\n'):
164
- line = line.strip()
165
- if line.startswith('-') and '<location>' in line and '</location>' in line:
166
- start = line.find('<location>') + len('<location>')
167
- end = line.find('</location>')
168
- location_description = line[start:end].strip()
169
-
170
- if ':' in location_description:
171
- location, description = location_description.split(':', 1)
172
- observations.append({
173
- 'location': location.strip(),
174
- 'description': description.strip()
175
- })
176
-
177
- print(f"Parsed observations: {observations}") # Debug print
178
-
179
- # Draw observations
180
- if observations:
181
- annotated_frame = self.draw_observations(display_frame, observations)
182
- return annotated_frame, analysis
183
-
184
- return display_frame, analysis
185
 
186
  except Exception as e:
187
  print(f"Processing error: {str(e)}")
 
8
  import os
9
  import base64
10
 
11
+ import torch
12
+
13
  class SafetyMonitor:
14
  def __init__(self):
15
  """Initialize Safety Monitor with configuration."""
 
17
  self.model_name = "llama-3.2-90b-vision-preview"
18
  self.max_image_size = (800, 800)
19
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
20
+
21
+ # Load YOLOv5 model for object detection
22
+ self.yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
23
 
24
  def preprocess_image(self, frame):
25
  """Process image for analysis."""
 
44
  return cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
45
  return image
46
 
47
+ def detect_objects(self, frame):
48
+ """Detect objects using YOLOv5."""
49
+ results = self.yolo_model(frame)
50
+ # Extract bounding boxes, class labels, and confidence scores
51
+ bbox_data = results.xyxy[0].numpy() # Bounding box coordinates
52
+ labels = results.names # Class names
53
+ return bbox_data, labels
54
 
55
  def analyze_frame(self, frame):
56
  """Perform safety analysis on the frame."""
 
59
 
60
  frame = self.preprocess_image(frame)
61
  image_url = self.encode_image(frame)
62
+
63
  try:
64
  completion = self.client.chat.completions.create(
65
  model=self.model_name,
 
89
  print(f"Analysis error: {str(e)}")
90
  return f"Analysis Error: {str(e)}", {}
91
 
92
+ def draw_bounding_boxes(self, image, bboxes, labels):
93
+ """Draw bounding boxes around detected objects."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  font = cv2.FONT_HERSHEY_SIMPLEX
95
  font_scale = 0.5
96
  thickness = 2
97
+ for idx, bbox in enumerate(bboxes):
98
+ x1, y1, x2, y2, conf, class_id = bbox
99
+ label = labels[int(class_id)]
100
  color = self.colors[idx % len(self.colors)]
101
 
102
+ # Draw bounding box
103
+ cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
 
104
 
105
+ # Draw label
106
+ label_text = f"{label} {conf:.2f}"
107
+ cv2.putText(image, label_text, (int(x1), int(y1) - 10), font, font_scale, color, thickness)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  return image
110
 
 
114
  return None, "No image provided"
115
 
116
  try:
117
+ # Detect objects in the image using YOLO
118
+ bbox_data, labels = self.detect_objects(frame)
119
+ frame_with_boxes = self.draw_bounding_boxes(frame, bbox_data, labels)
120
+
121
+ # Get analysis from Groq's model
122
  analysis, _ = self.analyze_frame(frame)
123
+ return frame_with_boxes, analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  except Exception as e:
126
  print(f"Processing error: {str(e)}")