Mohammed Abdeldayem
Update app.py
eed6859 verified
raw
history blame
4.12 kB
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr
# Load models during initialization
def init():
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Step 1: Load the YOLOv5 model from Hugging Face
try:
# Load the YOLOv5 model with trust_repo=True to skip the warning
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt', trust_repo=True)
# Assuming model is locally available
print("YOLOv5 model loaded successfully.")
except Exception as e:
print(f"Error loading YOLOv5 model: {e}")
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face
try:
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
print("ViT-GPT2 model loaded successfully.")
except Exception as e:
print(f"Error loading captioning model: {e}")
# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
cropped_images = []
for box in boxes:
cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0]))
cropped_images.append(cropped_image)
return cropped_images
# Gradio interface function
def process_image(image):
try:
# Step 1: Perform object detection with YOLOv5
results = object_detection_model(image)
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
# Step 2: Generate caption for the whole image
original_inputs = captioning_processor(images=image, return_tensors="pt")
with torch.no_grad():
original_caption_ids = captioning_model.generate(**original_inputs)
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
# Step 3: Crop detected objects and generate captions for each object
cropped_images = crop_objects(image, boxes)
captions = []
for cropped_image in cropped_images:
inputs = captioning_processor(images=cropped_image, return_tensors="pt")
with torch.no_grad():
caption_ids = captioning_model.generate(**inputs)
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
captions.append(caption)
# Prepare the result for visualization
detection_results = []
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
detection_results.append({
"label": label,
"caption": caption,
"bounding_box": [float(coord) for coord in box], # Convert to float
"confidence_score": float(score) # Convert to float
})
# Return the image with detections and the caption
return results.render()[0], detection_results, original_caption
except Exception as e:
return None, {"error": str(e)}, None
# Initialize models
init()
# Updated Gradio interface with new syntax
interface = gr.Interface(
fn=process_image, # Function to run
inputs=gr.Image(type="pil"), # Input: Image upload
outputs=[gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
gr.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object
gr.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image
live=True
)
# Launch the Gradio app
interface.launch()