File size: 4,119 Bytes
3ce1aa7 eed6859 3ce1aa7 eed6859 3ce1aa7 eed6859 3ce1aa7 eed6859 3ce1aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr
# Load models during initialization
def init():
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Step 1: Load the YOLOv5 model from Hugging Face
try:
# Load the YOLOv5 model with trust_repo=True to skip the warning
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt', trust_repo=True)
# Assuming model is locally available
print("YOLOv5 model loaded successfully.")
except Exception as e:
print(f"Error loading YOLOv5 model: {e}")
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face
try:
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
print("ViT-GPT2 model loaded successfully.")
except Exception as e:
print(f"Error loading captioning model: {e}")
# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
cropped_images = []
for box in boxes:
cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0]))
cropped_images.append(cropped_image)
return cropped_images
# Gradio interface function
def process_image(image):
try:
# Step 1: Perform object detection with YOLOv5
results = object_detection_model(image)
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
# Step 2: Generate caption for the whole image
original_inputs = captioning_processor(images=image, return_tensors="pt")
with torch.no_grad():
original_caption_ids = captioning_model.generate(**original_inputs)
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
# Step 3: Crop detected objects and generate captions for each object
cropped_images = crop_objects(image, boxes)
captions = []
for cropped_image in cropped_images:
inputs = captioning_processor(images=cropped_image, return_tensors="pt")
with torch.no_grad():
caption_ids = captioning_model.generate(**inputs)
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
captions.append(caption)
# Prepare the result for visualization
detection_results = []
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
detection_results.append({
"label": label,
"caption": caption,
"bounding_box": [float(coord) for coord in box], # Convert to float
"confidence_score": float(score) # Convert to float
})
# Return the image with detections and the caption
return results.render()[0], detection_results, original_caption
except Exception as e:
return None, {"error": str(e)}, None
# Initialize models
init()
# Updated Gradio interface with new syntax
interface = gr.Interface(
fn=process_image, # Function to run
inputs=gr.Image(type="pil"), # Input: Image upload
outputs=[gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
gr.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object
gr.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image
live=True
)
# Launch the Gradio app
interface.launch()
|