dahwinsingularity commited on
Commit
edf99e6
·
verified ·
1 Parent(s): 19ad6d3

Update modeling_minicpmv.py

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +114 -21
modeling_minicpmv.py CHANGED
@@ -1,4 +1,5 @@
1
- import math
 
2
  from typing import List, Optional
3
  import json
4
  import torch
@@ -22,7 +23,7 @@ class MiniCPMVPreTrainedModel(LlamaPreTrainedModel):
22
  class MiniCPMV(MiniCPMVPreTrainedModel):
23
  def __init__(self, config):
24
  super().__init__(config)
25
-
26
  self.llm = LlamaForCausalLM(config)
27
  self.vpm = self.init_vision_module()
28
  self.vision_dim = self.vpm.embed_dim
@@ -30,6 +31,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
30
  self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
31
  self.transform = self.init_transform()
32
 
 
 
 
 
 
 
 
33
  def init_vision_module(self):
34
  # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
35
  model = Idefics2VisionTransformer(self.config.vision_config)
@@ -61,9 +69,12 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
61
  )
62
 
63
  def get_vllm_embedding(self, data):
 
 
64
  if 'vision_hidden_states' not in data:
65
- dtype = self.vpm.embeddings.position_embedding.weight.dtype
66
- device = self.vpm.embeddings.position_embedding.weight.device
 
67
  tgt_sizes = data['tgt_sizes']
68
  pixel_values_list = data['pixel_values']
69
  vision_hidden_states = []
@@ -126,14 +137,89 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
126
 
127
  else:
128
  vision_hidden_states = data['vision_hidden_states']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if hasattr(self.llm.config, 'scale_emb'):
131
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
132
  else:
133
- vllm_embedding = self.llm.model.embed_tokens(data['input_ids'])
134
 
135
- vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(
136
- i, torch.Tensor) else i for i in vision_hidden_states]
137
 
138
  bs = len(data['input_ids'])
139
  for i in range(bs):
@@ -142,29 +228,36 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
142
  cur_vllm_emb = vllm_embedding[i]
143
  cur_image_bound = data['image_bound'][i]
144
  if len(cur_image_bound) > 0:
145
- image_indices = torch.stack(
146
- [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
147
- ).to(vllm_embedding.device)
148
-
149
- cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
150
- cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
151
  elif self.training:
152
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
153
 
154
  return vllm_embedding, vision_hidden_states
155
-
156
  def forward(self, data, **kwargs):
157
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
158
  position_ids = data["position_ids"]
159
  if position_ids.dtype != torch.int64:
160
  position_ids = position_ids.long()
161
 
162
- return self.llm(
163
- input_ids=None,
164
- position_ids=position_ids,
165
- inputs_embeds=vllm_embedding,
166
- **kwargs
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def _convert_to_tensors(
170
  self, tokenizer, input_ids, max_inp_length: Optional[int] = None
 
1
+
2
+ import math
3
  from typing import List, Optional
4
  import json
5
  import torch
 
23
  class MiniCPMV(MiniCPMVPreTrainedModel):
24
  def __init__(self, config):
25
  super().__init__(config)
26
+
27
  self.llm = LlamaForCausalLM(config)
28
  self.vpm = self.init_vision_module()
29
  self.vision_dim = self.vpm.embed_dim
 
31
  self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
32
  self.transform = self.init_transform()
33
 
34
+ # Wrap the model with DataParallel
35
+ if torch.cuda.device_count() > 1:
36
+ self.llm = nn.DataParallel(self.llm)
37
+ self.vpm = nn.DataParallel(self.vpm)
38
+ self.resampler = nn.DataParallel(self.resampler)
39
+
40
+
41
  def init_vision_module(self):
42
  # same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
43
  model = Idefics2VisionTransformer(self.config.vision_config)
 
69
  )
70
 
71
  def get_vllm_embedding(self, data):
72
+ # Similar processing as before but make sure to handle DataParallel wrapping if needed
73
+ # Make sure all tensors are moved to the correct device if needed
74
  if 'vision_hidden_states' not in data:
75
+ dtype = self.vpm.module.embeddings.position_embedding.weight.dtype if isinstance(self.vpm, nn.DataParallel) else self.vpm.embeddings.position_embedding.weight.dtype
76
+ device = self.vpm.module.embeddings.position_embedding.weight.device if isinstance(self.vpm, nn.DataParallel) else self.vpm.embeddings.position_embedding.weight.device
77
+
78
  tgt_sizes = data['tgt_sizes']
79
  pixel_values_list = data['pixel_values']
80
  vision_hidden_states = []
 
137
 
138
  else:
139
  vision_hidden_states = data['vision_hidden_states']
140
+ def chat(self, image, msgs, tokenizer, vision_hidden_states=None, max_new_tokens=1024, sampling=True, max_inp_length=2048, **kwargs):
141
+ if isinstance(msgs, str):
142
+ msgs = json.loads(msgs)
143
+
144
+ copy_msgs = deepcopy(msgs)
145
+ assert len(copy_msgs) > 0, 'msgs is empty'
146
+
147
+ if image is not None and isinstance(copy_msgs[0]['content'], str):
148
+ copy_msgs[0]['content'] = [image, copy_msgs[0]['content']]
149
+
150
+ images = []
151
+ tgt_sizes = []
152
+ for i, msg in enumerate(copy_msgs):
153
+ role = msg["role"]
154
+ content = msg["content"]
155
+ assert role in ["user", "assistant"]
156
+ if i == 0:
157
+ assert role == "user", "The role of first msg should be user"
158
+ if isinstance(content, str):
159
+ content = [content]
160
+
161
+ cur_msgs = []
162
+ for c in content:
163
+ if isinstance(c, Image.Image):
164
+ image = c
165
+ if self.config.slice_mode:
166
+ slice_images, image_placeholder = self.get_slice_image_placeholder(image, tokenizer)
167
+ cur_msgs.append(image_placeholder)
168
+ for slice_image in slice_images:
169
+ slice_image = self.transform(slice_image)
170
+ H, W = slice_image.shape[1:]
171
+ images.append(self.reshape_by_patch(slice_image))
172
+ tgt_sizes.append(torch.Tensor([H // self.config.patch_size, W // self.config.patch_size]).type(torch.int32))
173
+ else:
174
+ images.append(self.transform(image))
175
+ cur_msgs.append(tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end)
176
+ elif isinstance(c, str):
177
+ cur_msgs.append(c)
178
+
179
+ msg['content'] = '\n'.join(cur_msgs)
180
+ if tgt_sizes:
181
+ tgt_sizes = torch.vstack(tgt_sizes)
182
+
183
+ input_ids = tokenizer.apply_chat_template(copy_msgs, tokenize=True, add_generation_prompt=False)
184
+
185
+ if sampling:
186
+ generation_config = {
187
+ "top_p": 0.8,
188
+ "top_k": 100,
189
+ "temperature": 0.7,
190
+ "do_sample": True,
191
+ "repetition_penalty": 1.05
192
+ }
193
+ else:
194
+ generation_config = {
195
+ "num_beams": 3,
196
+ "repetition_penalty": 1.2,
197
+ }
198
+
199
+ generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
200
+
201
+ with torch.inference_mode():
202
+ res, vision_hidden_states = self.generate(
203
+ input_id_list=[input_ids],
204
+ max_inp_length=max_inp_length,
205
+ img_list=[images],
206
+ tgt_sizes=[tgt_sizes],
207
+ tokenizer=tokenizer,
208
+ max_new_tokens=max_new_tokens,
209
+ vision_hidden_states=vision_hidden_states,
210
+ return_vision_hidden_states=True,
211
+ **generation_config
212
+ )
213
+ answer = res[0]
214
+
215
+ return answer
216
 
217
  if hasattr(self.llm.config, 'scale_emb'):
218
+ vllm_embedding = self.llm.module.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb if isinstance(self.llm, nn.DataParallel) else self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
219
  else:
220
+ vllm_embedding = self.llm.module.model.embed_tokens(data['input_ids']) if isinstance(self.llm, nn.DataParallel) else self.llm.model.embed_tokens(data['input_ids'])
221
 
222
+ vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states]
 
223
 
224
  bs = len(data['input_ids'])
225
  for i in range(bs):
 
228
  cur_vllm_emb = vllm_embedding[i]
229
  cur_image_bound = data['image_bound'][i]
230
  if len(cur_image_bound) > 0:
231
+ image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]).to(vllm_embedding.device)
232
+ cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1]))
 
 
 
 
233
  elif self.training:
234
  cur_vllm_emb += cur_vs_hs[0].mean() * 0
235
 
236
  return vllm_embedding, vision_hidden_states
 
237
  def forward(self, data, **kwargs):
238
  vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
239
  position_ids = data["position_ids"]
240
  if position_ids.dtype != torch.int64:
241
  position_ids = position_ids.long()
242
 
243
+ # Handle DataParallel model
244
+ if isinstance(self.llm, nn.DataParallel):
245
+ outputs = self.llm.module(
246
+ input_ids=None,
247
+ position_ids=position_ids,
248
+ inputs_embeds=vllm_embedding,
249
+ **kwargs
250
+ )
251
+ else:
252
+ outputs = self.llm(
253
+ input_ids=None,
254
+ position_ids=position_ids,
255
+ inputs_embeds=vllm_embedding,
256
+ **kwargs
257
+ )
258
+
259
+ return outputs
260
+
261
 
262
  def _convert_to_tensors(
263
  self, tokenizer, input_ids, max_inp_length: Optional[int] = None