Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +52 -2
modeling_gpt2vision.py
CHANGED
@@ -18,6 +18,30 @@ def resize_token_embeds(model_name="openai-community/gpt2"):
|
|
18 |
|
19 |
tokenizer = resize_token_embeds()
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class MLP(nn.Module):
|
22 |
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
|
23 |
super().__init__()
|
@@ -62,7 +86,7 @@ class GPT2Vision(PreTrainedModel):
|
|
62 |
input_texts,
|
63 |
padding='max_length',
|
64 |
truncation=True,
|
65 |
-
max_length=
|
66 |
return_tensors="pt",
|
67 |
).to(device)
|
68 |
pixel_values = self.vision_encoder(images, device)
|
@@ -72,20 +96,46 @@ class GPT2Vision(PreTrainedModel):
|
|
72 |
"pixel_values": pixel_values
|
73 |
}
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def preprocess_inputs(self, batch):
|
76 |
pixel_values = batch['pixel_values'].squeeze(1)
|
77 |
input_ids = batch['input_ids'].squeeze(1)
|
78 |
attention_mask = batch['attention_mask'].squeeze(1)
|
|
|
79 |
input_ids = input_ids.to(self.device)
|
80 |
attention_mask = attention_mask.to(self.device)
|
81 |
pixel_values = pixel_values.to(self.device)
|
|
|
|
|
|
|
|
|
82 |
img_embs = self.mlp(pixel_values)
|
83 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
|
|
84 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
|
|
85 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
86 |
attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
|
87 |
-
return inputs_embeds, attention_mask, input_ids
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
90 |
prompt = f"Question: {question}\nAnswer:"
|
91 |
batch = {"image": [image], "text": prompt}
|
|
|
18 |
|
19 |
tokenizer = resize_token_embeds()
|
20 |
|
21 |
+
def create_labels(input_ids, tokenizer, attention_mask):
|
22 |
+
labels = input_ids.clone()
|
23 |
+
|
24 |
+
labels[attention_mask == 0] = -100
|
25 |
+
|
26 |
+
answer_start_tokens = tokenizer.encode("Answer:", add_special_tokens=False)
|
27 |
+
|
28 |
+
for i, seq in enumerate(input_ids):
|
29 |
+
# Find the start of the answer
|
30 |
+
answer_start = (seq == answer_start_tokens[0]).nonzero(as_tuple=True)[0]
|
31 |
+
if len(answer_start) > 0:
|
32 |
+
answer_start = answer_start[0]
|
33 |
+
if seq[answer_start:answer_start+len(answer_start_tokens)].tolist() == answer_start_tokens:
|
34 |
+
# Mask out everything before the answer
|
35 |
+
labels[i, :answer_start] = -100
|
36 |
+
|
37 |
+
# Find the end of the sequence (last non-padding token)
|
38 |
+
sequence_end = attention_mask[i].nonzero(as_tuple=True)[0][-1]
|
39 |
+
|
40 |
+
# Keep the last token (EOS) as part of the label
|
41 |
+
labels[i, sequence_end+1:] = -100
|
42 |
+
|
43 |
+
return labels
|
44 |
+
|
45 |
class MLP(nn.Module):
|
46 |
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
|
47 |
super().__init__()
|
|
|
86 |
input_texts,
|
87 |
padding='max_length',
|
88 |
truncation=True,
|
89 |
+
max_length=768,
|
90 |
return_tensors="pt",
|
91 |
).to(device)
|
92 |
pixel_values = self.vision_encoder(images, device)
|
|
|
96 |
"pixel_values": pixel_values
|
97 |
}
|
98 |
|
99 |
+
def freeze_model_components(self, freeze_vision=True, freeze_language=True, freeze_mlp=False):
|
100 |
+
for param in self.vision_model.parameters():
|
101 |
+
param.requires_grad = not freeze_vision
|
102 |
+
for param in self.language_model.parameters():
|
103 |
+
param.requires_grad = not freeze_language
|
104 |
+
for param in self.mlp.parameters():
|
105 |
+
param.requires_grad = not freeze_mlp
|
106 |
+
|
107 |
def preprocess_inputs(self, batch):
|
108 |
pixel_values = batch['pixel_values'].squeeze(1)
|
109 |
input_ids = batch['input_ids'].squeeze(1)
|
110 |
attention_mask = batch['attention_mask'].squeeze(1)
|
111 |
+
|
112 |
input_ids = input_ids.to(self.device)
|
113 |
attention_mask = attention_mask.to(self.device)
|
114 |
pixel_values = pixel_values.to(self.device)
|
115 |
+
|
116 |
+
labels = create_labels(input_ids, self.tokenizer, attention_mask)
|
117 |
+
labels = labels.to(self.device)
|
118 |
+
|
119 |
img_embs = self.mlp(pixel_values)
|
120 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
121 |
+
|
122 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
123 |
+
|
124 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
125 |
attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
|
|
|
126 |
|
127 |
+
img_labels = torch.full((labels.size(0), img_embs.size(1)), fill_value=-100, dtype=torch.long, device=self.device)
|
128 |
+
labels = torch.cat((labels[:, 0:1], img_labels, labels[:, 1:]), dim=1)
|
129 |
+
|
130 |
+
return inputs_embeds, attention_mask, input_ids, labels
|
131 |
+
|
132 |
+
|
133 |
+
def forward(self, batch, **kwargs):
|
134 |
+
inputs_embeds, attention_mask, input_ids, labels = self.preprocess_inputs(batch)
|
135 |
+
|
136 |
+
outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
|
137 |
+
return outputs
|
138 |
+
|
139 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
140 |
prompt = f"Question: {question}\nAnswer:"
|
141 |
batch = {"image": [image], "text": prompt}
|