!pip install openai-clip from transformers import AutoTokenizer, AutoModel import clip import skimage.io as io import PIL.Image from IPython.display import Image from transformers import AutoTokenizer, AutoModel import skimage.io as io import PIL.Image from IPython.display import Image import pandas as pd import numpy as np import time import json import nltk nltk.download('punkt') class ClipGPT2Model(nn.Module): def __init__(self, img_feature_length, img_feature_size = 512): super(ClipGPT2Model, self).__init__() torch.cuda.empty_cache() gc.collect() self.img_feature_length = img_feature_length self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] self.clip_project = Adapter((img_feature_size, (self.gpt_embedding_size * img_feature_length) // 2, self.gpt_embedding_size * img_feature_length)) torch.cuda.empty_cache() def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.zeros(batch_size, self.img_feature_length, dtype=torch.int64, device=device) def forward(self, tokens: torch.Tensor, feature: torch.Tensor, mask = None, labels = None): torch.cuda.empty_cache() gc.collect() embedding_text = self.gpt.transformer.wte(tokens) feature_projections = self.clip_project(feature).view(-1, self.img_feature_length, self.gpt_embedding_size) embedding_cat = torch.cat((feature_projections, embedding_text), dim=1) if labels is not None: dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) labels = torch.cat((dummy_token, tokens), dim=1) out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) return out def generate_beam( model, tokenizer, beam_size: int = 10, prompt=None, embed=None, entry_length=76, temperature=0.9, stop_token: str = ".", ): model.eval() stop_token_index = tokenizer.encode(stop_token)[0] tokens = None scores = None device = next(model.parameters()).device seq_lengths = torch.ones(beam_size, device=device) is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) with torch.no_grad(): if embed is not None: generated = embed else: if tokens is None: tokens = torch.tensor(tokenizer.encode(prompt)) tokens = tokens.unsqueeze(0).to(device) generated = model.gpt.transformer.wte(tokens) for i in range(entry_length): outputs = model.gpt(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) logits = logits.softmax(-1).log() if scores is None: scores, next_tokens = logits.topk(beam_size, -1) generated = generated.expand(beam_size, *generated.shape[1:]) next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) if tokens is None: tokens = next_tokens else: tokens = tokens.expand(beam_size, *tokens.shape[1:]) tokens = torch.cat((tokens, next_tokens), dim=1) else: logits[is_stopped] = -float(np.inf) logits[is_stopped, 0] = 0 scores_sum = scores[:, None] + logits seq_lengths[~is_stopped] += 1 scores_sum_average = scores_sum / seq_lengths[:, None] scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( beam_size, -1 ) next_tokens_source = next_tokens // scores_sum.shape[1] seq_lengths = seq_lengths[next_tokens_source] next_tokens = next_tokens % scores_sum.shape[1] next_tokens = next_tokens.unsqueeze(1) tokens = tokens[next_tokens_source] tokens = torch.cat((tokens, next_tokens), dim=1) generated = generated[next_tokens_source] scores = scores_sum_average * seq_lengths is_stopped = is_stopped[next_tokens_source] next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view( generated.shape[0], 1, -1 ) generated = torch.cat((generated, next_token_embed), dim=1) is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() if is_stopped.all(): break scores = scores / seq_lengths output_list = tokens.cpu().numpy() output_texts = [ tokenizer.decode(output[: int(length)]) for output, length in zip(output_list, seq_lengths) ] order = scores.argsort(descending=True) output_texts = [output_texts[i] for i in order] return output_texts def generate_caption_clipgpt(img): prefix_length = 10 model = ClipGPT2Model(prefix_length, img_feature_size = feature_dim) model.load_state_dict(torch.load('model_train_best_run_clipGPT.pt')) model = model.eval() device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) clip_model, preprocess = clip.load('ViT-B/32', device, jit=False) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") start_time = time.time() image = io.imread(img) pil_image = PIL.Image.fromarray(image) image = preprocess(pil_image).unsqueeze(0).to(device) with torch.no_grad(): prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) beam_caption = generate_beam(model, tokenizer, embed=prefix_embed)[0] end_time = time.time() print("--- Time taken to generate: %s seconds ---" % (end_time - start_time)) return beam_caption