import torch from torch import nn from transformers import PreTrainedModel import re from .vision_encoder import VisionEncoder from .configuration_gpt2vision import GPT2VisionConfig from .modeling_gpt2 import GPT2LMHeadModel IMAGE_TOKEN = "" ANSWER_EOS = "<|endoftext|>" def resize_token_embeds(model_name="openai-community/gpt2"): tokenizer = AutoTokenizer.from_pretrained(model_name) new_tokens = { "additional_special_tokens": [IMAGE_TOKEN] } tokenizer.add_special_tokens(new_tokens) return tokenizer tokenizer = resize_token_embeds() class GPT2Vision(PreTrainedModel): config_class = GPT2VisionConfig def __init__(self, config): super().__init__(config) self.vision_encoder = VisionEncoder() self.language_model.resize_token_embeddings(len(tokenizer)) self.tokenizer = tokenizer tokenizer.pad_token = tokenizer.eos_token self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) if isinstance(config.gpt2_config, dict): gpt2_config = GPT2Config(**config.gpt2_config) else: gpt2_config = config.gpt2_config self.text_model = GPT2LMHeadModel(gpt2_config) @property def device(self): return self.text_model.device def encode_image(self, image,device): return self.vision_encoder(image,device=device) def input_embeds(self, prompt, image_embeds, tokenizer): def _tokenize(txt): return tokenizer( txt, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.device) text_emb = self.text_model.get_input_embeddings() # Add BOS token embeds = [] embeds.append( text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device))) ) if "" not in prompt: embeds.append(text_emb(_tokenize(prompt))) else: assert prompt.count("") == 1 before, after = prompt.split("") embeds.append(text_emb(_tokenize(f"{before}"))) embeds.append(image_embeds.to(self.device)) embeds.append(text_emb(_tokenize(f"{after}"))) return torch.cat(embeds, dim=1) def generate( self, image_embeds, prompt, tokenizer, eos_text="<|endoftext|>", max_new_tokens=128, **kwargs, ): eos_tokens = tokenizer(eos_text, add_special_tokens=False)["input_ids"] generate_config = { "eos_token_id": eos_tokens, "bos_token_id": tokenizer.bos_token_id, "pad_token_id": tokenizer.eos_token_id, "max_new_tokens": max_new_tokens, **kwargs, } with torch.no_grad(): inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer) print("inputs_embeds",inputs_embeds.size()) output_ids = self.text_model.generate( inputs_embeds=inputs_embeds, **generate_config ) return tokenizer.batch_decode(output_ids, skip_special_tokens=True) def answer_question( self, image_embeds, question, tokenizer, chat_history="", result_queue=None, **kwargs, ): prompt = f"\n\n{chat_history}Question: {question}\n\nAnswer: " answer = self.generate( image_embeds, prompt, tokenizer, eos_text="<|endoftext|>", max_new_tokens=256, **kwargs, )[0] return answer