ai-lab / vit_captioning /generate.py
ClemSummer's picture
Added two more games. Container runs locally
79a071f
# generate.py
import torch
from PIL import Image
from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer
from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder
from vit_captioning.models.decoder import TransformerDecoder
import argparse
class CaptionGenerator:
def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False):
print(f"Loading {model_type} | Quantized: {quantized}")
# Setup device
if torch.cuda.is_available():
self.device = torch.device("cuda")
print("Using NVIDIA CUDA GPU acceleration.")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
print("Using Apple MPS GPU acceleration.")
else:
self.device = torch.device("cpu")
print("No GPU found, falling back to CPU.")
# Load tokenizer
#self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if (runAsContainer):
self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer')
else:
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# Select encoder, processor, output dim
if model_type == "ViTEncoder":
self.encoder = ViTEncoder().to(self.device)
self.encoder_dim = 768
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
elif model_type == "CLIPEncoder":
self.encoder = CLIPEncoder().to(self.device)
self.encoder_dim = 512
if (runAsContainer):
self.processor = CLIPProcessor.from_pretrained("/models/clip")
else:
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
else:
raise ValueError("Unknown model type")
if quantized:
print("Applying dynamic quantization to encoder...")
self.encoder = torch.ao.quantization.quantize_dynamic(
self.encoder,
{torch.nn.Linear},
dtype=torch.qint8
)
# Initialize decoder
self.decoder = TransformerDecoder(
vocab_size=30522,
hidden_dim=self.encoder_dim,
encoder_dim=self.encoder_dim
).to(self.device)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
self.encoder.eval()
self.decoder.eval()
def generate_caption(self, image_path: str) -> dict:
image = Image.open(image_path).convert("RGB")
encoding = self.processor(images=image, return_tensors='pt')
pixel_values = encoding['pixel_values'].to(self.device)
captions = {}
with torch.no_grad():
encoder_outputs = self.encoder(pixel_values)
# Greedy
caption_ids = self.decoder.generate(encoder_outputs, mode="greedy")
captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
# Top-k
caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30)
captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
# Top-p
caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92)
captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
return captions
if __name__ == "__main__":
# CLI usage
parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.")
parser.add_argument("--model", type=str, default="ViTEncoder",
choices=["ViTEncoder", "CLIPEncoder"],
help="Choose encoder: ViTEncoder or CLIPEncoder")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to the .pth checkpoint file")
parser.add_argument("--image", type=str, required=True,
help="Path to input image file")
parser.add_argument(
"--quantized",
action="store_true",
help="Load encoder with dynamic quantization"
) ### ✅ ADDED
args = parser.parse_args()
generator = CaptionGenerator(
model_type=args.model,
checkpoint_path=args.checkpoint,
runAsContainer=True
)
captions = generator.generate_caption(args.image)
print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}")
print(f"Top-k (diverse, creative): {captions['topk']}")
print(f"Top-p (diverse, human-like): {captions['topp']}")