Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor | |
from PIL import Image | |
from torchvision.transforms.functional import crop | |
import gradio as gr | |
# Load models during initialization | |
def init(): | |
global object_detection_model, captioning_model, tokenizer, captioning_processor | |
# Step 1: Load the YOLOv5 model from Hugging Face | |
try: | |
# Load the YOLOv5 model with trust_repo=True to skip the warning | |
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt', trust_repo=True) | |
# Assuming model is locally available | |
print("YOLOv5 model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading YOLOv5 model: {e}") | |
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face | |
try: | |
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") | |
print("ViT-GPT2 model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading captioning model: {e}") | |
# Utility function to crop objects from the image based on bounding boxes | |
def crop_objects(image, boxes): | |
cropped_images = [] | |
for box in boxes: | |
cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0])) | |
cropped_images.append(cropped_image) | |
return cropped_images | |
# Gradio interface function | |
def process_image(image): | |
try: | |
# Step 1: Perform object detection with YOLOv5 | |
results = object_detection_model(image) | |
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes | |
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names | |
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores | |
# Step 2: Generate caption for the whole image | |
original_inputs = captioning_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
original_caption_ids = captioning_model.generate(**original_inputs) | |
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True) | |
# Step 3: Crop detected objects and generate captions for each object | |
cropped_images = crop_objects(image, boxes) | |
captions = [] | |
for cropped_image in cropped_images: | |
inputs = captioning_processor(images=cropped_image, return_tensors="pt") | |
with torch.no_grad(): | |
caption_ids = captioning_model.generate(**inputs) | |
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
captions.append(caption) | |
# Prepare the result for visualization | |
detection_results = [] | |
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)): | |
detection_results.append({ | |
"label": label, | |
"caption": caption, | |
"bounding_box": [float(coord) for coord in box], # Convert to float | |
"confidence_score": float(score) # Convert to float | |
}) | |
# Return the image with detections and the caption | |
return results.render()[0], detection_results, original_caption | |
except Exception as e: | |
return None, {"error": str(e)}, None | |
# Initialize models | |
init() | |
# Updated Gradio interface with new syntax | |
interface = gr.Interface( | |
fn=process_image, # Function to run | |
inputs=gr.Image(type="pil"), # Input: Image upload | |
outputs=[gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes | |
gr.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object | |
gr.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image | |
live=True | |
) | |
# Launch the Gradio app | |
interface.launch() | |