Update app.py
Browse files
app.py
CHANGED
@@ -16,16 +16,15 @@ def create_monitor_interface():
|
|
16 |
def __init__(self):
|
17 |
self.client = Groq()
|
18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
19 |
-
self.max_image_size = (800, 800)
|
20 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
21 |
self.last_analysis_time = 0
|
22 |
-
self.analysis_interval = 2
|
23 |
-
self.last_observations = []
|
24 |
|
25 |
def resize_image(self, image):
|
26 |
height, width = image.shape[:2]
|
27 |
|
28 |
-
# Only resize if image is too large
|
29 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
30 |
aspect = width / height
|
31 |
if width > height:
|
@@ -50,11 +49,10 @@ def create_monitor_interface():
|
|
50 |
frame = self.resize_image(frame)
|
51 |
frame_pil = PILImage.fromarray(frame)
|
52 |
|
53 |
-
# Convert to base64 with better quality
|
54 |
buffered = io.BytesIO()
|
55 |
frame_pil.save(buffered,
|
56 |
format="JPEG",
|
57 |
-
quality=85,
|
58 |
optimize=True)
|
59 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
60 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
@@ -111,12 +109,10 @@ def create_monitor_interface():
|
|
111 |
'bottom-right': (2*width//3, 2*height//3, width, height)
|
112 |
}
|
113 |
|
114 |
-
# Find the best matching region
|
115 |
for region_name, coords in regions.items():
|
116 |
if region_name in position.lower():
|
117 |
return coords
|
118 |
|
119 |
-
# Default to center if no match
|
120 |
return regions['center']
|
121 |
|
122 |
def draw_observations(self, image, observations):
|
@@ -128,7 +124,6 @@ def create_monitor_interface():
|
|
128 |
for idx, obs in enumerate(observations):
|
129 |
color = self.colors[idx % len(self.colors)]
|
130 |
|
131 |
-
# Try to extract position from observation
|
132 |
parts = obs.split(':')
|
133 |
if len(parts) >= 2:
|
134 |
position = parts[0]
|
@@ -137,17 +132,13 @@ def create_monitor_interface():
|
|
137 |
position = 'center'
|
138 |
description = obs
|
139 |
|
140 |
-
# Get coordinates based on position
|
141 |
x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
|
142 |
|
143 |
-
# Draw rectangle
|
144 |
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
145 |
|
146 |
-
# Add label with background
|
147 |
label = description[:50] + "..." if len(description) > 50 else description
|
148 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
149 |
|
150 |
-
# Ensure label stays within image bounds
|
151 |
label_x = max(0, min(x1, width - label_size[0]))
|
152 |
label_y = max(20, y1 - 5)
|
153 |
|
@@ -164,12 +155,10 @@ def create_monitor_interface():
|
|
164 |
|
165 |
current_time = time.time()
|
166 |
|
167 |
-
# Only perform analysis if enough time has passed
|
168 |
if current_time - self.last_analysis_time >= self.analysis_interval:
|
169 |
analysis = self.analyze_frame(frame)
|
170 |
self.last_analysis_time = current_time
|
171 |
|
172 |
-
# Parse observations
|
173 |
observations = []
|
174 |
for line in analysis.split('\n'):
|
175 |
line = line.strip()
|
@@ -183,7 +172,6 @@ def create_monitor_interface():
|
|
183 |
|
184 |
self.last_observations = observations
|
185 |
|
186 |
-
# Draw observations on the frame
|
187 |
display_frame = frame.copy()
|
188 |
annotated_frame = self.draw_observations(display_frame, self.last_observations)
|
189 |
|
@@ -196,12 +184,12 @@ def create_monitor_interface():
|
|
196 |
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
|
197 |
|
198 |
with gr.Row():
|
199 |
-
|
200 |
output_image = gr.Image(label="Analysis")
|
201 |
|
202 |
analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
|
203 |
|
204 |
-
def
|
205 |
if image is None:
|
206 |
return None, "No image provided"
|
207 |
try:
|
@@ -211,12 +199,19 @@ def create_monitor_interface():
|
|
211 |
print(f"Processing error: {str(e)}")
|
212 |
return None, f"Error processing image: {str(e)}"
|
213 |
|
214 |
-
|
215 |
-
fn=
|
216 |
-
|
217 |
-
|
218 |
)
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
return demo
|
221 |
|
222 |
demo = create_monitor_interface()
|
|
|
16 |
def __init__(self):
|
17 |
self.client = Groq()
|
18 |
self.model_name = "llama-3.2-90b-vision-preview"
|
19 |
+
self.max_image_size = (800, 800)
|
20 |
self.colors = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 255, 0), (255, 0, 255)]
|
21 |
self.last_analysis_time = 0
|
22 |
+
self.analysis_interval = 2
|
23 |
+
self.last_observations = []
|
24 |
|
25 |
def resize_image(self, image):
|
26 |
height, width = image.shape[:2]
|
27 |
|
|
|
28 |
if height > self.max_image_size[1] or width > self.max_image_size[0]:
|
29 |
aspect = width / height
|
30 |
if width > height:
|
|
|
49 |
frame = self.resize_image(frame)
|
50 |
frame_pil = PILImage.fromarray(frame)
|
51 |
|
|
|
52 |
buffered = io.BytesIO()
|
53 |
frame_pil.save(buffered,
|
54 |
format="JPEG",
|
55 |
+
quality=85,
|
56 |
optimize=True)
|
57 |
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
58 |
image_url = f"data:image/jpeg;base64,{img_base64}"
|
|
|
109 |
'bottom-right': (2*width//3, 2*height//3, width, height)
|
110 |
}
|
111 |
|
|
|
112 |
for region_name, coords in regions.items():
|
113 |
if region_name in position.lower():
|
114 |
return coords
|
115 |
|
|
|
116 |
return regions['center']
|
117 |
|
118 |
def draw_observations(self, image, observations):
|
|
|
124 |
for idx, obs in enumerate(observations):
|
125 |
color = self.colors[idx % len(self.colors)]
|
126 |
|
|
|
127 |
parts = obs.split(':')
|
128 |
if len(parts) >= 2:
|
129 |
position = parts[0]
|
|
|
132 |
position = 'center'
|
133 |
description = obs
|
134 |
|
|
|
135 |
x1, y1, x2, y2 = self.get_region_coordinates(position, image.shape)
|
136 |
|
|
|
137 |
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
138 |
|
|
|
139 |
label = description[:50] + "..." if len(description) > 50 else description
|
140 |
label_size = cv2.getTextSize(label, font, font_scale, thickness)[0]
|
141 |
|
|
|
142 |
label_x = max(0, min(x1, width - label_size[0]))
|
143 |
label_y = max(20, y1 - 5)
|
144 |
|
|
|
155 |
|
156 |
current_time = time.time()
|
157 |
|
|
|
158 |
if current_time - self.last_analysis_time >= self.analysis_interval:
|
159 |
analysis = self.analyze_frame(frame)
|
160 |
self.last_analysis_time = current_time
|
161 |
|
|
|
162 |
observations = []
|
163 |
for line in analysis.split('\n'):
|
164 |
line = line.strip()
|
|
|
172 |
|
173 |
self.last_observations = observations
|
174 |
|
|
|
175 |
display_frame = frame.copy()
|
176 |
annotated_frame = self.draw_observations(display_frame, self.last_observations)
|
177 |
|
|
|
184 |
gr.Markdown("# Safety Analysis System powered by Llama 3.2 90b vision")
|
185 |
|
186 |
with gr.Row():
|
187 |
+
input_image = gr.Image(label="Upload Image")
|
188 |
output_image = gr.Image(label="Analysis")
|
189 |
|
190 |
analysis_text = gr.Textbox(label="Safety Concerns", lines=5)
|
191 |
|
192 |
+
def analyze_image(image):
|
193 |
if image is None:
|
194 |
return None, "No image provided"
|
195 |
try:
|
|
|
199 |
print(f"Processing error: {str(e)}")
|
200 |
return None, f"Error processing image: {str(e)}"
|
201 |
|
202 |
+
input_image.change(
|
203 |
+
fn=analyze_image,
|
204 |
+
inputs=input_image,
|
205 |
+
outputs=[output_image, analysis_text]
|
206 |
)
|
207 |
|
208 |
+
gr.Markdown("""
|
209 |
+
## Instructions:
|
210 |
+
1. Upload an image to analyze safety concerns
|
211 |
+
2. View annotated results and detailed analysis
|
212 |
+
3. Each box highlights a potential safety issue
|
213 |
+
""")
|
214 |
+
|
215 |
return demo
|
216 |
|
217 |
demo = create_monitor_interface()
|