reab5555 commited on
Commit
605c9e7
·
verified ·
1 Parent(s): 5fbcf60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -167,23 +167,23 @@ def process_image_detection(image, target_label, surprise_rating):
167
  image = image.convert('RGB')
168
 
169
  device = "cuda" if torch.cuda.is_available() else "cpu"
170
- print(f"Using device: {device}") # Debug print
171
 
172
  # Get original image DPI and size
173
  original_dpi = image.info.get('dpi', (72, 72))
174
  original_size = image.size
175
- print(f"Image size: {original_size}") # Debug print
176
 
177
  # Calculate relative font size based on image dimensions
178
  base_fontsize = min(original_size) / 40
179
 
180
- print("Loading models...") # Debug print
181
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
182
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
183
  sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
184
  sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
185
 
186
- print("Running object detection...") # Debug print
187
  inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
188
  with torch.no_grad():
189
  outputs = owlv2_model(**inputs)
@@ -204,10 +204,10 @@ def process_image_detection(image, target_label, surprise_rating):
204
  max_score = scores[max_score_idx].item()
205
 
206
  if max_score > 0.2:
207
- print("Processing detection results...") # Debug print
208
  box = results["boxes"][max_score_idx].cpu().numpy()
209
 
210
- print("Running SAM model...") # Debug print
211
  # Convert image to numpy array if needed for SAM
212
  if isinstance(image, Image.Image):
213
  image_np = np.array(image)
@@ -215,7 +215,7 @@ def process_image_detection(image, target_label, surprise_rating):
215
  image_np = image
216
 
217
  sam_inputs = sam_processor(
218
- image_np, # Use numpy array here
219
  input_boxes=[[[box[0], box[1], box[2], box[3]]]],
220
  return_tensors="pt"
221
  ).to(device)
@@ -229,7 +229,7 @@ def process_image_detection(image, target_label, surprise_rating):
229
  sam_inputs["reshaped_input_sizes"].cpu()
230
  )
231
 
232
- print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}") # Debug print
233
  mask = masks[0]
234
  if isinstance(mask, torch.Tensor):
235
  mask = mask.numpy()
@@ -266,24 +266,23 @@ def process_image_detection(image, target_label, surprise_rating):
266
  )
267
 
268
  plt.axis('off')
269
-
270
- print("Saving final image...") # Debug print
271
  try:
272
- # Save to buffer
273
- buf = io.BytesIO()
274
-
275
- # Force figure to be in a format we can save
276
  fig.canvas.draw()
277
 
278
- # Get the image data from the figure
279
- plot_data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
280
- plot_data = plot_data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
 
281
 
282
- # Convert to PIL Image
283
- output_image = Image.fromarray(plot_data)
284
 
285
- # Resize if needed
286
- output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
 
287
 
288
  # Save to final buffer
289
  final_buf = io.BytesIO()
@@ -294,13 +293,17 @@ def process_image_detection(image, target_label, surprise_rating):
294
  plt.close(fig)
295
 
296
  return final_buf
297
-
298
  except Exception as e:
299
- print(f"Save error details: {str(e)}") # Debug print
300
  print(f"Figure type: {type(fig)}")
301
  print(f"Canvas type: {type(fig.canvas)}")
302
  raise
303
 
 
 
 
 
304
 
305
  def process_and_analyze(image):
306
  if image is None:
 
167
  image = image.convert('RGB')
168
 
169
  device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ print(f"Using device: {device}")
171
 
172
  # Get original image DPI and size
173
  original_dpi = image.info.get('dpi', (72, 72))
174
  original_size = image.size
175
+ print(f"Image size: {original_size}")
176
 
177
  # Calculate relative font size based on image dimensions
178
  base_fontsize = min(original_size) / 40
179
 
180
+ print("Loading models...")
181
  owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
182
  owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
183
  sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
184
  sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
185
 
186
+ print("Running object detection...")
187
  inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
188
  with torch.no_grad():
189
  outputs = owlv2_model(**inputs)
 
204
  max_score = scores[max_score_idx].item()
205
 
206
  if max_score > 0.2:
207
+ print("Processing detection results...")
208
  box = results["boxes"][max_score_idx].cpu().numpy()
209
 
210
+ print("Running SAM model...")
211
  # Convert image to numpy array if needed for SAM
212
  if isinstance(image, Image.Image):
213
  image_np = np.array(image)
 
215
  image_np = image
216
 
217
  sam_inputs = sam_processor(
218
+ image_np,
219
  input_boxes=[[[box[0], box[1], box[2], box[3]]]],
220
  return_tensors="pt"
221
  ).to(device)
 
229
  sam_inputs["reshaped_input_sizes"].cpu()
230
  )
231
 
232
+ print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
233
  mask = masks[0]
234
  if isinstance(mask, torch.Tensor):
235
  mask = mask.numpy()
 
266
  )
267
 
268
  plt.axis('off')
269
+
270
+ print("Saving final image...")
271
  try:
272
+ # Force figure to be rendered
 
 
 
273
  fig.canvas.draw()
274
 
275
+ # Get the RGBA buffer from the figure
276
+ w, h = fig.canvas.get_width_height()
277
+ buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
278
+ buf.shape = (h, w, 3)
279
 
280
+ # Create PIL Image from buffer
281
+ output_image = Image.fromarray(buf)
282
 
283
+ # Resize to original size if needed
284
+ if output_image.size != original_size:
285
+ output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
286
 
287
  # Save to final buffer
288
  final_buf = io.BytesIO()
 
293
  plt.close(fig)
294
 
295
  return final_buf
296
+
297
  except Exception as e:
298
+ print(f"Save error details: {str(e)}")
299
  print(f"Figure type: {type(fig)}")
300
  print(f"Canvas type: {type(fig.canvas)}")
301
  raise
302
 
303
+ except Exception as e:
304
+ print(f"Process image detection error: {str(e)}")
305
+ print(f"Error occurred at line {e.__traceback__.tb_lineno}")
306
+ raise
307
 
308
  def process_and_analyze(image):
309
  if image is None: