damerajee commited on
Commit
56bf954
·
verified ·
1 Parent(s): bfd6014

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +52 -2
modeling_gpt2vision.py CHANGED
@@ -18,6 +18,30 @@ def resize_token_embeds(model_name="openai-community/gpt2"):
18
 
19
  tokenizer = resize_token_embeds()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class MLP(nn.Module):
22
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
23
  super().__init__()
@@ -62,7 +86,7 @@ class GPT2Vision(PreTrainedModel):
62
  input_texts,
63
  padding='max_length',
64
  truncation=True,
65
- max_length=384,
66
  return_tensors="pt",
67
  ).to(device)
68
  pixel_values = self.vision_encoder(images, device)
@@ -72,20 +96,46 @@ class GPT2Vision(PreTrainedModel):
72
  "pixel_values": pixel_values
73
  }
74
 
 
 
 
 
 
 
 
 
75
  def preprocess_inputs(self, batch):
76
  pixel_values = batch['pixel_values'].squeeze(1)
77
  input_ids = batch['input_ids'].squeeze(1)
78
  attention_mask = batch['attention_mask'].squeeze(1)
 
79
  input_ids = input_ids.to(self.device)
80
  attention_mask = attention_mask.to(self.device)
81
  pixel_values = pixel_values.to(self.device)
 
 
 
 
82
  img_embs = self.mlp(pixel_values)
83
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
 
84
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
 
85
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
86
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
87
- return inputs_embeds, attention_mask, input_ids
88
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def generate(self, question, image, max_new_tokens=30, **kwargs):
90
  prompt = f"Question: {question}\nAnswer:"
91
  batch = {"image": [image], "text": prompt}
 
18
 
19
  tokenizer = resize_token_embeds()
20
 
21
+ def create_labels(input_ids, tokenizer, attention_mask):
22
+ labels = input_ids.clone()
23
+
24
+ labels[attention_mask == 0] = -100
25
+
26
+ answer_start_tokens = tokenizer.encode("Answer:", add_special_tokens=False)
27
+
28
+ for i, seq in enumerate(input_ids):
29
+ # Find the start of the answer
30
+ answer_start = (seq == answer_start_tokens[0]).nonzero(as_tuple=True)[0]
31
+ if len(answer_start) > 0:
32
+ answer_start = answer_start[0]
33
+ if seq[answer_start:answer_start+len(answer_start_tokens)].tolist() == answer_start_tokens:
34
+ # Mask out everything before the answer
35
+ labels[i, :answer_start] = -100
36
+
37
+ # Find the end of the sequence (last non-padding token)
38
+ sequence_end = attention_mask[i].nonzero(as_tuple=True)[0][-1]
39
+
40
+ # Keep the last token (EOS) as part of the label
41
+ labels[i, sequence_end+1:] = -100
42
+
43
+ return labels
44
+
45
  class MLP(nn.Module):
46
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
47
  super().__init__()
 
86
  input_texts,
87
  padding='max_length',
88
  truncation=True,
89
+ max_length=768,
90
  return_tensors="pt",
91
  ).to(device)
92
  pixel_values = self.vision_encoder(images, device)
 
96
  "pixel_values": pixel_values
97
  }
98
 
99
+ def freeze_model_components(self, freeze_vision=True, freeze_language=True, freeze_mlp=False):
100
+ for param in self.vision_model.parameters():
101
+ param.requires_grad = not freeze_vision
102
+ for param in self.language_model.parameters():
103
+ param.requires_grad = not freeze_language
104
+ for param in self.mlp.parameters():
105
+ param.requires_grad = not freeze_mlp
106
+
107
  def preprocess_inputs(self, batch):
108
  pixel_values = batch['pixel_values'].squeeze(1)
109
  input_ids = batch['input_ids'].squeeze(1)
110
  attention_mask = batch['attention_mask'].squeeze(1)
111
+
112
  input_ids = input_ids.to(self.device)
113
  attention_mask = attention_mask.to(self.device)
114
  pixel_values = pixel_values.to(self.device)
115
+
116
+ labels = create_labels(input_ids, self.tokenizer, attention_mask)
117
+ labels = labels.to(self.device)
118
+
119
  img_embs = self.mlp(pixel_values)
120
  tok_embs = self.language_model.get_input_embeddings()(input_ids)
121
+
122
  inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
123
+
124
  img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
125
  attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
 
126
 
127
+ img_labels = torch.full((labels.size(0), img_embs.size(1)), fill_value=-100, dtype=torch.long, device=self.device)
128
+ labels = torch.cat((labels[:, 0:1], img_labels, labels[:, 1:]), dim=1)
129
+
130
+ return inputs_embeds, attention_mask, input_ids, labels
131
+
132
+
133
+ def forward(self, batch, **kwargs):
134
+ inputs_embeds, attention_mask, input_ids, labels = self.preprocess_inputs(batch)
135
+
136
+ outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
137
+ return outputs
138
+
139
  def generate(self, question, image, max_new_tokens=30, **kwargs):
140
  prompt = f"Question: {question}\nAnswer:"
141
  batch = {"image": [image], "text": prompt}