Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,834 Bytes
7b2eca8 79a071f 7b2eca8 88b5781 79a071f 7b2eca8 7d83d86 7b2eca8 79a071f 7b2eca8 79a071f 7b2eca8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# generate.py
import torch
from PIL import Image
from transformers import ViTImageProcessor, CLIPProcessor, AutoTokenizer
from vit_captioning.models.encoder import ViTEncoder, CLIPEncoder
from vit_captioning.models.decoder import TransformerDecoder
import argparse
class CaptionGenerator:
def __init__(self, model_type: str, checkpoint_path: str, quantized=False, runAsContainer=False):
print(f"Loading {model_type} | Quantized: {quantized}")
# Setup device
if torch.cuda.is_available():
self.device = torch.device("cuda")
print("Using NVIDIA CUDA GPU acceleration.")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
print("Using Apple MPS GPU acceleration.")
else:
self.device = torch.device("cpu")
print("No GPU found, falling back to CPU.")
# Load tokenizer
#self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
if (runAsContainer):
self.tokenizer = AutoTokenizer.from_pretrained('/models/bert-tokenizer')
else:
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# Select encoder, processor, output dim
if model_type == "ViTEncoder":
self.encoder = ViTEncoder().to(self.device)
self.encoder_dim = 768
self.processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
elif model_type == "CLIPEncoder":
self.encoder = CLIPEncoder().to(self.device)
self.encoder_dim = 512
if (runAsContainer):
self.processor = CLIPProcessor.from_pretrained("/models/clip")
else:
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
else:
raise ValueError("Unknown model type")
if quantized:
print("Applying dynamic quantization to encoder...")
self.encoder = torch.ao.quantization.quantize_dynamic(
self.encoder,
{torch.nn.Linear},
dtype=torch.qint8
)
# Initialize decoder
self.decoder = TransformerDecoder(
vocab_size=30522,
hidden_dim=self.encoder_dim,
encoder_dim=self.encoder_dim
).to(self.device)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
self.encoder.eval()
self.decoder.eval()
def generate_caption(self, image_path: str) -> dict:
image = Image.open(image_path).convert("RGB")
encoding = self.processor(images=image, return_tensors='pt')
pixel_values = encoding['pixel_values'].to(self.device)
captions = {}
with torch.no_grad():
encoder_outputs = self.encoder(pixel_values)
# Greedy
caption_ids = self.decoder.generate(encoder_outputs, mode="greedy")
captions['greedy'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
# Top-k
caption_ids = self.decoder.generate(encoder_outputs, mode="topk", top_k=30)
captions['topk'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
# Top-p
caption_ids = self.decoder.generate(encoder_outputs, mode="topp", top_p=0.92)
captions['topp'] = self.tokenizer.decode(caption_ids[0], skip_special_tokens=True)
return captions
if __name__ == "__main__":
# CLI usage
parser = argparse.ArgumentParser(description="Generate caption using ViT or CLIP.")
parser.add_argument("--model", type=str, default="ViTEncoder",
choices=["ViTEncoder", "CLIPEncoder"],
help="Choose encoder: ViTEncoder or CLIPEncoder")
parser.add_argument("--checkpoint", type=str, required=True,
help="Path to the .pth checkpoint file")
parser.add_argument("--image", type=str, required=True,
help="Path to input image file")
parser.add_argument(
"--quantized",
action="store_true",
help="Load encoder with dynamic quantization"
) ### ✅ ADDED
args = parser.parse_args()
generator = CaptionGenerator(
model_type=args.model,
checkpoint_path=args.checkpoint,
runAsContainer=True
)
captions = generator.generate_caption(args.image)
print(f"Greedy-argmax (deterministic, factual): {captions['greedy']}")
print(f"Top-k (diverse, creative): {captions['topk']}")
print(f"Top-p (diverse, human-like): {captions['topp']}") |