damerajee commited on
Commit
0dd73f3
·
verified ·
1 Parent(s): f966943

Create modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +171 -0
modeling_gpt2vision.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoTokenizer
4
+ from .configuration_gpt2vision import GPT2VisionConfig ,GPT2Config
5
+ from .modeling_gpt2 import GPT2LMHeadModel
6
+ from .vision_encoder import VisionEncoder
7
+
8
+ IMAGE_TOKEN = "<image>"
9
+ ANSWER_EOS = "<|endoftext|>"
10
+
11
+ def resize_token_embeds(model_name="openai-community/gpt2"):
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
14
+ return tokenizer
15
+
16
+ tokenizer = resize_token_embeds()
17
+
18
+ print("tokenizer",tokenizer)
19
+ def create_labels(input_ids, tokenizer, attention_mask):
20
+ labels = input_ids.clone()
21
+
22
+ labels[attention_mask == 0] = -100
23
+
24
+ answer_start_tokens = tokenizer.encode("Answer:", add_special_tokens=False)
25
+
26
+ for i, seq in enumerate(input_ids):
27
+ # Find the start of the answer
28
+ answer_start = (seq == answer_start_tokens[0]).nonzero(as_tuple=True)[0]
29
+ if len(answer_start) > 0:
30
+ answer_start = answer_start[0]
31
+ if seq[answer_start:answer_start+len(answer_start_tokens)].tolist() == answer_start_tokens:
32
+ # Mask out everything before the answer
33
+ labels[i, :answer_start] = -100
34
+
35
+ # Find the end of the sequence (last non-padding token)
36
+ sequence_end = attention_mask[i].nonzero(as_tuple=True)[0][-1]
37
+
38
+ # Keep the last token (EOS) as part of the label
39
+ labels[i, sequence_end+1:] = -100
40
+
41
+ return labels
42
+
43
+ class MLP(nn.Module):
44
+ def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
45
+ super().__init__()
46
+ out_features = out_features or in_features
47
+ hidden_features = hidden_features or in_features
48
+ self.fc1 = nn.Linear(in_features, hidden_features)
49
+ self.act = nn.GELU(approximate="tanh")
50
+ self.fc2 = nn.Linear(hidden_features, out_features)
51
+ self.dropout = nn.Dropout(p=0.1)
52
+
53
+ # Initialize weights
54
+ nn.init.xavier_normal_(self.fc1.weight)
55
+ nn.init.zeros_(self.fc1.bias)
56
+ nn.init.xavier_normal_(self.fc2.weight)
57
+ nn.init.zeros_(self.fc2.bias)
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.dropout(x)
63
+ x = self.fc2(x)
64
+ return x
65
+
66
+
67
+ class GPT2Vision(PreTrainedModel):
68
+ config_class = GPT2VisionConfig
69
+
70
+ def __init__(self, config):
71
+ super().__init__(config)
72
+ self.vision_encoder = VisionEncoder()
73
+ self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768)
74
+
75
+ self.language_model = GPT2LMHeadModel(config.gpt2_config)
76
+
77
+ self.language_model.resize_token_embeddings(len(tokenizer))
78
+
79
+ self.tokenizer = tokenizer
80
+ tokenizer.pad_token = tokenizer.eos_token
81
+
82
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
83
+
84
+ @property
85
+ def device(self):
86
+ return next(self.language_model.parameters()).device
87
+
88
+ def freeze_model_components(self, freeze_vision=True, freeze_language=True,freeze_mlp=True):
89
+ for param in self.vision_encoder.parameters():
90
+ param.requires_grad = not freeze_vision
91
+ for param in self.language_model.parameters():
92
+ param.requires_grad = not freeze_language
93
+ for param in self.mlp.parameters():
94
+ param.requires_grad = not freeze_mlp
95
+
96
+ def tokenize_encode(self, batch, device):
97
+ text = batch['text']
98
+ images = batch['image']
99
+
100
+ if isinstance(text, str):
101
+ text = [text]
102
+
103
+ input_texts = [f"{IMAGE_TOKEN}{self.tokenizer.bos_token}{t}" for t in text]
104
+ text_inputs = self.tokenizer(
105
+ input_texts,
106
+ padding='max_length',
107
+ truncation=True,
108
+ max_length=384,
109
+ return_tensors="pt",
110
+ pad_to_multiple_of=8,
111
+ ).to(device)
112
+
113
+ pixel_values = self.vision_encoder(images,device)
114
+
115
+ return {
116
+ "input_ids": text_inputs.input_ids,
117
+ "attention_mask": text_inputs.attention_mask,
118
+ "pixel_values": pixel_values
119
+ }
120
+
121
+ def preprocess_inputs(self, batch):
122
+ pixel_values = batch['pixel_values'].squeeze(1)
123
+ input_ids = batch['input_ids'].squeeze(1)
124
+ attention_mask = batch['attention_mask'].squeeze(1)
125
+
126
+ input_ids = input_ids.to(self.device)
127
+ attention_mask = attention_mask.to(self.device)
128
+ pixel_values = pixel_values.to(self.device)
129
+
130
+ labels = create_labels(input_ids, self.tokenizer, attention_mask)
131
+ labels = labels.to(self.device)
132
+
133
+ img_embs = self.mlp(pixel_values)
134
+ tok_embs = self.language_model.get_input_embeddings()(input_ids)
135
+
136
+ inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
137
+
138
+ img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
139
+ attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
140
+
141
+ img_labels = torch.full((labels.size(0), img_embs.size(1)), fill_value=-100, dtype=torch.long, device=self.device)
142
+ labels = torch.cat((labels[:, 0:1], img_labels, labels[:, 1:]), dim=1)
143
+ return inputs_embeds, attention_mask, input_ids, labels
144
+
145
+ def forward(self, batch, **kwargs):
146
+ inputs_embeds, attention_mask, input_ids, labels = self.preprocess_inputs(batch)
147
+ outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
148
+
149
+
150
+ return outputs
151
+
152
+ def generate(self, question, image, max_new_tokens=30, **kwargs):
153
+ prompt = prompt = f"Question: {question}\nAnswer:"
154
+ batch = {"image": [image], "text": prompt}
155
+ encoded_batch = self.tokenize_encode(batch, self.device)
156
+ inputs_embeds, attention_mask, input_ids, _ = self.preprocess_inputs(encoded_batch)
157
+
158
+
159
+
160
+
161
+ output_sequences = self.language_model.generate(
162
+ inputs_embeds=inputs_embeds,
163
+ attention_mask=attention_mask,
164
+ max_new_tokens=max_new_tokens,
165
+ pad_token_id=self.tokenizer.eos_token_id,
166
+ eos_token_id=self.tokenizer.eos_token_id,
167
+ **kwargs
168
+ )
169
+
170
+ output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
171
+ return output