Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# generate.py | |
import torch | |
from PIL import Image | |
from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer | |
from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder | |
from vit_captioning.models.decoder import TransformerDecoder | |
import argparse | |
class CaptionGenerator: | |
def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False): | |
print(f"Loading {model_type} | Quantized: {quantized}") | |
# Setup device | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
print("Using NVIDIA CUDA GPU acceleration.") | |
elif torch.backends.mps.is_available(): | |
self.device = torch.device("mps") | |
print("Using Apple MPS GPU acceleration.") | |
else: | |
self.device = torch.device("cpu") | |
print("No GPU found, falling back to CPU.") | |
# Load tokenizer | |
#self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
if (runAsContainer): | |
self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer') | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
# Select encoder, processor, output dim | |
if model_type == "ViTEncoder": | |
self.encoder = ViTEncoder().to(self.device) | |
self.encoder_dim = 768 | |
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
elif model_type == "CLIPEncoder": | |
self.encoder = CLIPEncoder().to(self.device) | |
self.encoder_dim = 512 | |
if (runAsContainer): | |
self.processor = CLIPProcessor.from_pretrained("/models/clip") | |
else: | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
else: | |
raise ValueError("Unknown model type") | |
if quantized: | |
print("Applying dynamic quantization to encoder...") | |
self.encoder = torch.ao.quantization.quantize_dynamic( | |
self.encoder, | |
{torch.nn.Linear}, | |
dtype=torch.qint8 | |
) | |
# Initialize decoder | |
self.decoder = TransformerDecoder( | |
vocab_size=30522, | |
hidden_dim=self.encoder_dim, | |
encoder_dim=self.encoder_dim | |
).to(self.device) | |
# Load checkpoint | |
checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
self.encoder.load_state_dict(checkpoint['encoder_state_dict']) | |
self.decoder.load_state_dict(checkpoint['decoder_state_dict']) | |
self.encoder.eval() | |
self.decoder.eval() | |
def generate_caption(self, image_path: str) -> dict: | |
image = Image.open(image_path).convert("RGB") | |
encoding = self.processor(images=image, return_tensors='pt') | |
pixel_values = encoding['pixel_values'].to(self.device) | |
captions = {} | |
with torch.no_grad(): | |
encoder_outputs = self.encoder(pixel_values) | |
# Greedy | |
caption_ids = self.decoder.generate(encoder_outputs, mode="greedy") | |
captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
# Top-k | |
caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30) | |
captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
# Top-p | |
caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92) | |
captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True) | |
return captions | |
if __name__ == "__main__": | |
# CLI usage | |
parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.") | |
parser.add_argument("--model", type=str, default="ViTEncoder", | |
choices=["ViTEncoder", "CLIPEncoder"], | |
help="Choose encoder: ViTEncoder or CLIPEncoder") | |
parser.add_argument("--checkpoint", type=str, required=True, | |
help="Path to the .pth checkpoint file") | |
parser.add_argument("--image", type=str, required=True, | |
help="Path to input image file") | |
parser.add_argument( | |
"--quantized", | |
action="store_true", | |
help="Load encoder with dynamic quantization" | |
) ### ✅ ADDED | |
args = parser.parse_args() | |
generator = CaptionGenerator( | |
model_type=args.model, | |
checkpoint_path=args.checkpoint, | |
runAsContainer=True | |
) | |
captions = generator.generate_caption(args.image) | |
print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}") | |
print(f"Top-k (diverse, creative): {captions['topk']}") | |
print(f"Top-p (diverse, human-like): {captions['topp']}") |