Update modeling_minicpmv.py
Browse files- modeling_minicpmv.py +114 -21
modeling_minicpmv.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
-
|
|
|
2 |
from typing import List, Optional
|
3 |
import json
|
4 |
import torch
|
@@ -22,7 +23,7 @@ class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
|
|
22 |
class MiniCPMV(MiniCPMVPreTrainedModel):
|
23 |
def __init__(self, config):
|
24 |
super().__init__(config)
|
25 |
-
|
26 |
self.llm = LlamaForCausalLM(config)
|
27 |
self.vpm = self.init_vision_module()
|
28 |
self.vision_dim = self.vpm.embed_dim
|
@@ -30,6 +31,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
30 |
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
31 |
self.transform = self.init_transform()
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def init_vision_module(self):
|
34 |
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
|
35 |
model = Idefics2VisionTransformer(self.config.vision_config)
|
@@ -61,9 +69,12 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
61 |
)
|
62 |
|
63 |
def get_vllm_embedding(self, data):
|
|
|
|
|
64 |
if 'vision_hidden_states' not in data:
|
65 |
-
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
66 |
-
device = self.vpm.embeddings.position_embedding.weight.device
|
|
|
67 |
tgt_sizes = data['tgt_sizes']
|
68 |
pixel_values_list = data['pixel_values']
|
69 |
vision_hidden_states = []
|
@@ -126,14 +137,89 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
126 |
|
127 |
else:
|
128 |
vision_hidden_states = data['vision_hidden_states']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
if hasattr(self.llm.config, 'scale_emb'):
|
131 |
-
vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
132 |
else:
|
133 |
-
vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
|
134 |
|
135 |
-
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
|
136 |
-
i, torch.Tensor) else i for i in vision_hidden_states]
|
137 |
|
138 |
bs = len(data['input_ids'])
|
139 |
for i in range(bs):
|
@@ -142,29 +228,36 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
142 |
cur_vllm_emb = vllm_embedding[i]
|
143 |
cur_image_bound = data['image_bound'][i]
|
144 |
if len(cur_image_bound) > 0:
|
145 |
-
image_indices = torch.stack(
|
146 |
-
|
147 |
-
).to(vllm_embedding.device)
|
148 |
-
|
149 |
-
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
150 |
-
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
151 |
elif self.training:
|
152 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
153 |
|
154 |
return vllm_embedding, vision_hidden_states
|
155 |
-
|
156 |
def forward(self, data, **kwargs):
|
157 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
158 |
position_ids = data["position_ids"]
|
159 |
if position_ids.dtype != torch.int64:
|
160 |
position_ids = position_ids.long()
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
def _convert_to_tensors(
|
170 |
self, tokenizer, input_ids, max_inp_length: Optional[int] = None
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
from typing import List, Optional
|
4 |
import json
|
5 |
import torch
|
|
|
23 |
class MiniCPMV(MiniCPMVPreTrainedModel):
|
24 |
def __init__(self, config):
|
25 |
super().__init__(config)
|
26 |
+
|
27 |
self.llm = LlamaForCausalLM(config)
|
28 |
self.vpm = self.init_vision_module()
|
29 |
self.vision_dim = self.vpm.embed_dim
|
|
|
31 |
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
32 |
self.transform = self.init_transform()
|
33 |
|
34 |
+
# Wrap the model with DataParallel
|
35 |
+
if torch.cuda.device_count() > 1:
|
36 |
+
self.llm = nn.DataParallel(self.llm)
|
37 |
+
self.vpm = nn.DataParallel(self.vpm)
|
38 |
+
self.resampler = nn.DataParallel(self.resampler)
|
39 |
+
|
40 |
+
|
41 |
def init_vision_module(self):
|
42 |
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
|
43 |
model = Idefics2VisionTransformer(self.config.vision_config)
|
|
|
69 |
)
|
70 |
|
71 |
def get_vllm_embedding(self, data):
|
72 |
+
# Similar processing as before but make sure to handle DataParallel wrapping if needed
|
73 |
+
# Make sure all tensors are moved to the correct device if needed
|
74 |
if 'vision_hidden_states' not in data:
|
75 |
+
dtype = self.vpm.module.embeddings.position_embedding.weight.dtype if isinstance(self.vpm, nn.DataParallel) else self.vpm.embeddings.position_embedding.weight.dtype
|
76 |
+
device = self.vpm.module.embeddings.position_embedding.weight.device if isinstance(self.vpm, nn.DataParallel) else self.vpm.embeddings.position_embedding.weight.device
|
77 |
+
|
78 |
tgt_sizes = data['tgt_sizes']
|
79 |
pixel_values_list = data['pixel_values']
|
80 |
vision_hidden_states = []
|
|
|
137 |
|
138 |
else:
|
139 |
vision_hidden_states = data['vision_hidden_states']
|
140 |
+
def chat(self, image, msgs, tokenizer, vision_hidden_states=None, max_new_tokens=1024, sampling=True, max_inp_length=2048, **kwargs):
|
141 |
+
if isinstance(msgs, str):
|
142 |
+
msgs = json.loads(msgs)
|
143 |
+
|
144 |
+
copy_msgs = deepcopy(msgs)
|
145 |
+
assert len(copy_msgs) > 0, 'msgs is empty'
|
146 |
+
|
147 |
+
if image is not None and isinstance(copy_msgs[0]['content'], str):
|
148 |
+
copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
|
149 |
+
|
150 |
+
images = []
|
151 |
+
tgt_sizes = []
|
152 |
+
for i, msg in enumerate(copy_msgs):
|
153 |
+
role = msg["role"]
|
154 |
+
content = msg["content"]
|
155 |
+
assert role in ["user", "assistant"]
|
156 |
+
if i == 0:
|
157 |
+
assert role == "user", "The role of first msg should be user"
|
158 |
+
if isinstance(content, str):
|
159 |
+
content = [content]
|
160 |
+
|
161 |
+
cur_msgs = []
|
162 |
+
for c in content:
|
163 |
+
if isinstance(c, Image.Image):
|
164 |
+
image = c
|
165 |
+
if self.config.slice_mode:
|
166 |
+
slice_images, image_placeholder = self.get_slice_image_placeholder(image, tokenizer)
|
167 |
+
cur_msgs.append(image_placeholder)
|
168 |
+
for slice_image in slice_images:
|
169 |
+
slice_image = self.transform(slice_image)
|
170 |
+
H, W = slice_image.shape[1:]
|
171 |
+
images.append(self.reshape_by_patch(slice_image))
|
172 |
+
tgt_sizes.append(torch.Tensor([H // self.config.patch_size, W // self.config.patch_size]).type(torch.int32))
|
173 |
+
else:
|
174 |
+
images.append(self.transform(image))
|
175 |
+
cur_msgs.append(tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end)
|
176 |
+
elif isinstance(c, str):
|
177 |
+
cur_msgs.append(c)
|
178 |
+
|
179 |
+
msg['content'] = '\n'.join(cur_msgs)
|
180 |
+
if tgt_sizes:
|
181 |
+
tgt_sizes = torch.vstack(tgt_sizes)
|
182 |
+
|
183 |
+
input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
|
184 |
+
|
185 |
+
if sampling:
|
186 |
+
generation_config = {
|
187 |
+
"top_p": 0.8,
|
188 |
+
"top_k": 100,
|
189 |
+
"temperature": 0.7,
|
190 |
+
"do_sample": True,
|
191 |
+
"repetition_penalty": 1.05
|
192 |
+
}
|
193 |
+
else:
|
194 |
+
generation_config = {
|
195 |
+
"num_beams": 3,
|
196 |
+
"repetition_penalty": 1.2,
|
197 |
+
}
|
198 |
+
|
199 |
+
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
|
200 |
+
|
201 |
+
with torch.inference_mode():
|
202 |
+
res, vision_hidden_states = self.generate(
|
203 |
+
input_id_list=[input_ids],
|
204 |
+
max_inp_length=max_inp_length,
|
205 |
+
img_list=[images],
|
206 |
+
tgt_sizes=[tgt_sizes],
|
207 |
+
tokenizer=tokenizer,
|
208 |
+
max_new_tokens=max_new_tokens,
|
209 |
+
vision_hidden_states=vision_hidden_states,
|
210 |
+
return_vision_hidden_states=True,
|
211 |
+
**generation_config
|
212 |
+
)
|
213 |
+
answer = res[0]
|
214 |
+
|
215 |
+
return answer
|
216 |
|
217 |
if hasattr(self.llm.config, 'scale_emb'):
|
218 |
+
vllm_embedding = self.llm.module.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb if isinstance(self.llm, nn.DataParallel) else self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
|
219 |
else:
|
220 |
+
vllm_embedding = self.llm.module.model.embed_tokens(data['input_ids']) if isinstance(self.llm, nn.DataParallel) else self.llm.model.embed_tokens(data['input_ids'])
|
221 |
|
222 |
+
vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states]
|
|
|
223 |
|
224 |
bs = len(data['input_ids'])
|
225 |
for i in range(bs):
|
|
|
228 |
cur_vllm_emb = vllm_embedding[i]
|
229 |
cur_image_bound = data['image_bound'][i]
|
230 |
if len(cur_image_bound) > 0:
|
231 |
+
image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]).to(vllm_embedding.device)
|
232 |
+
cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
|
|
|
|
|
|
|
|
|
233 |
elif self.training:
|
234 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
235 |
|
236 |
return vllm_embedding, vision_hidden_states
|
|
|
237 |
def forward(self, data, **kwargs):
|
238 |
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
|
239 |
position_ids = data["position_ids"]
|
240 |
if position_ids.dtype != torch.int64:
|
241 |
position_ids = position_ids.long()
|
242 |
|
243 |
+
# Handle DataParallel model
|
244 |
+
if isinstance(self.llm, nn.DataParallel):
|
245 |
+
outputs = self.llm.module(
|
246 |
+
input_ids=None,
|
247 |
+
position_ids=position_ids,
|
248 |
+
inputs_embeds=vllm_embedding,
|
249 |
+
**kwargs
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
outputs = self.llm(
|
253 |
+
input_ids=None,
|
254 |
+
position_ids=position_ids,
|
255 |
+
inputs_embeds=vllm_embedding,
|
256 |
+
**kwargs
|
257 |
+
)
|
258 |
+
|
259 |
+
return outputs
|
260 |
+
|
261 |
|
262 |
def _convert_to_tensors(
|
263 |
self, tokenizer, input_ids, max_inp_length: Optional[int] = None
|