reab5555 commited on
Commit
e1b87b0
·
verified ·
1 Parent(s): aff928e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -162
app.py CHANGED
@@ -143,168 +143,173 @@ def show_mask(mask, ax, random_color=False):
143
 
144
 
145
  def process_image_detection(image, target_label, surprise_rating):
146
- try:
147
- # Handle different image input types
148
- if isinstance(image, tuple):
149
- if len(image) > 0 and image[0] is not None:
150
- if isinstance(image[0], np.ndarray):
151
- image = Image.fromarray(image[0])
152
- else:
153
- image = image[0]
154
- else:
155
- raise ValueError("Invalid image tuple provided")
156
- elif isinstance(image, np.ndarray):
157
- image = Image.fromarray(image)
158
- elif isinstance(image, str):
159
- image = Image.open(image)
160
-
161
- # Ensure image is in PIL Image format
162
- if not isinstance(image, Image.Image):
163
- raise ValueError(f"Input must be a PIL Image, got {type(image)}")
164
-
165
- # Ensure image is in RGB mode
166
- if image.mode != 'RGB':
167
- image = image.convert('RGB')
168
-
169
- device = "cuda" if torch.cuda.is_available() else "cpu"
170
- print(f"Using device: {device}")
171
-
172
- # Get original image DPI and size
173
- original_dpi = image.info.get('dpi', (72, 72))
174
- original_size = image.size
175
- print(f"Image size: {original_size}")
176
-
177
- # Calculate relative font size based on image dimensions
178
- base_fontsize = min(original_size) / 40
179
-
180
- print("Loading models...")
181
- owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
182
- owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
183
- sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
184
- sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
185
-
186
- print("Running object detection...")
187
- inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
188
- with torch.no_grad():
189
- outputs = owlv2_model(**inputs)
190
-
191
- target_sizes = torch.tensor([image.size[::-1]]).to(device)
192
- results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
193
-
194
- dpi = 300
195
- figsize = (original_size[0] / dpi, original_size[1] / dpi)
196
- fig = plt.figure(figsize=figsize, dpi=dpi)
197
- ax = plt.Axes(fig, [0., 0., 1., 1.])
198
- fig.add_axes(ax)
199
- ax.imshow(image)
200
-
201
- scores = results["scores"]
202
- if len(scores) > 0:
203
- max_score_idx = scores.argmax().item()
204
- max_score = scores[max_score_idx].item()
205
-
206
- if max_score > 0.2:
207
- print("Processing detection results...")
208
- box = results["boxes"][max_score_idx].cpu().numpy()
209
-
210
- print("Running SAM model...")
211
- # Convert image to numpy array if needed for SAM
212
- if isinstance(image, Image.Image):
213
- image_np = np.array(image)
214
- else:
215
- image_np = image
216
-
217
- sam_inputs = sam_processor(
218
- image_np,
219
- input_boxes=[[[box[0], box[1], box[2], box[3]]]],
220
- return_tensors="pt"
221
- ).to(device)
222
-
223
- with torch.no_grad():
224
- sam_outputs = sam_model(**sam_inputs)
225
-
226
- masks = sam_processor.image_processor.post_process_masks(
227
- sam_outputs.pred_masks.cpu(),
228
- sam_inputs["original_sizes"].cpu(),
229
- sam_inputs["reshaped_input_sizes"].cpu()
230
- )
231
-
232
- print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
233
- mask = masks[0]
234
- if isinstance(mask, torch.Tensor):
235
- mask = mask.numpy()
236
-
237
- show_mask(mask, ax=ax)
238
-
239
- rect = patches.Rectangle(
240
- (box[0], box[1]),
241
- box[2] - box[0],
242
- box[3] - box[1],
243
- linewidth=max(2, min(original_size) / 500),
244
- edgecolor='red',
245
- facecolor='none'
246
- )
247
- ax.add_patch(rect)
248
-
249
- plt.text(
250
- box[0], box[1] - base_fontsize,
251
- f'{max_score:.2f}',
252
- color='red',
253
- fontsize=base_fontsize,
254
- fontweight='bold',
255
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
256
- )
257
-
258
- plt.text(
259
- box[2] + base_fontsize / 2, box[1],
260
- f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
261
- color='red',
262
- fontsize=base_fontsize,
263
- fontweight='bold',
264
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
265
- verticalalignment='bottom'
266
- )
267
-
268
- plt.axis('off')
269
-
270
- print("Saving final image...")
271
- try:
272
- # Save directly to buffer using savefig
273
- buf = io.BytesIO()
274
- fig.savefig(buf,
275
- format='png',
276
- dpi=dpi,
277
- bbox_inches='tight',
278
- pad_inches=0)
279
- buf.seek(0)
280
-
281
- # Open as PIL Image
282
- output_image = Image.open(buf)
283
-
284
- # Convert to RGB if needed
285
- if output_image.mode != 'RGB':
286
- output_image = output_image.convert('RGB')
287
-
288
- # Resize to original size if needed
289
- if output_image.size != original_size:
290
- output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
291
-
292
- # Save to final buffer
293
- final_buf = io.BytesIO()
294
- output_image.save(final_buf, format='PNG', dpi=original_dpi)
295
- final_buf.seek(0)
296
-
297
- # Cleanup
298
- plt.close(fig)
299
- buf.close()
300
-
301
- return final_buf
302
-
303
- except Exception as e:
304
- print(f"Save error details: {str(e)}")
305
- print(f"Figure type: {type(fig)}")
306
- print(f"Canvas type: {type(fig.canvas)}")
307
- raise
 
 
 
 
 
308
 
309
  def process_and_analyze(image):
310
  if image is None:
 
143
 
144
 
145
  def process_image_detection(image, target_label, surprise_rating):
146
+ try:
147
+ # Handle different image input types
148
+ if isinstance(image, tuple):
149
+ if len(image) > 0 and image[0] is not None:
150
+ if isinstance(image[0], np.ndarray):
151
+ image = Image.fromarray(image[0])
152
+ else:
153
+ image = image[0]
154
+ else:
155
+ raise ValueError("Invalid image tuple provided")
156
+ elif isinstance(image, np.ndarray):
157
+ image = Image.fromarray(image)
158
+ elif isinstance(image, str):
159
+ image = Image.open(image)
160
+
161
+ # Ensure image is in PIL Image format
162
+ if not isinstance(image, Image.Image):
163
+ raise ValueError(f"Input must be a PIL Image, got {type(image)}")
164
+
165
+ # Ensure image is in RGB mode
166
+ if image.mode != 'RGB':
167
+ image = image.convert('RGB')
168
+
169
+ device = "cuda" if torch.cuda.is_available() else "cpu"
170
+ print(f"Using device: {device}")
171
+
172
+ # Get original image DPI and size
173
+ original_dpi = image.info.get('dpi', (72, 72))
174
+ original_size = image.size
175
+ print(f"Image size: {original_size}")
176
+
177
+ # Calculate relative font size based on image dimensions
178
+ base_fontsize = min(original_size) / 40
179
+
180
+ print("Loading models...")
181
+ owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
182
+ owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
183
+ sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
184
+ sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
185
+
186
+ print("Running object detection...")
187
+ inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
188
+ with torch.no_grad():
189
+ outputs = owlv2_model(**inputs)
190
+
191
+ target_sizes = torch.tensor([image.size[::-1]]).to(device)
192
+ results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
193
+
194
+ dpi = 300
195
+ figsize = (original_size[0] / dpi, original_size[1] / dpi)
196
+ fig = plt.figure(figsize=figsize, dpi=dpi)
197
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
198
+ fig.add_axes(ax)
199
+ ax.imshow(image)
200
+
201
+ scores = results["scores"]
202
+ if len(scores) > 0:
203
+ max_score_idx = scores.argmax().item()
204
+ max_score = scores[max_score_idx].item()
205
+
206
+ if max_score > 0.2:
207
+ print("Processing detection results...")
208
+ box = results["boxes"][max_score_idx].cpu().numpy()
209
+
210
+ print("Running SAM model...")
211
+ # Convert image to numpy array if needed for SAM
212
+ if isinstance(image, Image.Image):
213
+ image_np = np.array(image)
214
+ else:
215
+ image_np = image
216
+
217
+ sam_inputs = sam_processor(
218
+ image_np,
219
+ input_boxes=[[[box[0], box[1], box[2], box[3]]]],
220
+ return_tensors="pt"
221
+ ).to(device)
222
+
223
+ with torch.no_grad():
224
+ sam_outputs = sam_model(**sam_inputs)
225
+
226
+ masks = sam_processor.image_processor.post_process_masks(
227
+ sam_outputs.pred_masks.cpu(),
228
+ sam_inputs["original_sizes"].cpu(),
229
+ sam_inputs["reshaped_input_sizes"].cpu()
230
+ )
231
+
232
+ print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
233
+ mask = masks[0]
234
+ if isinstance(mask, torch.Tensor):
235
+ mask = mask.numpy()
236
+
237
+ show_mask(mask, ax=ax)
238
+
239
+ rect = patches.Rectangle(
240
+ (box[0], box[1]),
241
+ box[2] - box[0],
242
+ box[3] - box[1],
243
+ linewidth=max(2, min(original_size) / 500),
244
+ edgecolor='red',
245
+ facecolor='none'
246
+ )
247
+ ax.add_patch(rect)
248
+
249
+ plt.text(
250
+ box[0], box[1] - base_fontsize,
251
+ f'{max_score:.2f}',
252
+ color='red',
253
+ fontsize=base_fontsize,
254
+ fontweight='bold',
255
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
256
+ )
257
+
258
+ plt.text(
259
+ box[2] + base_fontsize / 2, box[1],
260
+ f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
261
+ color='red',
262
+ fontsize=base_fontsize,
263
+ fontweight='bold',
264
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
265
+ verticalalignment='bottom'
266
+ )
267
+
268
+ plt.axis('off')
269
+
270
+ print("Saving final image...")
271
+ try:
272
+ # Save directly to buffer using savefig
273
+ buf = io.BytesIO()
274
+ fig.savefig(buf,
275
+ format='png',
276
+ dpi=dpi,
277
+ bbox_inches='tight',
278
+ pad_inches=0)
279
+ buf.seek(0)
280
+
281
+ # Open as PIL Image
282
+ output_image = Image.open(buf)
283
+
284
+ # Convert to RGB if needed
285
+ if output_image.mode != 'RGB':
286
+ output_image = output_image.convert('RGB')
287
+
288
+ # Resize to original size if needed
289
+ if output_image.size != original_size:
290
+ output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
291
+
292
+ # Save to final buffer
293
+ final_buf = io.BytesIO()
294
+ output_image.save(final_buf, format='PNG', dpi=original_dpi)
295
+ final_buf.seek(0)
296
+
297
+ # Cleanup
298
+ plt.close(fig)
299
+ buf.close()
300
+
301
+ return final_buf
302
+
303
+ except Exception as e:
304
+ print(f"Save error details: {str(e)}")
305
+ print(f"Figure type: {type(fig)}")
306
+ print(f"Canvas type: {type(fig.canvas)}")
307
+ raise
308
+
309
+ except Exception as e:
310
+ print(f"Process image detection error: {str(e)}")
311
+ print(f"Error occurred at line {e.__traceback__.tb_lineno}")
312
+ raise
313
 
314
  def process_and_analyze(image):
315
  if image is None: