Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +21 -29
modeling_gpt2vision.py
CHANGED
@@ -45,42 +45,21 @@ class GPT2Vision(PreTrainedModel):
|
|
45 |
self.language_model = GPT2LMHeadModel(config.gpt2_config)
|
46 |
self.language_model.resize_token_embeddings(len(tokenizer))
|
47 |
self.tokenizer = tokenizer
|
48 |
-
tokenizer.pad_token = tokenizer.eos_token
|
49 |
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
50 |
|
51 |
@property
|
52 |
def device(self):
|
53 |
return next(self.language_model.parameters()).device
|
54 |
|
55 |
-
def tokenize_encode(self, batch, device):
|
56 |
-
text = batch['text']
|
57 |
-
images = batch['image']
|
58 |
-
if isinstance(text, str):
|
59 |
-
text = [text]
|
60 |
-
input_texts = [f"{IMAGE_TOKEN}{t}" for t in text]
|
61 |
-
text_inputs = self.tokenizer(
|
62 |
-
input_texts,
|
63 |
-
padding='max_length',
|
64 |
-
truncation=True,
|
65 |
-
max_length=768,
|
66 |
-
return_tensors="pt",
|
67 |
-
pad_to_multiple_of=8,
|
68 |
-
).to(device)
|
69 |
-
pixel_values = self.vision_encoder(images, device)
|
70 |
-
return {
|
71 |
-
"input_ids": text_inputs.input_ids,
|
72 |
-
"attention_mask": text_inputs.attention_mask,
|
73 |
-
"pixel_values": pixel_values
|
74 |
-
}
|
75 |
-
|
76 |
def preprocess_inputs(self, batch):
|
77 |
-
|
78 |
input_ids = batch['input_ids']
|
79 |
attention_mask = batch['attention_mask']
|
80 |
input_ids = input_ids.to(self.device)
|
81 |
attention_mask = attention_mask.to(self.device)
|
82 |
-
|
83 |
-
|
84 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
85 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
86 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
@@ -88,10 +67,22 @@ class GPT2Vision(PreTrainedModel):
|
|
88 |
return inputs_embeds, attention_mask, input_ids
|
89 |
|
90 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
output_sequences = self.language_model.generate(
|
96 |
inputs_embeds=inputs_embeds,
|
97 |
attention_mask=attention_mask,
|
@@ -100,5 +91,6 @@ class GPT2Vision(PreTrainedModel):
|
|
100 |
max_new_tokens=max_new_tokens,
|
101 |
**kwargs
|
102 |
)
|
|
|
103 |
output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
|
104 |
return output
|
|
|
45 |
self.language_model = GPT2LMHeadModel(config.gpt2_config)
|
46 |
self.language_model.resize_token_embeddings(len(tokenizer))
|
47 |
self.tokenizer = tokenizer
|
48 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
49 |
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
50 |
|
51 |
@property
|
52 |
def device(self):
|
53 |
return next(self.language_model.parameters()).device
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def preprocess_inputs(self, batch):
|
56 |
+
img_embs = batch['pixel_values']
|
57 |
input_ids = batch['input_ids']
|
58 |
attention_mask = batch['attention_mask']
|
59 |
input_ids = input_ids.to(self.device)
|
60 |
attention_mask = attention_mask.to(self.device)
|
61 |
+
img_embs = img_embs.to(self.device)
|
62 |
+
|
63 |
tok_embs = self.language_model.get_input_embeddings()(input_ids)
|
64 |
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
|
65 |
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
|
|
|
67 |
return inputs_embeds, attention_mask, input_ids
|
68 |
|
69 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
70 |
+
# Process the image
|
71 |
+
img_embs = self.vision_encoder(image.unsqueeze(0), device=self.device)
|
72 |
+
img_embs = self.mlp(img_embs)
|
73 |
+
|
74 |
+
# Tokenize the question
|
75 |
+
prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
|
76 |
+
encoded_input = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
77 |
+
|
78 |
+
batch = {
|
79 |
+
"pixel_values": img_embs,
|
80 |
+
"input_ids": encoded_input.input_ids,
|
81 |
+
"attention_mask": encoded_input.attention_mask
|
82 |
+
}
|
83 |
+
|
84 |
+
inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
|
85 |
+
|
86 |
output_sequences = self.language_model.generate(
|
87 |
inputs_embeds=inputs_embeds,
|
88 |
attention_mask=attention_mask,
|
|
|
91 |
max_new_tokens=max_new_tokens,
|
92 |
**kwargs
|
93 |
)
|
94 |
+
|
95 |
output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
|
96 |
return output
|