damerajee commited on
Commit
fdd8533
·
verified ·
1 Parent(s): e5b4a7d

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +21 -29
modeling_gpt2vision.py CHANGED
@@ -45,42 +45,21 @@ class GPT2Vision(PreTrainedModel):
45
  self.language_model = GPT2LMHeadModel(config.gpt2_config)
46
  self.language_model.resize_token_embeddings(len(tokenizer))
47
  self.tokenizer = tokenizer
48
- tokenizer.pad_token = tokenizer.eos_token
49
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
50
 
51
  @property
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
55
- def tokenize_encode(self, batch, device):
56
- text = batch['text']
57
- images = batch['image']
58
- if isinstance(text, str):
59
- text = [text]
60
- input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
61
- text_inputs = self.tokenizer(
62
- input_texts,
63
- padding='max_length',
64
- truncation=True,
65
- max_length=768,
66
- return_tensors="pt",
67
- pad_to_multiple_of=8,
68
- ).to(device)
69
- pixel_values = self.vision_encoder(images, device)
70
- return {
71
- "input_ids": text_inputs.input_ids,
72
- "attention_mask": text_inputs.attention_mask,
73
- "pixel_values": pixel_values
74
- }
75
-
76
  def preprocess_inputs(self, batch):
77
- pixel_values = batch['pixel_values']
78
  input_ids = batch['input_ids']
79
  attention_mask = batch['attention_mask']
80
  input_ids = input_ids.to(self.device)
81
  attention_mask = attention_mask.to(self.device)
82
- pixel_values = pixel_values.to(self.device)
83
- img_embs = self.mlp(pixel_values)
84
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
85
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
86
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
@@ -88,10 +67,22 @@ class GPT2Vision(PreTrainedModel):
88
  return inputs_embeds, attention_mask, input_ids
89
 
90
  def generate(self, question, image, max_new_tokens=30, **kwargs):
91
- prompt = f"Question: {question}\nAnswer:"
92
- batch = {"image": [image], "text": prompt}
93
- encoded_batch = self.tokenize_encode(batch, self.device)
94
- inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(encoded_batch)
 
 
 
 
 
 
 
 
 
 
 
 
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,
@@ -100,5 +91,6 @@ class GPT2Vision(PreTrainedModel):
100
  max_new_tokens=max_new_tokens,
101
  **kwargs
102
  )
 
103
  output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
104
  return output
 
45
  self.language_model = GPT2LMHeadModel(config.gpt2_config)
46
  self.language_model.resize_token_embeddings(len(tokenizer))
47
  self.tokenizer = tokenizer
48
+ self.tokenizer.pad_token = self.tokenizer.eos_token
49
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
50
 
51
  @property
52
  def device(self):
53
  return next(self.language_model.parameters()).device
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def preprocess_inputs(self, batch):
56
+ img_embs = batch['pixel_values']
57
  input_ids = batch['input_ids']
58
  attention_mask = batch['attention_mask']
59
  input_ids = input_ids.to(self.device)
60
  attention_mask = attention_mask.to(self.device)
61
+ img_embs = img_embs.to(self.device)
62
+
63
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
64
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
65
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
 
67
  return inputs_embeds, attention_mask, input_ids
68
 
69
  def generate(self, question, image, max_new_tokens=30, **kwargs):
70
+ # Process the image
71
+ img_embs = self.vision_encoder(image.unsqueeze(0), device=self.device)
72
+ img_embs = self.mlp(img_embs)
73
+
74
+ # Tokenize the question
75
+ prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
76
+ encoded_input = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
+
78
+ batch = {
79
+ "pixel_values": img_embs,
80
+ "input_ids": encoded_input.input_ids,
81
+ "attention_mask": encoded_input.attention_mask
82
+ }
83
+
84
+ inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
85
+
86
  output_sequences = self.language_model.generate(
87
  inputs_embeds=inputs_embeds,
88
  attention_mask=attention_mask,
 
91
  max_new_tokens=max_new_tokens,
92
  **kwargs
93
  )
94
+
95
  output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
96
  return output