Update app.py
Browse files
app.py
CHANGED
@@ -167,23 +167,23 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
167 |
image = image.convert('RGB')
|
168 |
|
169 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
170 |
-
print(f"Using device: {device}")
|
171 |
|
172 |
# Get original image DPI and size
|
173 |
original_dpi = image.info.get('dpi', (72, 72))
|
174 |
original_size = image.size
|
175 |
-
print(f"Image size: {original_size}")
|
176 |
|
177 |
# Calculate relative font size based on image dimensions
|
178 |
base_fontsize = min(original_size) / 40
|
179 |
|
180 |
-
print("Loading models...")
|
181 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
182 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
183 |
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
184 |
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
|
185 |
|
186 |
-
print("Running object detection...")
|
187 |
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
|
188 |
with torch.no_grad():
|
189 |
outputs = owlv2_model(**inputs)
|
@@ -204,10 +204,10 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
204 |
max_score = scores[max_score_idx].item()
|
205 |
|
206 |
if max_score > 0.2:
|
207 |
-
print("Processing detection results...")
|
208 |
box = results["boxes"][max_score_idx].cpu().numpy()
|
209 |
|
210 |
-
print("Running SAM model...")
|
211 |
# Convert image to numpy array if needed for SAM
|
212 |
if isinstance(image, Image.Image):
|
213 |
image_np = np.array(image)
|
@@ -215,7 +215,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
215 |
image_np = image
|
216 |
|
217 |
sam_inputs = sam_processor(
|
218 |
-
image_np,
|
219 |
input_boxes=[[[box[0], box[1], box[2], box[3]]]],
|
220 |
return_tensors="pt"
|
221 |
).to(device)
|
@@ -229,7 +229,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
229 |
sam_inputs["reshaped_input_sizes"].cpu()
|
230 |
)
|
231 |
|
232 |
-
print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
|
233 |
mask = masks[0]
|
234 |
if isinstance(mask, torch.Tensor):
|
235 |
mask = mask.numpy()
|
@@ -266,24 +266,23 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
266 |
)
|
267 |
|
268 |
plt.axis('off')
|
269 |
-
|
270 |
-
print("Saving final image...")
|
271 |
try:
|
272 |
-
#
|
273 |
-
buf = io.BytesIO()
|
274 |
-
|
275 |
-
# Force figure to be in a format we can save
|
276 |
fig.canvas.draw()
|
277 |
|
278 |
-
# Get the
|
279 |
-
|
280 |
-
|
|
|
281 |
|
282 |
-
#
|
283 |
-
output_image = Image.fromarray(
|
284 |
|
285 |
-
# Resize if needed
|
286 |
-
output_image
|
|
|
287 |
|
288 |
# Save to final buffer
|
289 |
final_buf = io.BytesIO()
|
@@ -294,13 +293,17 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
294 |
plt.close(fig)
|
295 |
|
296 |
return final_buf
|
297 |
-
|
298 |
except Exception as e:
|
299 |
-
print(f"Save error details: {str(e)}")
|
300 |
print(f"Figure type: {type(fig)}")
|
301 |
print(f"Canvas type: {type(fig.canvas)}")
|
302 |
raise
|
303 |
|
|
|
|
|
|
|
|
|
304 |
|
305 |
def process_and_analyze(image):
|
306 |
if image is None:
|
|
|
167 |
image = image.convert('RGB')
|
168 |
|
169 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
170 |
+
print(f"Using device: {device}")
|
171 |
|
172 |
# Get original image DPI and size
|
173 |
original_dpi = image.info.get('dpi', (72, 72))
|
174 |
original_size = image.size
|
175 |
+
print(f"Image size: {original_size}")
|
176 |
|
177 |
# Calculate relative font size based on image dimensions
|
178 |
base_fontsize = min(original_size) / 40
|
179 |
|
180 |
+
print("Loading models...")
|
181 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
182 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
183 |
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
184 |
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
|
185 |
|
186 |
+
print("Running object detection...")
|
187 |
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
|
188 |
with torch.no_grad():
|
189 |
outputs = owlv2_model(**inputs)
|
|
|
204 |
max_score = scores[max_score_idx].item()
|
205 |
|
206 |
if max_score > 0.2:
|
207 |
+
print("Processing detection results...")
|
208 |
box = results["boxes"][max_score_idx].cpu().numpy()
|
209 |
|
210 |
+
print("Running SAM model...")
|
211 |
# Convert image to numpy array if needed for SAM
|
212 |
if isinstance(image, Image.Image):
|
213 |
image_np = np.array(image)
|
|
|
215 |
image_np = image
|
216 |
|
217 |
sam_inputs = sam_processor(
|
218 |
+
image_np,
|
219 |
input_boxes=[[[box[0], box[1], box[2], box[3]]]],
|
220 |
return_tensors="pt"
|
221 |
).to(device)
|
|
|
229 |
sam_inputs["reshaped_input_sizes"].cpu()
|
230 |
)
|
231 |
|
232 |
+
print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
|
233 |
mask = masks[0]
|
234 |
if isinstance(mask, torch.Tensor):
|
235 |
mask = mask.numpy()
|
|
|
266 |
)
|
267 |
|
268 |
plt.axis('off')
|
269 |
+
|
270 |
+
print("Saving final image...")
|
271 |
try:
|
272 |
+
# Force figure to be rendered
|
|
|
|
|
|
|
273 |
fig.canvas.draw()
|
274 |
|
275 |
+
# Get the RGBA buffer from the figure
|
276 |
+
w, h = fig.canvas.get_width_height()
|
277 |
+
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
278 |
+
buf.shape = (h, w, 3)
|
279 |
|
280 |
+
# Create PIL Image from buffer
|
281 |
+
output_image = Image.fromarray(buf)
|
282 |
|
283 |
+
# Resize to original size if needed
|
284 |
+
if output_image.size != original_size:
|
285 |
+
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
286 |
|
287 |
# Save to final buffer
|
288 |
final_buf = io.BytesIO()
|
|
|
293 |
plt.close(fig)
|
294 |
|
295 |
return final_buf
|
296 |
+
|
297 |
except Exception as e:
|
298 |
+
print(f"Save error details: {str(e)}")
|
299 |
print(f"Figure type: {type(fig)}")
|
300 |
print(f"Canvas type: {type(fig.canvas)}")
|
301 |
raise
|
302 |
|
303 |
+
except Exception as e:
|
304 |
+
print(f"Process image detection error: {str(e)}")
|
305 |
+
print(f"Error occurred at line {e.__traceback__.tb_lineno}")
|
306 |
+
raise
|
307 |
|
308 |
def process_and_analyze(image):
|
309 |
if image is None:
|