Vision-CAIR commited on
Commit
ad13eac
·
verified ·
1 Parent(s): 328394f

Push model using huggingface_hub.

Browse files
Files changed (2) hide show
  1. config.json +8 -0
  2. mini_gpt4_llama_v2.py +880 -0
config.json CHANGED
@@ -1,5 +1,12 @@
1
  {
2
  "arch": "mini_gpt4_llama_v2",
 
 
 
 
 
 
 
3
  "chat_template": true,
4
  "ckpt": "checkpoints/video_mistral_all_checkpoint_last.pth",
5
  "device": "cuda",
@@ -27,6 +34,7 @@
27
  "prompt_path": "",
28
  "remove_template": false,
29
  "token_pooling": true,
 
30
  "transformers_version": "4.42.3",
31
  "use_grad_checkpoint": true,
32
  "use_grad_checkpoint_llm": true,
 
1
  {
2
  "arch": "mini_gpt4_llama_v2",
3
+ "architectures": [
4
+ "MiniGPT4_Video"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "mini_gpt4_llama_v2.minigpt4_video_config",
8
+ "AutoModel": "mini_gpt4_llama_v2.MiniGPT4_Video"
9
+ },
10
  "chat_template": true,
11
  "ckpt": "checkpoints/video_mistral_all_checkpoint_last.pth",
12
  "device": "cuda",
 
34
  "prompt_path": "",
35
  "remove_template": false,
36
  "token_pooling": true,
37
+ "torch_dtype": "float32",
38
  "transformers_version": "4.42.3",
39
  "use_grad_checkpoint": true,
40
  "use_grad_checkpoint_llm": true,
mini_gpt4_llama_v2.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import torch
4
+ import webvtt
5
+ import os
6
+ import cv2
7
+ from torchvision import transforms
8
+ import soundfile as sf
9
+ import moviepy.editor as mp
10
+ from PIL import Image
11
+ from moviepy.editor import VideoFileClip
12
+ import torch
13
+ import random
14
+ import torch.backends.cudnn as cudnn
15
+ import torch
16
+ from torch.cuda.amp import autocast as autocast
17
+ import torch.nn as nn
18
+
19
+ from minigpt4_video.registry import registry
20
+ from minigpt4_video.blip2 import Blip2Base, disabled_train
21
+ from minigpt4_video.conversation import Conversation, SeparatorStyle, StoppingCriteriaList, StoppingCriteriaSub
22
+ from transformers import LlamaTokenizer
23
+ from transformers import BitsAndBytesConfig
24
+ from transformers import AutoConfig, AutoTokenizer
25
+ from peft import (
26
+ LoraConfig,
27
+ get_peft_model,
28
+ get_peft_model_state_dict,
29
+ prepare_model_for_int8_training,
30
+ set_peft_model_state_dict,
31
+ )
32
+ import time
33
+ import numpy as np
34
+ import os
35
+ from transformers import PretrainedConfig
36
+ from transformers import PreTrainedModel
37
+ from minigpt4_video.conversation import CONV_VISION
38
+ import cv2
39
+ def extract_audio(video_path, audio_path):
40
+ video_clip = mp.VideoFileClip(video_path)
41
+ audio_clip = video_clip.audio
42
+ audio_clip.write_audiofile(audio_path, codec="libmp3lame", bitrate="320k")
43
+
44
+ def generate_subtitles(video_path):
45
+ video_id=video_path.split('/')[-1].split('.')[0]
46
+ audio_path = f"workspace/inference_subtitles/mp3/{video_id}"+'.mp3'
47
+ os.makedirs("workspace/inference_subtitles/mp3",exist_ok=True)
48
+ try:
49
+ extract_audio(video_path,audio_path)
50
+ print("successfully extracted")
51
+ os.system(f"whisper {audio_path} --language English --model large --output_format vtt --output_dir workspace/inference_subtitles/")
52
+ # remove the audio file
53
+ os.system(f"rm {audio_path}")
54
+ print("subtitle successfully generated")
55
+ return f"workspace/inference_subtitles/{video_id}"+'.vtt'
56
+ except Exception as e:
57
+ print("error",e)
58
+ print("error",video_path)
59
+ return None
60
+
61
+ class minigpt4_video_config(PretrainedConfig):
62
+ model_type="minigpt4_video"
63
+ PRETRAINED_MODEL_CONFIG_DICT = {
64
+ "minigpt4_video": "configs/models/minigpt4.yaml",
65
+ }
66
+ def __init__(
67
+ self,
68
+ omg_config:dict = {},
69
+ **kwargs,
70
+ ):
71
+ for key, value in omg_config.items():
72
+ setattr(self, key, value)
73
+ super().__init__(**kwargs)
74
+
75
+
76
+ @registry.register_model("mini_gpt4_llama_v2")
77
+ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
78
+ """
79
+ BLIP2 GPT-LLAMA model.
80
+ """
81
+ PRETRAINED_MODEL_CONFIG_DICT = {
82
+ "minigpt4_video": "minigpt4/configs/models/minigpt4.yaml",
83
+ }
84
+ config_class=minigpt4_video_config
85
+
86
+ def __init__(
87
+ self,
88
+ cfg={},
89
+ ):
90
+ ## loop through the config minigpt4_video_config object and set the attributes
91
+ # if isinstance(cfg, minigpt4_video_config):
92
+ try:
93
+ cfg = cfg.to_dict()
94
+ except:
95
+ pass
96
+ for key, value in cfg.items():
97
+ try:
98
+ setattr(self, key, value)
99
+ except:
100
+ print(f"Error setting attribute {key} with value {value}")
101
+ PreTrainedModel.__init__(self, minigpt4_video_config(cfg))
102
+ Blip2Base.__init__(self)
103
+
104
+ vis_processor_cfg = {"name": "blip2_image_train","image_size": 224}
105
+ self.vis_processor = registry.get_processor_class(vis_processor_cfg["name"]).from_config(vis_processor_cfg)
106
+ self.CONV_VISION = CONV_VISION
107
+ if "Mistral" in self.llama_model:
108
+ from minigpt4_video.modeling_mistral import MistralForCausalLM as llm_model
109
+ print("Mistral model")
110
+ self.model_type = "Mistral"
111
+ else:
112
+ from minigpt4_video.modeling_llama_v2 import LlamaForCausalLM as llm_model
113
+ print("Llama model")
114
+ self.model_type = "Llama"
115
+ self.tokenizer = self.init_tokenizer()
116
+
117
+ print("token pooling", self.token_pooling)
118
+ if self.freeze_vit:
119
+ # self.vit_precision="fp32"
120
+ print("vit precision", self.vit_precision)
121
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
122
+ self.vit_model, self.img_size, self.drop_path_rate, self.use_grad_checkpoint, self.vit_precision
123
+ )
124
+ for name, param in self.visual_encoder.named_parameters():
125
+ param.requires_grad = False
126
+ self.visual_encoder = self.visual_encoder.eval()
127
+ self.visual_encoder.train = disabled_train
128
+ for name, param in self.ln_vision.named_parameters():
129
+ param.requires_grad = False
130
+ self.ln_vision = self.ln_vision.eval()
131
+ self.ln_vision.train = disabled_train
132
+ logging.info("freeze vision encoder")
133
+ print("freeze the vision encoder")
134
+
135
+ else:
136
+ self.vit_precision="fp32"
137
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
138
+ self.vit_model, self.img_size, self.drop_path_rate, self.use_grad_checkpoint, self.vit_precision
139
+ )
140
+
141
+ print("unfreeze the vision encoder")
142
+ print('Loading VIT Done')
143
+
144
+ print('Loading LLAMA')
145
+
146
+ self.B_SYS, self.E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
147
+ token=os.environ.get("HF_TKN")
148
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(self.llama_model,use_fast=False,token=token) #
149
+ self.llama_tokenizer.pad_token = "$$"
150
+ # use fastv
151
+ self.use_fastv = False
152
+ print("self.low_resource",self.low_resource)
153
+ if self.low_resource:
154
+ self.llama_model = llm_model.from_pretrained(
155
+ self.llama_model,
156
+ torch_dtype=torch.float16,
157
+ # torch_dtype = torch.bfloat16,
158
+ load_in_8bit=True,
159
+ # device_map = "balanced"
160
+ # device_map="auto",
161
+ device_map={'':torch.cuda.current_device()},token=token
162
+ # device_map={'':0}
163
+
164
+ )
165
+ else:
166
+ self.llama_model = llm_model.from_pretrained(
167
+ self.llama_model,
168
+ torch_dtype=torch.float16,token=token
169
+ )
170
+
171
+ # self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
172
+ self.llama_model = prepare_model_for_int8_training(self.llama_model)
173
+ loraconfig = LoraConfig(
174
+ r=self.lora_r,
175
+ lora_alpha=self.lora_alpha,
176
+ target_modules=self.lora_target_modules,
177
+ lora_dropout=self.lora_dropout,
178
+ bias="none",
179
+ task_type="CAUSAL_LM"
180
+ )
181
+ self.llama_model = get_peft_model(self.llama_model, loraconfig)
182
+
183
+ self.llama_model.print_trainable_parameters()
184
+
185
+ if self.use_grad_checkpoint_llm:
186
+ self.llama_model.gradient_checkpointing_enable()
187
+
188
+ print('Loading LLAMA Done')
189
+
190
+
191
+ if self.token_pooling:
192
+ self.llama_proj = nn.Linear(
193
+ 1408*4, self.llama_model.config.hidden_size
194
+ )
195
+ else:
196
+ self.llama_proj = nn.Linear(
197
+ 1408, self.llama_model.config.hidden_size
198
+ )
199
+ if self.prompt_path:
200
+ with open(self.prompt_path, 'r') as f:
201
+ raw_prompts = f.read().splitlines()
202
+ filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt]
203
+ self.prompt_list = [self.prompt_template.format(p) for p in filted_prompts]
204
+ print('Load {} training prompts'.format(len(self.prompt_list)))
205
+ print('Prompt Example \n{}'.format(random.choice(self.prompt_list)))
206
+ else:
207
+ self.prompt_list = []
208
+ def prepare_input(self,video_path,subtitle_path,instruction):
209
+ cap = cv2.VideoCapture(video_path)
210
+ if subtitle_path is not None:
211
+ # Load the VTT subtitle file
212
+ vtt_file = webvtt.read(subtitle_path)
213
+ print("subtitle loaded successfully")
214
+ clip = VideoFileClip(video_path)
215
+ total_num_frames = int(clip.duration * clip.fps)
216
+ # print("Video duration = ",clip.duration)
217
+ clip.close()
218
+ else :
219
+ # calculate the total number of frames in the video using opencv
220
+ total_num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
221
+ if self.model_type == "Mistral":
222
+ max_images_length = 90
223
+ max_sub_len = 800
224
+ else:
225
+ max_images_length = 45
226
+ max_sub_len = 400
227
+ images = []
228
+ frame_count = 0
229
+ sampling_interval = int(total_num_frames / max_images_length)
230
+ if sampling_interval == 0:
231
+ sampling_interval = 1
232
+ img_placeholder = ""
233
+ subtitle_text_in_interval = ""
234
+ history_subtitles = {}
235
+ raw_frames=[]
236
+ number_of_words=0
237
+ transform=transforms.Compose([
238
+ transforms.ToPILImage(),
239
+ ])
240
+ while cap.isOpened():
241
+ ret, frame = cap.read()
242
+ if not ret:
243
+ break
244
+ # Find the corresponding subtitle for the frame and combine the interval subtitles into one subtitle
245
+ # we choose 1 frame for every 2 seconds,so we need to combine the subtitles in the interval of 2 seconds
246
+ if subtitle_path is not None:
247
+ for subtitle in vtt_file:
248
+ sub=subtitle.text.replace('\n',' ')
249
+ if (subtitle.start_in_seconds <= (frame_count / int(clip.fps)) <= subtitle.end_in_seconds) and sub not in subtitle_text_in_interval:
250
+ if not history_subtitles.get(sub,False):
251
+ subtitle_text_in_interval+=sub+" "
252
+ history_subtitles[sub]=True
253
+ break
254
+ if frame_count % sampling_interval == 0:
255
+ raw_frames.append(Image.fromarray(cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)))
256
+ frame = transform(frame[:,:,::-1]) # convert to RGB
257
+ frame = self.vis_processor(frame)
258
+ images.append(frame)
259
+ img_placeholder += '<Img><ImageHere>'
260
+ if subtitle_path is not None and subtitle_text_in_interval != "" and number_of_words< max_sub_len:
261
+ img_placeholder+=f'<Cap>{subtitle_text_in_interval}'
262
+ number_of_words+=len(subtitle_text_in_interval.split(' '))
263
+ subtitle_text_in_interval = ""
264
+ frame_count += 1
265
+
266
+ if len(images) >= max_images_length:
267
+ break
268
+
269
+ while len(images) < max_images_length:
270
+ images.append(images[-1])
271
+ img_placeholder += '<Img><ImageHere>'
272
+
273
+ cap.release()
274
+ cv2.destroyAllWindows()
275
+ if len(images) == 0:
276
+ # skip the video if no frame is extracted
277
+ return None,None
278
+ images = torch.stack(images)
279
+ instruction = img_placeholder + '\n' + instruction
280
+ return images,instruction
281
+ def encode_img(self, image):
282
+ device = image.device
283
+ if len(image.shape) > 4:
284
+ image = image.reshape(-1, *image.shape[-3:]) # for video input flatten the batch and time dimension (4,50,3,224,224) -> (200,3,224,224)
285
+ with self.maybe_autocast():
286
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) # (200,3,224,224) -> (200,257,1408)
287
+ image_embeds = image_embeds[:,1:,:] # remove the first token (CLS) (200,256,1408)
288
+ bs, pn, hs = image_embeds.shape
289
+ if self.token_pooling: # concat the each 4 tokens into one token (200,64,5632)
290
+ image_embeds = image_embeds.view(bs, int(pn/4), int(hs*4)) # (200,64,5632)
291
+
292
+ inputs_llama = self.llama_proj(image_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
293
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
294
+ return inputs_llama, atts_llama
295
+
296
+ def get_context_emb(self, prompt, img_list):
297
+ img_device = img_list[0].device
298
+ prompt_segs = prompt.split('<ImageHere>')
299
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
300
+ seg_tokens = [
301
+ self.llama_tokenizer(
302
+ seg, return_tensors="pt", add_special_tokens=i==0).to(img_device).input_ids # only add bos to the first seg
303
+ for i, seg in enumerate(prompt_segs)
304
+ ]
305
+
306
+ seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
307
+
308
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
309
+
310
+ mixed_embs = torch.cat(mixed_embs, dim=1)
311
+
312
+ return mixed_embs
313
+
314
+ def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
315
+ if prompts is None or len(prompts) == 0:
316
+ # prompts is not provided, just return the original image embedding
317
+ return img_embeds, atts_img
318
+ elif img_embeds is None:
319
+ # prompt is provided but there is no image embedding. return the prompt embedding in right padding
320
+ self.llama_tokenizer.padding_side = "right"
321
+ prompt_tokens = self.llama_tokenizer(
322
+ prompts,
323
+ return_tensors="pt",
324
+ padding="max_length",
325
+ add_special_tokens=False
326
+ ).to(self.device)
327
+ prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
328
+ atts_prompt = prompt_tokens.attention_mask
329
+ return prompt_embeds, atts_prompt
330
+
331
+ else:
332
+ # return the multi-modal embedding in right padding
333
+ emb_lists = []
334
+ if type(prompts) == str:
335
+ prompts = [prompts] * len(img_embeds)
336
+ for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
337
+ pn = each_img_embed.shape[-2]
338
+ if lengths is not None:
339
+ each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
340
+ each_img_embed = each_img_embed[:lengths[idx] * pn]
341
+
342
+ p_segs = each_prompt.split('<ImageHere>')
343
+ interleave_emb = []
344
+ for idx, seg in enumerate(p_segs[:-1]):
345
+ p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
346
+ p_embed = self.embed_tokens(p_tokens.input_ids)
347
+
348
+ interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx*pn:(idx+1)*pn]], dim=1))
349
+
350
+ wrapped_emb = torch.cat(interleave_emb, dim=1)
351
+ p_tokens = self.llama_tokenizer(p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
352
+ p_embed = self.embed_tokens(p_tokens.input_ids)
353
+ wrapped_emb = torch.cat([wrapped_emb,p_embed], dim=1)
354
+ emb_lists.append(wrapped_emb)
355
+
356
+ emb_lens = [emb.shape[1] for emb in emb_lists]
357
+ pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device))
358
+
359
+ # max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len
360
+ max_length = self.max_context_len
361
+ wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone()
362
+ wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device)
363
+
364
+ for i, emb in enumerate(emb_lists):
365
+ length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len
366
+ wrapped_embs[i, :length] = emb[:, :length]
367
+ wrapped_atts[i, :length] = 1
368
+
369
+ return wrapped_embs, wrapped_atts
370
+
371
+ def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
372
+ """
373
+ Concatenate the batched input embedding and batched output embedding together.
374
+ Both the input and the output embedding should be right padded.
375
+ """
376
+
377
+ input_lens = []
378
+ cat_embs = []
379
+ cat_atts = []
380
+
381
+ for i in range(input_embs.size(0)):
382
+ input_len = input_atts[i].sum()
383
+ input_lens.append(input_len)
384
+
385
+ cat_embs.append(
386
+ torch.cat([
387
+ input_embs[i][:input_len],
388
+ output_embs[i],
389
+ input_embs[i][input_len:]
390
+ ])
391
+ )
392
+ cat_atts.append(
393
+ torch.cat([
394
+ input_atts[i][:input_len],
395
+ output_atts[i],
396
+ input_atts[i][input_len:]
397
+ ])
398
+ )
399
+
400
+ cat_embs = torch.stack(cat_embs)
401
+ cat_atts = torch.stack(cat_atts)
402
+ return cat_embs, cat_atts, input_lens
403
+
404
+ def get_conv_emb(self, conv_q, conv_a, conv_img):
405
+ """concatenate conversation and make sure the model is only trained to regress the answer"""
406
+
407
+ regress_embs_list = []
408
+ targets_list = []
409
+
410
+ batch_size = len(conv_q)
411
+ for batch_idx in range(batch_size):
412
+ questions, answers = conv_q[batch_idx], conv_a[batch_idx]
413
+ assigned_imgs = conv_img[batch_idx]
414
+ questions = [self.prompt_wrap(
415
+ img_embeds=img,
416
+ atts_img=None,
417
+ prompts=[q],
418
+ lengths=[img.shape[1]] if img is not None else None) for q, img in zip(questions, assigned_imgs)]
419
+ q_embs = [emb for emb, _ in questions]
420
+
421
+ answers = [self.llama_tokenizer(a, return_tensors="pt", add_special_tokens=False).to(self.device) for a in answers]
422
+ cur_emb = []
423
+ cur_target = []
424
+ for i in range(len(questions)):
425
+ cur_emb.append(q_embs[i])
426
+ cur_target.append(torch.ones_like(q_embs[i][..., 0], dtype=torch.int) * -100)
427
+
428
+ cur_emb.append(self.embed_tokens(answers[i].input_ids))
429
+ cur_target.append(answers[i].input_ids)
430
+
431
+ cur_emb = torch.cat(cur_emb, dim=1)
432
+ cur_target = torch.cat(cur_target, dim=1)
433
+
434
+ regress_embs_list.append(cur_emb)
435
+ targets_list.append(cur_target)
436
+
437
+ max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
438
+
439
+ regress_embeds = torch.zeros([batch_size, max_len, cur_emb.shape[-1]], device=self.device)
440
+ regress_attn = torch.zeros([batch_size, max_len], dtype=torch.int, device=self.device)
441
+ targets = torch.ones([batch_size, max_len], dtype=torch.long, device=self.device) * -100
442
+
443
+ for batch_idx in range(batch_size):
444
+ cur_len = regress_embs_list[batch_idx].shape[1]
445
+ regress_embeds[batch_idx, :cur_len] = regress_embs_list[batch_idx][0, :max_len]
446
+ regress_attn[batch_idx, :cur_len] = 1
447
+ targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
448
+
449
+ return regress_embeds, regress_attn, targets
450
+
451
+ def preparing_embedding(self, samples):
452
+ def remove_special_tokens(data):
453
+
454
+ # if "instruction_input" in data:
455
+ data = [instruct.replace(" [caption]","") for instruct in data]
456
+ data = [instruct.replace(" [vqa]","") for instruct in data]
457
+ data = [instruct.replace(" [grounding]","") for instruct in data]
458
+ data = [instruct.replace(" [identify]","") for instruct in data]
459
+ data = [instruct.replace(" [refer]","") for instruct in data]
460
+ return data
461
+
462
+ ### prepare input tokens
463
+ if 'image' in samples:
464
+ img_embeds, img_atts = self.encode_img(samples["image"])
465
+ else:
466
+ img_embeds = img_atts = None
467
+
468
+ if 'conv_q' in samples:
469
+ # handeling conversation datasets
470
+ conv_q, conv_a = samples['conv_q'], samples['conv_a']
471
+
472
+ connect_sym = samples['connect_sym'][0]
473
+ conv_q = [q.split(connect_sym)for q in conv_q]
474
+ conv_a = [a.split(connect_sym) for a in conv_a]
475
+ conv_img = assign_imgs(conv_q, img_embeds)
476
+
477
+ if self.chat_template:
478
+ conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
479
+
480
+ regress_embeds, regress_atts, part_targets = self.get_conv_emb(conv_q, conv_a, conv_img)
481
+ cond_embeds, cond_atts = regress_embeds[:, :0], regress_atts[:, :0]
482
+
483
+ else:
484
+ if "instruction_input" in samples:
485
+ instruction = samples["instruction_input"]
486
+ elif len(self.prompt_list) > 1:
487
+ instruction = random.choice(self.prompt_list)
488
+ else:
489
+ instruction = None
490
+
491
+ if self.remove_template:
492
+ instruction = remove_special_tokens(instruction)
493
+
494
+ if self.chat_template:
495
+ instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
496
+
497
+ if 'length' in samples:
498
+ # the input is a image train (like videos)
499
+ bsz, pn, hs = img_embeds.shape
500
+ img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) # (200,64,4096) -> (4,50,64,4096)
501
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
502
+ else:
503
+ cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, instruction)
504
+
505
+ ### prepare target tokens
506
+ self.llama_tokenizer.padding_side = "right"
507
+ text = [t + self.end_sym for t in samples["answer"]]
508
+
509
+ regress_tokens = self.llama_tokenizer(
510
+ text,
511
+ return_tensors="pt",
512
+ padding="max_length",
513
+ truncation=True,
514
+ max_length=self.max_txt_len,
515
+ add_special_tokens=False
516
+ ).to(self.device)
517
+
518
+ regress_token_ids = regress_tokens.input_ids
519
+ regress_atts = regress_tokens.attention_mask
520
+ part_targets = regress_token_ids.masked_fill(
521
+ regress_token_ids == self.llama_tokenizer.pad_token_id, -100
522
+ )
523
+
524
+ regress_embeds = self.embed_tokens(regress_token_ids)
525
+
526
+ return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets
527
+
528
+ def forward(self, samples, reduction="mean"):
529
+ # prepare the embedding to condition and the embedding to regress
530
+ cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets = \
531
+ self.preparing_embedding(samples)
532
+
533
+ # concat the embedding to condition and the embedding to regress
534
+ inputs_embeds, attention_mask, input_lens = \
535
+ self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
536
+ # get bos token embedding
537
+ bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
538
+ bos_embeds = self.embed_tokens(bos)
539
+ bos_atts = attention_mask[:, :1]
540
+
541
+ # add bos token at the begining
542
+ inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
543
+ attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
544
+
545
+ targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
546
+ dtype=torch.long).to(self.device).fill_(-100)
547
+ for i, target in enumerate(part_targets):
548
+ targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
549
+
550
+ with self.maybe_autocast():
551
+ outputs = self.llama_model(
552
+ inputs_embeds=inputs_embeds,
553
+ attention_mask=attention_mask,
554
+ return_dict=True,
555
+ labels=targets,
556
+ reduction=reduction,
557
+ use_fastv=self.use_fastv
558
+ )
559
+ loss = outputs.loss
560
+
561
+ return {"loss": loss}
562
+
563
+ @torch.no_grad()
564
+ def generate(
565
+ self,
566
+ images,
567
+ texts,
568
+ use_nucleus_sampling=False,
569
+ num_beams=1,
570
+ max_new_tokens=20,
571
+ min_length=1,
572
+ top_p=0.9,
573
+ repetition_penalty=1.5,
574
+ length_penalty=1,
575
+ temperature=1,
576
+ do_sample=False,
577
+ stop_words_ids=[2],
578
+ lengths=None,
579
+ return_video_temporal_features=False,
580
+ img_embeds=None,
581
+ ):
582
+ '''
583
+ function for generate test use
584
+ '''
585
+
586
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
587
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
588
+ if img_embeds is None:
589
+ img_embeds, atts_img = self.encode_img(images.to(self.device))
590
+ else:
591
+ # Use images features from the input(4,45,64,5632)
592
+ img_embeds = img_embeds.reshape(-1, *img_embeds.shape[-2:])
593
+ img_embeds= img_embeds.to(self.device)
594
+ img_embeds = self.llama_proj(img_embeds) # project to llama input size (200,64,5632) -> (200,64,4096)
595
+ atts_img = torch.ones(img_embeds.size()[:-1], dtype=torch.long).to(self.device)
596
+
597
+ if lengths is not None:
598
+ image_lists = []
599
+ img_embeds = img_embeds.reshape(len(lengths), -1, img_embeds.shape[-2], img_embeds.shape[-1])
600
+ for idx, img_embed in enumerate(img_embeds):
601
+ image_lists.append([img_embed[i][None] for i in range(lengths[idx])])
602
+ else:
603
+ image_lists = [[image_emb[None]] for image_emb in img_embeds]
604
+ assert len(texts) == len(image_lists)
605
+ batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
606
+
607
+ batch_size = len(batch_embs)
608
+ max_len = max([emb.shape[1] for emb in batch_embs])
609
+ emb_dim = batch_embs[0].shape[2]
610
+ dtype = batch_embs[0].dtype
611
+ device = batch_embs[0].device
612
+
613
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
614
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
615
+ for i, emb in enumerate(batch_embs):
616
+ emb_len = emb.shape[1]
617
+ embs[i, -emb_len:] = emb[0]
618
+ attn_mask[i, -emb_len:] = 1
619
+ # check if the input embedding tokens are in the range of the model cotext window (4096) and if it is not, then truncate it to the max context window
620
+ if self.model_type == "Llama":
621
+ context_window = 3700
622
+ else:
623
+ context_window = 7500
624
+ if embs.shape[1] > context_window:
625
+ embs = embs[:, -context_window:]
626
+ attn_mask = attn_mask[:, -context_window:]
627
+ with self.maybe_autocast():
628
+ if return_video_temporal_features:
629
+ last_hidden_state = self.llama_model(
630
+ inputs_embeds=embs,
631
+ attention_mask=attn_mask,
632
+ output_hidden_states=True,
633
+ ).hidden_states[-1]
634
+ video_temporal_features = last_hidden_state.mean(dim=1)
635
+ # normalize the temporal features using L2 norm
636
+ # video_temporal_features = video_temporal_features / video_temporal_features.norm(dim=-1, keepdim=True)
637
+ outputs = self.llama_model.generate(
638
+ inputs_embeds=embs,
639
+ attention_mask=attn_mask,
640
+ max_new_tokens=max_new_tokens,
641
+ num_beams=num_beams,
642
+ do_sample=do_sample,
643
+ temperature=temperature,
644
+ repetition_penalty=repetition_penalty,
645
+ # stopping_criteria=stopping_criteria,
646
+ use_fastv=False,
647
+ )
648
+
649
+ answers = []
650
+ for output_token in outputs:
651
+ if output_token[0] == 0:
652
+ output_token = output_token[1:]
653
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
654
+ output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
655
+ output_texts = output_texts.replace("<s>", "")
656
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
657
+ answers.append(output_texts)
658
+ if return_video_temporal_features:
659
+ return answers, video_temporal_features
660
+ else:
661
+ return answers
662
+ def inference_fun (self,video_path,instruction,gen_subtitles=True):
663
+ if gen_subtitles:
664
+ subtitle_path=generate_subtitles(video_path)
665
+ else :
666
+ subtitle_path=None
667
+ prepared_images,prepared_instruction=self.prepare_input(video_path,subtitle_path,instruction)
668
+ if prepared_images is None:
669
+ return "Video cann't be open ,check the video path again"
670
+ length=len(prepared_images)
671
+ prepared_images=prepared_images.unsqueeze(0)
672
+ conv = self.CONV_VISION.copy()
673
+ conv.system = ""
674
+ # if you want to make conversation comment the 2 lines above and make the conv is global variable
675
+ conv.append_message(conv.roles[0], prepared_instruction)
676
+ conv.append_message(conv.roles[1], None)
677
+ prompt = [conv.get_prompt()]
678
+ answers = self.generate(prepared_images, prompt, max_new_tokens=512, do_sample=True, lengths=[length],num_beams=1)
679
+ return answers[0]
680
+ @torch.no_grad()
681
+ def generate_text_only(
682
+ self,
683
+ images,
684
+ seg_tokens,
685
+ use_nucleus_sampling=False,
686
+ num_beams=1,
687
+ max_new_tokens=20,
688
+ min_length=1,
689
+ top_p=0.9,
690
+ repetition_penalty=1.5,
691
+ length_penalty=1,
692
+ temperature=1,
693
+ do_sample=False,
694
+ stop_words_ids=[2],
695
+ lengths=None,
696
+ return_video_temporal_features=False,
697
+ img_embeds=None,
698
+ ):
699
+ '''
700
+ function for generate test use
701
+ '''
702
+
703
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
704
+ stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
705
+
706
+ batch_embs = [torch.cat([self.embed_tokens(seg_t)]) for seg_t in seg_tokens]
707
+
708
+ batch_size = len(batch_embs)
709
+ max_len = max([emb.shape[1] for emb in batch_embs])
710
+ emb_dim = batch_embs[0].shape[2]
711
+ dtype = batch_embs[0].dtype
712
+ device = batch_embs[0].device
713
+
714
+ embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
715
+ attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
716
+ for i, emb in enumerate(batch_embs):
717
+ emb_len = emb.shape[1]
718
+ embs[i, -emb_len:] = emb[0]
719
+ attn_mask[i, -emb_len:] = 1
720
+
721
+ with self.maybe_autocast():
722
+ outputs = self.llama_model.generate(
723
+ inputs_embeds=embs,
724
+ attention_mask=attn_mask,
725
+ max_new_tokens=max_new_tokens,
726
+ num_beams=num_beams,
727
+ do_sample=do_sample,
728
+ temperature=temperature,
729
+ repetition_penalty=repetition_penalty,
730
+ # stopping_criteria=stopping_criteria,
731
+ )
732
+
733
+ answers = []
734
+ for output_token in outputs:
735
+ if output_token[0] == 0:
736
+ output_token = output_token[1:]
737
+ output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
738
+ output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
739
+ output_texts = output_texts.replace("<s>", "")
740
+ output_texts = output_texts.split(r'[/INST]')[-1].strip()
741
+ answers.append(output_texts)
742
+ return answers
743
+
744
+
745
+
746
+ @torch.no_grad()
747
+ def multi_select(self, images, texts, answers, num_cand=None):
748
+ all_losses = []
749
+ for answer in answers:
750
+ choice_samples = {
751
+ 'image': images,
752
+ 'instruction_input': texts,
753
+ 'answer': answer
754
+ }
755
+ loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
756
+ all_losses.append(loss)
757
+ torch.cuda.empty_cache()
758
+ all_losses = torch.cat(all_losses, dim=-1)
759
+ if num_cand is not None:
760
+ for i in range(all_losses.shape[0]):
761
+ all_losses[i, num_cand[i]:] = 9999
762
+ output_class_ranks = torch.argsort(all_losses, dim=-1)
763
+ return output_class_ranks.tolist()
764
+
765
+ def predict_answers(
766
+ self,
767
+ samples,
768
+ num_beams=5,
769
+ inference_method="generate",
770
+ max_len=10,
771
+ min_len=1,
772
+ num_ans_candidates=128,
773
+ answer_list=None,
774
+ prompt="",
775
+ length_penalty=0,
776
+ **kwargs
777
+ ):
778
+ '''
779
+ function for open-ended VQA
780
+ '''
781
+ images = samples["image"].cuda()
782
+ texts = samples["instruction_input"]
783
+
784
+ output_text = self.generate(
785
+ images=images,
786
+ texts=texts,
787
+ num_beams=num_beams,
788
+ max_new_tokens=max_len,
789
+ min_length=min_len,
790
+ length_penalty=length_penalty
791
+ )
792
+
793
+ if "apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]:
794
+ output_text = self._lemmatize(output_text)
795
+
796
+ return output_text
797
+
798
+ def predict_class(
799
+ self,
800
+ samples,
801
+ num_beams=5,
802
+ inference_method="generate",
803
+ max_len=10,
804
+ min_len=1,
805
+ num_ans_candidates=5,
806
+ answer_list=None,
807
+ prompt="",
808
+ length_penalty=0,
809
+ **kwargs
810
+ ):
811
+ '''
812
+ function for multi-choice VQA
813
+ '''
814
+
815
+ image = samples["image"].cuda()
816
+ instruction = samples['instruction_input']
817
+ answers = samples["choices"]
818
+ num_cand = samples["num_choices"]
819
+
820
+ ranks = self.multi_select(image, instruction, answers, num_cand)
821
+
822
+ pred_ans = []
823
+ for i, rank in enumerate(ranks):
824
+ pred = answers[rank[0]][i]
825
+ pred_ans.append(pred)
826
+ return pred_ans
827
+
828
+ def embed_tokens(self, token_ids):
829
+ try:
830
+ embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
831
+ except AttributeError:
832
+ embeds = self.llama_model.model.embed_tokens(token_ids)
833
+
834
+ return embeds
835
+
836
+ @classmethod
837
+ def from_config(cls, cfg):
838
+ model = cls(
839
+ cfg=cfg,
840
+ )
841
+ ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
842
+ if ckpt_path:
843
+ print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
844
+ ckpt = torch.load(ckpt_path, map_location="cpu")
845
+ msg = model.load_state_dict(ckpt['model'], strict=False)
846
+ # push the model to the hub with its metadata and config file
847
+ model.to('cuda')
848
+ model.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
849
+ video_config = minigpt4_video_config(cfg)
850
+ # video_config.save_pretrained("minigpt4_video_config")
851
+ # print("Save Minigpt-4-LLM Config: minigpt4_video_config")
852
+ video_config.push_to_hub("Vision-CAIR/MiniGPT4-video-mistral-hf")
853
+ return model
854
+
855
+
856
+ def assign_imgs(batched_instruct_list, batched_img_embeds):
857
+ '''this function is used when the data is interleaved.
858
+ the interlevaed data is separated, and this function assign
859
+ corresponding image embeddings to each segment'''
860
+ if len(batched_img_embeds.shape) == 3:
861
+ batched_img_embeds = batched_img_embeds[:, None]
862
+
863
+ batched_assigned = []
864
+
865
+ for instruct_list, img_embeds in zip(batched_instruct_list, batched_img_embeds):
866
+ img_idx = 0
867
+ assigned_img = []
868
+ n_assigned = []
869
+ for instruct in instruct_list:
870
+ n_img = instruct.count('<ImageHere>')
871
+ if n_img > 0: # this instruction include images.
872
+ assigned_img.append(img_embeds[None, img_idx:img_idx+n_img])
873
+ img_idx += n_img
874
+ n_assigned.append(n_img)
875
+ else: # this instruction doesn't include images
876
+ assigned_img.append(None)
877
+ n_assigned.append(None)
878
+ batched_assigned.append(assigned_img)
879
+
880
+ return batched_assigned