import torch from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor from PIL import Image from torchvision.transforms.functional import crop import gradio as gr # Load models during initialization def init(): global object_detection_model, captioning_model, tokenizer, captioning_processor # Step 1: Load the YOLOv5 model from Hugging Face try: # Load the YOLOv5 model with trust_repo=True to skip the warning object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt', trust_repo=True) # Assuming model is locally available print("YOLOv5 model loaded successfully.") except Exception as e: print(f"Error loading YOLOv5 model: {e}") # Step 2: Load the ViT-GPT2 captioning model from Hugging Face try: captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning") print("ViT-GPT2 model loaded successfully.") except Exception as e: print(f"Error loading captioning model: {e}") # Utility function to crop objects from the image based on bounding boxes def crop_objects(image, boxes): cropped_images = [] for box in boxes: cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0])) cropped_images.append(cropped_image) return cropped_images # Gradio interface function def process_image(image): try: # Step 1: Perform object detection with YOLOv5 results = object_detection_model(image) boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores # Step 2: Generate caption for the whole image original_inputs = captioning_processor(images=image, return_tensors="pt") with torch.no_grad(): original_caption_ids = captioning_model.generate(**original_inputs) original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True) # Step 3: Crop detected objects and generate captions for each object cropped_images = crop_objects(image, boxes) captions = [] for cropped_image in cropped_images: inputs = captioning_processor(images=cropped_image, return_tensors="pt") with torch.no_grad(): caption_ids = captioning_model.generate(**inputs) caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True) captions.append(caption) # Prepare the result for visualization detection_results = [] for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)): detection_results.append({ "label": label, "caption": caption, "bounding_box": [float(coord) for coord in box], # Convert to float "confidence_score": float(score) # Convert to float }) # Return the image with detections and the caption return results.render()[0], detection_results, original_caption except Exception as e: return None, {"error": str(e)}, None # Initialize models init() # Updated Gradio interface with new syntax interface = gr.Interface( fn=process_image, # Function to run inputs=gr.Image(type="pil"), # Input: Image upload outputs=[gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes gr.JSON(label="Object Captions & Bounding Boxes"), # Output 2: JSON results for each object gr.Textbox(label="Whole Image Caption")], # Output 3: Caption for the whole image live=True ) # Launch the Gradio app interface.launch()