capradeepgujaran commited on
Commit
b4f3ea6
·
verified ·
1 Parent(s): 740f7c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -16,16 +16,15 @@ def create_monitor_interface():
16
  def __init__(self):
17
  self.client = Groq()
18
  self.model_name = "llama-3.2-90b-vision-preview"
19
- self.max_image_size = (800, 800) # Increased size for better quality
20
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
21
  self.last_analysis_time = 0
22
- self.analysis_interval = 2 # Analyze every 2 seconds
23
- self.last_observations = [] # Store previous observations
24
 
25
  def resize_image(self, image):
26
  height, width = image.shape[:2]
27
 
28
- # Only resize if image is too large
29
  if height > self.max_image_size[1] or width > self.max_image_size[0]:
30
  aspect = width / height
31
  if width > height:
@@ -50,11 +49,10 @@ def create_monitor_interface():
50
  frame = self.resize_image(frame)
51
  frame_pil = PILImage.fromarray(frame)
52
 
53
- # Convert to base64 with better quality
54
  buffered = io.BytesIO()
55
  frame_pil.save(buffered,
56
  format="JPEG",
57
- quality=85, # Higher quality
58
  optimize=True)
59
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
60
  image_url = f"data:image/jpeg;base64,{img_base64}"
@@ -111,12 +109,10 @@ def create_monitor_interface():
111
  'bottom-right': (2*width//3, 2*height//3, width, height)
112
  }
113
 
114
- # Find the best matching region
115
  for region_name, coords in regions.items():
116
  if region_name in position.lower():
117
  return coords
118
 
119
- # Default to center if no match
120
  return regions['center']
121
 
122
  def draw_observations(self, image, observations):
@@ -128,7 +124,6 @@ def create_monitor_interface():
128
  for idx, obs in enumerate(observations):
129
  color = self.colors[idx % len(self.colors)]
130
 
131
- # Try to extract position from observation
132
  parts = obs.split(':')
133
  if len(parts) >= 2:
134
  position = parts[0]
@@ -137,17 +132,13 @@ def create_monitor_interface():
137
  position = 'center'
138
  description = obs
139
 
140
- # Get coordinates based on position
141
  x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
142
 
143
- # Draw rectangle
144
  cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
145
 
146
- # Add label with background
147
  label = description[:50] + "..." if len(description) > 50 else description
148
  label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
149
 
150
- # Ensure label stays within image bounds
151
  label_x = max(0, min(x1, width - label_size[0]))
152
  label_y = max(20, y1 - 5)
153
 
@@ -164,12 +155,10 @@ def create_monitor_interface():
164
 
165
  current_time = time.time()
166
 
167
- # Only perform analysis if enough time has passed
168
  if current_time - self.last_analysis_time >= self.analysis_interval:
169
  analysis = self.analyze_frame(frame)
170
  self.last_analysis_time = current_time
171
 
172
- # Parse observations
173
  observations = []
174
  for line in analysis.split('\n'):
175
  line = line.strip()
@@ -183,7 +172,6 @@ def create_monitor_interface():
183
 
184
  self.last_observations = observations
185
 
186
- # Draw observations on the frame
187
  display_frame = frame.copy()
188
  annotated_frame = self.draw_observations(display_frame, self.last_observations)
189
 
@@ -196,12 +184,12 @@ def create_monitor_interface():
196
  gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
197
 
198
  with gr.Row():
199
- webcam = gr.Image(source="webcam", streaming=True, label="Live Feed")
200
  output_image = gr.Image(label="Analysis")
201
 
202
  analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
203
 
204
- def analyze_stream(image):
205
  if image is None:
206
  return None, "No image provided"
207
  try:
@@ -211,12 +199,19 @@ def create_monitor_interface():
211
  print(f"Processing error: {str(e)}")
212
  return None, f"Error processing image: {str(e)}"
213
 
214
- webcam.stream(
215
- fn=analyze_stream,
216
- outputs=[output_image, analysis_text],
217
- show_progress=False
218
  )
219
 
 
 
 
 
 
 
 
220
  return demo
221
 
222
  demo = create_monitor_interface()
 
16
  def __init__(self):
17
  self.client = Groq()
18
  self.model_name = "llama-3.2-90b-vision-preview"
19
+ self.max_image_size = (800, 800)
20
  self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
21
  self.last_analysis_time = 0
22
+ self.analysis_interval = 2
23
+ self.last_observations = []
24
 
25
  def resize_image(self, image):
26
  height, width = image.shape[:2]
27
 
 
28
  if height > self.max_image_size[1] or width > self.max_image_size[0]:
29
  aspect = width / height
30
  if width > height:
 
49
  frame = self.resize_image(frame)
50
  frame_pil = PILImage.fromarray(frame)
51
 
 
52
  buffered = io.BytesIO()
53
  frame_pil.save(buffered,
54
  format="JPEG",
55
+ quality=85,
56
  optimize=True)
57
  img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
58
  image_url = f"data:image/jpeg;base64,{img_base64}"
 
109
  'bottom-right': (2*width//3, 2*height//3, width, height)
110
  }
111
 
 
112
  for region_name, coords in regions.items():
113
  if region_name in position.lower():
114
  return coords
115
 
 
116
  return regions['center']
117
 
118
  def draw_observations(self, image, observations):
 
124
  for idx, obs in enumerate(observations):
125
  color = self.colors[idx % len(self.colors)]
126
 
 
127
  parts = obs.split(':')
128
  if len(parts) >= 2:
129
  position = parts[0]
 
132
  position = 'center'
133
  description = obs
134
 
 
135
  x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
136
 
 
137
  cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
138
 
 
139
  label = description[:50] + "..." if len(description) > 50 else description
140
  label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
141
 
 
142
  label_x = max(0, min(x1, width - label_size[0]))
143
  label_y = max(20, y1 - 5)
144
 
 
155
 
156
  current_time = time.time()
157
 
 
158
  if current_time - self.last_analysis_time >= self.analysis_interval:
159
  analysis = self.analyze_frame(frame)
160
  self.last_analysis_time = current_time
161
 
 
162
  observations = []
163
  for line in analysis.split('\n'):
164
  line = line.strip()
 
172
 
173
  self.last_observations = observations
174
 
 
175
  display_frame = frame.copy()
176
  annotated_frame = self.draw_observations(display_frame, self.last_observations)
177
 
 
184
  gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
185
 
186
  with gr.Row():
187
+ input_image = gr.Image(label="Upload Image")
188
  output_image = gr.Image(label="Analysis")
189
 
190
  analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
191
 
192
+ def analyze_image(image):
193
  if image is None:
194
  return None, "No image provided"
195
  try:
 
199
  print(f"Processing error: {str(e)}")
200
  return None, f"Error processing image: {str(e)}"
201
 
202
+ input_image.change(
203
+ fn=analyze_image,
204
+ inputs=input_image,
205
+ outputs=[output_image, analysis_text]
206
  )
207
 
208
+ gr.Markdown("""
209
+ ## Instructions:
210
+ 1. Upload an image to analyze safety concerns
211
+ 2. View annotated results and detailed analysis
212
+ 3. Each box highlights a potential safety issue
213
+ """)
214
+
215
  return demo
216
 
217
  demo = create_monitor_interface()