File size: 4,119 Bytes
3ce1aa7
 
 
 
 
 
 
 
 
 
 
 
eed6859
 
 
3ce1aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eed6859
3ce1aa7
 
eed6859
 
 
 
3ce1aa7
 
 
eed6859
3ce1aa7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr

# Load models during initialization
def init():
    global object_detection_model, captioning_model, tokenizer, captioning_processor

    # Step 1: Load the YOLOv5 model from Hugging Face
    try:
        # Load the YOLOv5 model with trust_repo=True to skip the warning
        object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path='yolov5/weights/best14.pt', trust_repo=True)
        # Assuming model is locally available
        print("YOLOv5 model loaded successfully.")
    except Exception as e:
        print(f"Error loading YOLOv5 model: {e}")

    # Step 2: Load the ViT-GPT2 captioning model from Hugging Face
    try:
        captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        print("ViT-GPT2 model loaded successfully.")
    except Exception as e:
        print(f"Error loading captioning model: {e}")

# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
    cropped_images = []
    for box in boxes:
        cropped_image = crop(image, int(box[1]), int(box[0]), int(box[3] - box[1]), int(box[2] - box[0]))
        cropped_images.append(cropped_image)
    return cropped_images

# Gradio interface function
def process_image(image):
    try:
        # Step 1: Perform object detection with YOLOv5
        results = object_detection_model(image)
        boxes = results.xyxy[0][:, :4].cpu().numpy()  # Bounding boxes
        labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)]  # Class names
        scores = results.xyxy[0][:, 4].cpu().numpy()  # Confidence scores

        # Step 2: Generate caption for the whole image
        original_inputs = captioning_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            original_caption_ids = captioning_model.generate(**original_inputs)
        original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)

        # Step 3: Crop detected objects and generate captions for each object
        cropped_images = crop_objects(image, boxes)
        captions = []
        for cropped_image in cropped_images:
            inputs = captioning_processor(images=cropped_image, return_tensors="pt")
            with torch.no_grad():
                caption_ids = captioning_model.generate(**inputs)
            caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
            captions.append(caption)

        # Prepare the result for visualization
        detection_results = []
        for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
            detection_results.append({
                "label": label,
                "caption": caption,
                "bounding_box": [float(coord) for coord in box],  # Convert to float
                "confidence_score": float(score)  # Convert to float
            })

        # Return the image with detections and the caption
        return results.render()[0], detection_results, original_caption

    except Exception as e:
        return None, {"error": str(e)}, None

# Initialize models
init()

# Updated Gradio interface with new syntax
interface = gr.Interface(
    fn=process_image,  # Function to run
    inputs=gr.Image(type="pil"),  # Input: Image upload
    outputs=[gr.Image(type="pil", label="Detected Objects"),  # Output 1: Image with bounding boxes
             gr.JSON(label="Object Captions & Bounding Boxes"),  # Output 2: JSON results for each object
             gr.Textbox(label="Whole Image Caption")],  # Output 3: Caption for the whole image
    live=True
)


# Launch the Gradio app
interface.launch()