Update app.py
Browse files
app.py
CHANGED
|
@@ -143,168 +143,173 @@ def show_mask(mask, ax, random_color=False):
|
|
| 143 |
|
| 144 |
|
| 145 |
def process_image_detection(image, target_label, surprise_rating):
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
def process_and_analyze(image):
|
| 310 |
if image is None:
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
def process_image_detection(image, target_label, surprise_rating):
|
| 146 |
+
try:
|
| 147 |
+
# Handle different image input types
|
| 148 |
+
if isinstance(image, tuple):
|
| 149 |
+
if len(image) > 0 and image[0] is not None:
|
| 150 |
+
if isinstance(image[0], np.ndarray):
|
| 151 |
+
image = Image.fromarray(image[0])
|
| 152 |
+
else:
|
| 153 |
+
image = image[0]
|
| 154 |
+
else:
|
| 155 |
+
raise ValueError("Invalid image tuple provided")
|
| 156 |
+
elif isinstance(image, np.ndarray):
|
| 157 |
+
image = Image.fromarray(image)
|
| 158 |
+
elif isinstance(image, str):
|
| 159 |
+
image = Image.open(image)
|
| 160 |
+
|
| 161 |
+
# Ensure image is in PIL Image format
|
| 162 |
+
if not isinstance(image, Image.Image):
|
| 163 |
+
raise ValueError(f"Input must be a PIL Image, got {type(image)}")
|
| 164 |
+
|
| 165 |
+
# Ensure image is in RGB mode
|
| 166 |
+
if image.mode != 'RGB':
|
| 167 |
+
image = image.convert('RGB')
|
| 168 |
+
|
| 169 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 170 |
+
print(f"Using device: {device}")
|
| 171 |
+
|
| 172 |
+
# Get original image DPI and size
|
| 173 |
+
original_dpi = image.info.get('dpi', (72, 72))
|
| 174 |
+
original_size = image.size
|
| 175 |
+
print(f"Image size: {original_size}")
|
| 176 |
+
|
| 177 |
+
# Calculate relative font size based on image dimensions
|
| 178 |
+
base_fontsize = min(original_size) / 40
|
| 179 |
+
|
| 180 |
+
print("Loading models...")
|
| 181 |
+
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
| 182 |
+
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
| 183 |
+
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
| 184 |
+
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
|
| 185 |
+
|
| 186 |
+
print("Running object detection...")
|
| 187 |
+
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
outputs = owlv2_model(**inputs)
|
| 190 |
+
|
| 191 |
+
target_sizes = torch.tensor([image.size[::-1]]).to(device)
|
| 192 |
+
results = owlv2_processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
|
| 193 |
+
|
| 194 |
+
dpi = 300
|
| 195 |
+
figsize = (original_size[0] / dpi, original_size[1] / dpi)
|
| 196 |
+
fig = plt.figure(figsize=figsize, dpi=dpi)
|
| 197 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
| 198 |
+
fig.add_axes(ax)
|
| 199 |
+
ax.imshow(image)
|
| 200 |
+
|
| 201 |
+
scores = results["scores"]
|
| 202 |
+
if len(scores) > 0:
|
| 203 |
+
max_score_idx = scores.argmax().item()
|
| 204 |
+
max_score = scores[max_score_idx].item()
|
| 205 |
+
|
| 206 |
+
if max_score > 0.2:
|
| 207 |
+
print("Processing detection results...")
|
| 208 |
+
box = results["boxes"][max_score_idx].cpu().numpy()
|
| 209 |
+
|
| 210 |
+
print("Running SAM model...")
|
| 211 |
+
# Convert image to numpy array if needed for SAM
|
| 212 |
+
if isinstance(image, Image.Image):
|
| 213 |
+
image_np = np.array(image)
|
| 214 |
+
else:
|
| 215 |
+
image_np = image
|
| 216 |
+
|
| 217 |
+
sam_inputs = sam_processor(
|
| 218 |
+
image_np,
|
| 219 |
+
input_boxes=[[[box[0], box[1], box[2], box[3]]]],
|
| 220 |
+
return_tensors="pt"
|
| 221 |
+
).to(device)
|
| 222 |
+
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
sam_outputs = sam_model(**sam_inputs)
|
| 225 |
+
|
| 226 |
+
masks = sam_processor.image_processor.post_process_masks(
|
| 227 |
+
sam_outputs.pred_masks.cpu(),
|
| 228 |
+
sam_inputs["original_sizes"].cpu(),
|
| 229 |
+
sam_inputs["reshaped_input_sizes"].cpu()
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
|
| 233 |
+
mask = masks[0]
|
| 234 |
+
if isinstance(mask, torch.Tensor):
|
| 235 |
+
mask = mask.numpy()
|
| 236 |
+
|
| 237 |
+
show_mask(mask, ax=ax)
|
| 238 |
+
|
| 239 |
+
rect = patches.Rectangle(
|
| 240 |
+
(box[0], box[1]),
|
| 241 |
+
box[2] - box[0],
|
| 242 |
+
box[3] - box[1],
|
| 243 |
+
linewidth=max(2, min(original_size) / 500),
|
| 244 |
+
edgecolor='red',
|
| 245 |
+
facecolor='none'
|
| 246 |
+
)
|
| 247 |
+
ax.add_patch(rect)
|
| 248 |
+
|
| 249 |
+
plt.text(
|
| 250 |
+
box[0], box[1] - base_fontsize,
|
| 251 |
+
f'{max_score:.2f}',
|
| 252 |
+
color='red',
|
| 253 |
+
fontsize=base_fontsize,
|
| 254 |
+
fontweight='bold',
|
| 255 |
+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
plt.text(
|
| 259 |
+
box[2] + base_fontsize / 2, box[1],
|
| 260 |
+
f'Unexpected (Rating: {surprise_rating}/5)\n{target_label}',
|
| 261 |
+
color='red',
|
| 262 |
+
fontsize=base_fontsize,
|
| 263 |
+
fontweight='bold',
|
| 264 |
+
bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=2),
|
| 265 |
+
verticalalignment='bottom'
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
plt.axis('off')
|
| 269 |
+
|
| 270 |
+
print("Saving final image...")
|
| 271 |
+
try:
|
| 272 |
+
# Save directly to buffer using savefig
|
| 273 |
+
buf = io.BytesIO()
|
| 274 |
+
fig.savefig(buf,
|
| 275 |
+
format='png',
|
| 276 |
+
dpi=dpi,
|
| 277 |
+
bbox_inches='tight',
|
| 278 |
+
pad_inches=0)
|
| 279 |
+
buf.seek(0)
|
| 280 |
+
|
| 281 |
+
# Open as PIL Image
|
| 282 |
+
output_image = Image.open(buf)
|
| 283 |
+
|
| 284 |
+
# Convert to RGB if needed
|
| 285 |
+
if output_image.mode != 'RGB':
|
| 286 |
+
output_image = output_image.convert('RGB')
|
| 287 |
+
|
| 288 |
+
# Resize to original size if needed
|
| 289 |
+
if output_image.size != original_size:
|
| 290 |
+
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
| 291 |
+
|
| 292 |
+
# Save to final buffer
|
| 293 |
+
final_buf = io.BytesIO()
|
| 294 |
+
output_image.save(final_buf, format='PNG', dpi=original_dpi)
|
| 295 |
+
final_buf.seek(0)
|
| 296 |
+
|
| 297 |
+
# Cleanup
|
| 298 |
+
plt.close(fig)
|
| 299 |
+
buf.close()
|
| 300 |
+
|
| 301 |
+
return final_buf
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"Save error details: {str(e)}")
|
| 305 |
+
print(f"Figure type: {type(fig)}")
|
| 306 |
+
print(f"Canvas type: {type(fig.canvas)}")
|
| 307 |
+
raise
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
print(f"Process image detection error: {str(e)}")
|
| 311 |
+
print(f"Error occurred at line {e.__traceback__.tb_lineno}")
|
| 312 |
+
raise
|
| 313 |
|
| 314 |
def process_and_analyze(image):
|
| 315 |
if image is None:
|