damerajee commited on
Commit
98463c7
·
verified ·
1 Parent(s): 5b8f163

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +23 -33
modeling_gpt2vision.py CHANGED
@@ -1,10 +1,8 @@
1
  import torch
2
- from torch import nn
3
- from transformers import PreTrainedModel,AutoTokenizer
4
- import re
5
-
6
- from .vision_encoder import VisionEncoder
7
  from .configuration_gpt2vision import GPT2VisionConfig
 
8
  from .modeling_gpt2 import GPT2LMHeadModel
9
 
10
  IMAGE_TOKEN = "<image>"
@@ -36,7 +34,7 @@ class MLP(nn.Module):
36
  x = self.dropout(x)
37
  x = self.fc2(x)
38
  return x
39
-
40
  class GPT2Vision(PreTrainedModel):
41
  config_class = GPT2VisionConfig
42
 
@@ -49,7 +47,6 @@ class GPT2Vision(PreTrainedModel):
49
  self.tokenizer = tokenizer
50
  tokenizer.pad_token = tokenizer.eos_token
51
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
52
- self.img_tokens = 197
53
 
54
  @property
55
  def device(self):
@@ -60,48 +57,41 @@ class GPT2Vision(PreTrainedModel):
60
  images = batch['image']
61
  if isinstance(text, str):
62
  text = [text]
63
-
64
  text_inputs = self.tokenizer(
65
- text,
66
  padding='max_length',
67
  truncation=True,
68
  max_length=768,
69
  return_tensors="pt",
 
70
  ).to(device)
71
-
72
- # Adjust attention mask to account for image tokens and the extra <image> token
73
- batch_size = text_inputs.input_ids.shape[0]
74
- img_attention = torch.ones((batch_size, self.img_tokens + 1), dtype=torch.long, device=device)
75
- attention_mask = torch.cat([img_attention, text_inputs.attention_mask[:, 1:]], dim=1)
76
-
77
  return {
78
  "input_ids": text_inputs.input_ids,
79
- "attention_mask": attention_mask,
80
- "images": images
81
  }
82
-
83
- def preprocess_inputs(self, batch):
84
- images = batch['images']
85
- input_ids = batch['input_ids'].to(self.device)
86
- attention_mask = batch['attention_mask'].to(self.device)
87
 
88
- img_embs = self.vision_encoder(images, device=self.device)
89
- img_embs = self.mlp(img_embs)
 
 
 
 
 
 
90
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
91
-
92
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
 
 
 
93
 
94
- # Ensure the attention mask aligns with the inputs_embeds
95
- assert inputs_embeds.shape[1] == attention_mask.shape[1], f"Mismatch between embeddings ({inputs_embeds.shape[1]}) and attention mask length ({attention_mask.shape[1]})."
96
-
97
- return inputs_embeds, attention_mask
98
-
99
-
100
  def generate(self, question, image, max_new_tokens=30, **kwargs):
101
- prompt = f"\n\nQuestion: {question}\n\nAnswer:"
102
  batch = {"image": [image], "text": prompt}
103
  encoded_batch = self.tokenize_encode(batch, self.device)
104
- inputs_embeds, attention_mask = self.preprocess_inputs(encoded_batch)
105
  output_sequences = self.language_model.generate(
106
  inputs_embeds=inputs_embeds,
107
  attention_mask=attention_mask,
 
1
  import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoTokenizer
 
 
 
4
  from .configuration_gpt2vision import GPT2VisionConfig
5
+ from .vision_encoder import VisionEncoder
6
  from .modeling_gpt2 import GPT2LMHeadModel
7
 
8
  IMAGE_TOKEN = "<image>"
 
34
  x = self.dropout(x)
35
  x = self.fc2(x)
36
  return x
37
+
38
  class GPT2Vision(PreTrainedModel):
39
  config_class = GPT2VisionConfig
40
 
 
47
  self.tokenizer = tokenizer
48
  tokenizer.pad_token = tokenizer.eos_token
49
  self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
 
50
 
51
  @property
52
  def device(self):
 
57
  images = batch['image']
58
  if isinstance(text, str):
59
  text = [text]
60
+ input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
61
  text_inputs = self.tokenizer(
62
+ input_texts,
63
  padding='max_length',
64
  truncation=True,
65
  max_length=768,
66
  return_tensors="pt",
67
+ pad_to_multiple_of=8,
68
  ).to(device)
69
+ pixel_values = self.vision_encoder(images, device)
 
 
 
 
 
70
  return {
71
  "input_ids": text_inputs.input_ids,
72
+ "attention_mask": text_inputs.attention_mask,
73
+ "pixel_values": pixel_values
74
  }
 
 
 
 
 
75
 
76
+ def preprocess_inputs(self, batch):
77
+ pixel_values = batch['pixel_values'].squeeze(1)
78
+ input_ids = batch['input_ids'].squeeze(1)
79
+ attention_mask = batch['attention_mask'].squeeze(1)
80
+ input_ids = input_ids.to(self.device)
81
+ attention_mask = attention_mask.to(self.device)
82
+ pixel_values = pixel_values.to(self.device)
83
+ img_embs = self.mlp(pixel_values)
84
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
 
85
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
86
+ img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
87
+ attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
88
+ return inputs_embeds, attention_mask, input_ids
89
 
 
 
 
 
 
 
90
  def generate(self, question, image, max_new_tokens=30, **kwargs):
91
+ prompt = f"Question: {question}\nAnswer:"
92
  batch = {"image": [image], "text": prompt}
93
  encoded_batch = self.tokenize_encode(batch, self.device)
94
+ inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(encoded_batch)
95
  output_sequences = self.language_model.generate(
96
  inputs_embeds=inputs_embeds,
97
  attention_mask=attention_mask,