File size: 4,834 Bytes
7b2eca8
 
 
 
 
 
 
 
 
 
 
 
79a071f
7b2eca8
 
 
 
 
 
 
 
 
 
 
 
 
88b5781
79a071f
 
 
 
7b2eca8
 
 
 
 
7d83d86
7b2eca8
 
 
79a071f
 
 
 
 
7b2eca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79a071f
 
7b2eca8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# generate.py

import torch
from PIL import Image
from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer
from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder
from vit_captioning.models.decoder import TransformerDecoder

import argparse


class CaptionGenerator:
    def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False):
        print(f"Loading {model_type} | Quantized: {quantized}")
        # Setup device
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            print("Using NVIDIA CUDA GPU acceleration.")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
            print("Using Apple MPS GPU acceleration.")
        else:
            self.device = torch.device("cpu")
            print("No GPU found, falling back to CPU.")

        # Load tokenizer
        #self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        if (runAsContainer):
            self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer')
        else:
            self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

        # Select encoder, processor, output dim
        if model_type == "ViTEncoder":
            self.encoder = ViTEncoder().to(self.device)
            self.encoder_dim = 768
            self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
        elif model_type == "CLIPEncoder":
            self.encoder = CLIPEncoder().to(self.device)
            self.encoder_dim = 512
            if (runAsContainer):
                self.processor = CLIPProcessor.from_pretrained("/models/clip")
            else:
                self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        else:
            raise ValueError("Unknown model type")

        if quantized:
            print("Applying dynamic quantization to encoder...")
            self.encoder = torch.ao.quantization.quantize_dynamic(
                self.encoder,
                {torch.nn.Linear},
                dtype=torch.qint8
            )

        # Initialize decoder
        self.decoder = TransformerDecoder(
            vocab_size=30522,
            hidden_dim=self.encoder_dim,
            encoder_dim=self.encoder_dim
        ).to(self.device)

        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
        self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
        self.encoder.eval()
        self.decoder.eval()

    def generate_caption(self, image_path: str) -> dict:
        image = Image.open(image_path).convert("RGB")
        encoding = self.processor(images=image, return_tensors='pt')
        pixel_values = encoding['pixel_values'].to(self.device)

        captions = {}

        with torch.no_grad():
            encoder_outputs = self.encoder(pixel_values)

            # Greedy
            caption_ids = self.decoder.generate(encoder_outputs, mode="greedy")
            captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)

            # Top-k
            caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30)
            captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)

            # Top-p
            caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92)
            captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)

        return captions


if __name__ == "__main__":
    # CLI usage
    parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.")
    parser.add_argument("--model", type=str, default="ViTEncoder",
                        choices=["ViTEncoder", "CLIPEncoder"],
                        help="Choose encoder: ViTEncoder or CLIPEncoder")
    parser.add_argument("--checkpoint", type=str, required=True,
                        help="Path to the .pth checkpoint file")
    parser.add_argument("--image", type=str, required=True,
                        help="Path to input image file")
    parser.add_argument(
        "--quantized",
        action="store_true",
        help="Load encoder with dynamic quantization"
    )  ### ✅ ADDED

    args = parser.parse_args()

    generator = CaptionGenerator(
        model_type=args.model,
        checkpoint_path=args.checkpoint,
        runAsContainer=True
    )

    captions = generator.generate_caption(args.image)

    print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}")
    print(f"Top-k (diverse, creative): {captions['topk']}")
    print(f"Top-p (diverse, human-like): {captions['topp']}")