reab5555 commited on
Commit
6b3604d
·
verified ·
1 Parent(s): e674f2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -132
app.py CHANGED
@@ -119,138 +119,154 @@ def show_mask(mask, ax, random_color=False):
119
 
120
 
121
  def process_image_detection(image, target_label, surprise_rating):
122
- # Handle different image input types
123
- if isinstance(image, tuple):
124
- if len(image) > 0 and image[0] is not None:
125
- image = Image.fromarray(image[0])
126
- else:
127
- raise ValueError("Invalid image tuple provided")
128
- elif isinstance(image, np.ndarray):
129
- image = Image.fromarray(image)
130
- elif isinstance(image, str):
131
- image = Image.open(image)
132
-
133
- # Ensure image is in PIL Image format
134
- if not isinstance(image, Image.Image):
135
- raise ValueError("Input must be a PIL Image, numpy array, or valid image path")
136
-
137
- # Ensure image is in RGB mode
138
- if image.mode != 'RGB':
139
- image = image.convert('RGB')
140
-
141
- device = "cuda" if torch.cuda.is_available() else "cpu"
142
-
143
- # Get original image DPI and size
144
- original_dpi = image.info.get('dpi', (72, 72))
145
- original_size = image.size
146
-
147
- # Calculate relative font size based on image dimensions
148
- base_fontsize = min(original_size) / 40
149
-
150
- owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
151
- owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
152
-
153
- sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
154
- sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
155
-
156
- image_np = np.array(image)
157
-
158
- inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
159
- with torch.no_grad():
160
- outputs = owlv2_model(**inputs)
161
-
162
- target_sizes = torch.tensor([image.size[::-1]]).to(device)
163
- results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
164
-
165
- dpi = 300
166
- figsize = (original_size[0] / dpi, original_size[1] / dpi)
167
- fig = plt.figure(figsize=figsize, dpi=dpi)
168
-
169
- ax = plt.Axes(fig, [0., 0., 1., 1.])
170
- fig.add_axes(ax)
171
-
172
- plt.imshow(image)
173
-
174
- scores = results["scores"]
175
- if len(scores) > 0:
176
- max_score_idx = scores.argmax().item()
177
- max_score = scores[max_score_idx].item()
178
-
179
- if max_score > 0.2:
180
- box = results["boxes"][max_score_idx].cpu().numpy()
181
-
182
- sam_inputs = sam_processor(
183
- image,
184
- input_boxes=[[[box[0], box[1], box[2], box[3]]]],
185
- return_tensors="pt"
186
- ).to(device)
187
-
188
- with torch.no_grad():
189
- sam_outputs = sam_model(**sam_inputs)
190
-
191
- masks = sam_processor.image_processor.post_process_masks(
192
- sam_outputs.pred_masks.cpu(),
193
- sam_inputs["original_sizes"].cpu(),
194
- sam_inputs["reshaped_input_sizes"].cpu()
195
- )
196
-
197
- mask = masks[0].numpy() if isinstance(masks[0], torch.Tensor) else masks[0]
198
- show_mask(mask, ax=ax)
199
-
200
- # Draw rectangle with increased line width
201
- rect = patches.Rectangle(
202
- (box[0], box[1]),
203
- box[2] - box[0],
204
- box[3] - box[1],
205
- linewidth=max(2, min(original_size) / 500),
206
- edgecolor='red',
207
- facecolor='none'
208
- )
209
- ax.add_patch(rect)
210
-
211
- # Add confidence score with improved visibility
212
- plt.text(
213
- box[0], box[1] - base_fontsize,
214
- f'{max_score:.2f}',
215
- color='red',
216
- fontsize=base_fontsize,
217
- fontweight='bold',
218
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
219
- )
220
-
221
- # Add label and rating with improved visibility
222
- plt.text(
223
- box[2] + base_fontsize / 2, box[1],
224
- f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
225
- color='red',
226
- fontsize=base_fontsize,
227
- fontweight='bold',
228
- bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
229
- verticalalignment='bottom'
230
- )
231
-
232
- plt.axis('off')
233
-
234
- # Save with high DPI
235
- buf = io.BytesIO()
236
- plt.savefig(buf,
237
- format='png',
238
- dpi=dpi,
239
- bbox_inches='tight',
240
- pad_inches=0,
241
- metadata={'dpi': original_dpi})
242
- buf.seek(0)
243
- plt.close()
244
-
245
- # Process final image
246
- output_image = Image.open(buf)
247
- output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
248
-
249
- final_buf = io.BytesIO()
250
- output_image.save(final_buf, format='PNG', dpi=original_dpi)
251
- final_buf.seek(0)
252
-
253
- return final_buf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
 
256
  def process_and_analyze(image):
 
119
 
120
 
121
  def process_image_detection(image, target_label, surprise_rating):
122
+ try:
123
+ # Handle different image input types
124
+ if isinstance(image, tuple):
125
+ if len(image) > 0 and image[0] is not None:
126
+ if isinstance(image[0], np.ndarray):
127
+ image = Image.fromarray(image[0])
128
+ else:
129
+ image = image[0]
130
+ else:
131
+ raise ValueError("Invalid image tuple provided")
132
+ elif isinstance(image, np.ndarray):
133
+ image = Image.fromarray(image)
134
+ elif isinstance(image, str):
135
+ image = Image.open(image)
136
+
137
+ # Ensure image is in PIL Image format
138
+ if not isinstance(image, Image.Image):
139
+ raise ValueError(f"Input must be a PIL Image, got {type(image)}")
140
+
141
+ # Ensure image is in RGB mode
142
+ if image.mode != 'RGB':
143
+ image = image.convert('RGB')
144
+
145
+ device = "cuda" if torch.cuda.is_available() else "cpu"
146
+ print(f"Using device: {device}") # Debug print
147
+
148
+ # Get original image DPI and size
149
+ original_dpi = image.info.get('dpi', (72, 72))
150
+ original_size = image.size
151
+ print(f"Image size: {original_size}") # Debug print
152
+
153
+ # Calculate relative font size based on image dimensions
154
+ base_fontsize = min(original_size) / 40
155
+
156
+ print("Loading models...") # Debug print
157
+ owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
158
+ owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
159
+ sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
160
+ sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
161
+
162
+ print("Running object detection...") # Debug print
163
+ inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
164
+ with torch.no_grad():
165
+ outputs = owlv2_model(**inputs)
166
+
167
+ target_sizes = torch.tensor([image.size[::-1]]).to(device)
168
+ results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
169
+
170
+ dpi = 300
171
+ figsize = (original_size[0] / dpi, original_size[1] / dpi)
172
+ fig = plt.figure(figsize=figsize, dpi=dpi)
173
+ ax = plt.Axes(fig, [0., 0., 1., 1.])
174
+ fig.add_axes(ax)
175
+ ax.imshow(image)
176
+
177
+ scores = results["scores"]
178
+ if len(scores) > 0:
179
+ max_score_idx = scores.argmax().item()
180
+ max_score = scores[max_score_idx].item()
181
+
182
+ if max_score > 0.2:
183
+ print("Processing detection results...") # Debug print
184
+ box = results["boxes"][max_score_idx].cpu().numpy()
185
+
186
+ print("Running SAM model...") # Debug print
187
+ # Convert image to numpy array if needed for SAM
188
+ if isinstance(image, Image.Image):
189
+ image_np = np.array(image)
190
+ else:
191
+ image_np = image
192
+
193
+ sam_inputs = sam_processor(
194
+ image_np, # Use numpy array here
195
+ input_boxes=[[[box[0], box[1], box[2], box[3]]]],
196
+ return_tensors="pt"
197
+ ).to(device)
198
+
199
+ with torch.no_grad():
200
+ sam_outputs = sam_model(**sam_inputs)
201
+
202
+ masks = sam_processor.image_processor.post_process_masks(
203
+ sam_outputs.pred_masks.cpu(),
204
+ sam_inputs["original_sizes"].cpu(),
205
+ sam_inputs["reshaped_input_sizes"].cpu()
206
+ )
207
+
208
+ print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}") # Debug print
209
+ mask = masks[0]
210
+ if isinstance(mask, torch.Tensor):
211
+ mask = mask.numpy()
212
+
213
+ show_mask(mask, ax=ax)
214
+
215
+ rect = patches.Rectangle(
216
+ (box[0], box[1]),
217
+ box[2] - box[0],
218
+ box[3] - box[1],
219
+ linewidth=max(2, min(original_size) / 500),
220
+ edgecolor='red',
221
+ facecolor='none'
222
+ )
223
+ ax.add_patch(rect)
224
+
225
+ plt.text(
226
+ box[0], box[1] - base_fontsize,
227
+ f'{max_score:.2f}',
228
+ color='red',
229
+ fontsize=base_fontsize,
230
+ fontweight='bold',
231
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
232
+ )
233
+
234
+ plt.text(
235
+ box[2] + base_fontsize / 2, box[1],
236
+ f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
237
+ color='red',
238
+ fontsize=base_fontsize,
239
+ fontweight='bold',
240
+ bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
241
+ verticalalignment='bottom'
242
+ )
243
+
244
+ plt.axis('off')
245
+
246
+ print("Saving final image...") # Debug print
247
+ buf = io.BytesIO()
248
+ plt.savefig(buf,
249
+ format='png',
250
+ dpi=dpi,
251
+ bbox_inches='tight',
252
+ pad_inches=0,
253
+ metadata={'dpi': original_dpi})
254
+ buf.seek(0)
255
+ plt.close()
256
+
257
+ output_image = Image.open(buf)
258
+ output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
259
+
260
+ final_buf = io.BytesIO()
261
+ output_image.save(final_buf, format='PNG', dpi=original_dpi)
262
+ final_buf.seek(0)
263
+
264
+ return final_buf
265
+
266
+ except Exception as e:
267
+ print(f"Process image detection error: {str(e)}") # Debug print
268
+ print(f"Error occurred at line {e.__traceback__.tb_lineno}") # Debug print
269
+ raise
270
 
271
 
272
  def process_and_analyze(image):