cosmo3769's picture
Create app.py
06e0fe9 verified
raw
history blame
3.49 kB
import torch
import gradio as gr
import re
import cv2
from PIL import ImageDraw, Image
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
mix_model_id = "google/paligemma-3b-mix-224"
mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id)
mix_processor = AutoProcessor.from_pretrained(mix_model_id)
# Helper function to parse multiple <loc> tags and return a list of coordinate sets and labels
def parse_multiple_locations(decoded_output):
# Regex pattern to match four <locxxxx> tags and the label at the end (e.g., 'cat')
loc_pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s+(\w+)"
matches = re.findall(loc_pattern, decoded_output)
coords_and_labels = []
for match in matches:
# Extract the coordinates and label
y1 = int(match[0]) / 1000
x1 = int(match[1]) / 1000
y2 = int(match[2]) / 1000
x2 = int(match[3]) / 1000
label = match[4]
coords_and_labels.append({
'label': label,
'bbox': [y1, x1, y2, x2]
})
return coords_and_labels
# Helper function to draw bounding boxes and labels for all objects on the image
def draw_multiple_bounding_boxes(image, coords_and_labels):
draw = ImageDraw.Draw(image)
width, height = image.size
for obj in coords_and_labels:
# Extract the bounding box coordinates
y1, x1, y2, x2 = obj['bbox'][0] * height, obj['bbox'][1] * width, obj['bbox'][2] * height, obj['bbox'][3] * width
# Draw bounding box and label
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
draw.text((x1, y1), obj['label'], fill="red")
return image
# Define inference function
def process_image(image, prompt):
# Process the image and prompt using the processor
inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt")
try:
# Generate output from the model
output = mix_model.generate(**inputs, max_new_tokens=100)
# Decode the output from the model
decoded_output = mix_processor.decode(output[0], skip_special_tokens=True)
# Extract bounding box coordinates and labels
coords_and_labels = parse_multiple_locations(decoded_output)
if coords_and_labels:
# Draw bounding boxes and labels on the image
image_with_boxes = draw_multiple_bounding_boxes(image, coords_and_labels)
# Prepare the coordinates and labels for the UI
labels_and_coords = "\n".join([f"Label: {obj['label']}, Coordinates: {obj['bbox']}" for obj in coords_and_labels])
# Return the modified image and the list of coordinates+labels
return image_with_boxes, labels_and_coords
else:
return "No bounding boxes detected."
except IndexError as e:
print(f"IndexError: {e}")
return "An error occurred during processing."
# Define the Gradio interface
inputs = [
gr.Image(type="pil"),
gr.Textbox(label="Prompt", placeholder="Enter your question")
]
outputs = [
gr.Image(label="Output Image with Bounding Boxes"),
gr.Textbox(label="Bounding Box Coordinates and Labels")
]
# Create the Gradio app
demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Object Detection with Mix PaliGemma Model",
description="Upload an image and get object detections with bounding boxes and labels.")
# Launch the app
demo.launch(debug=True)