prithivMLmods commited on
Commit
c58c051
·
verified ·
1 Parent(s): b830986

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -720
app.py DELETED
@@ -1,720 +0,0 @@
1
- import gradio as gr
2
- import spaces
3
- from PIL import Image
4
- from transformers import AutoProcessor, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel
5
- import copy
6
- from decord import VideoReader, cpu
7
- import numpy as np
8
- import json
9
- from tqdm import tqdm
10
- import os
11
- import easyocr
12
- import re
13
- import ast
14
- import socket
15
- import pickle
16
- import ffmpeg
17
- import torchaudio
18
- import torch
19
- import warnings
20
- import shutil
21
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
22
- from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
23
- from llava.utils import rank0_print
24
- from llava.model import *
25
- from llava.model.language_model.llava_qwen import LlavaQwenForCausalLM, LlavaQwenConfig
26
- from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
27
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
28
- from llava.conversation import conv_templates, SeparatorStyle
29
-
30
- # Inline the tools if code is provided; assuming they are modules, but for single file, inline if possible.
31
- # For now, assume imports work, but since single file, need to define them.
32
- # User mentioned imports like from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic
33
- # But code not provided, so I'll keep the imports, assuming environment has them.
34
- # Similarly for filter_keywords, generate_scene_graph_description
35
-
36
- from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic
37
- from tools.filter_keywords import filter_keywords
38
- from tools.scene_graph import generate_scene_graph_description
39
-
40
- # From builder.py
41
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16", attn_implementation=None, customized_config=None, overwrite_config=None, **kwargs):
42
- kwargs["device_map"] = device_map
43
-
44
- if load_8bit:
45
- kwargs["load_in_8bit"] = True
46
- elif load_4bit:
47
- kwargs["load_in_4bit"] = True
48
- kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
49
- elif torch_dtype == "float16":
50
- kwargs["torch_dtype"] = torch.float16
51
- elif torch_dtype == "bfloat16":
52
- kwargs["torch_dtype"] = torch.bfloat16
53
- else:
54
- import pdb;pdb.set_trace()
55
-
56
- if customized_config is not None:
57
- kwargs["config"] = customized_config
58
-
59
- if "multimodal" in kwargs:
60
- if kwargs["multimodal"] is True:
61
- is_multimodal = True
62
- kwargs.pop("multimodal")
63
- else:
64
- is_multimodal = False
65
-
66
- if "llava" in model_name.lower() or is_multimodal:
67
- # Load LLaVA model
68
- if "lora" in model_name.lower() and model_base is None:
69
- warnings.warn(
70
- "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
71
- )
72
- if "lora" in model_name.lower() and model_base is not None:
73
- lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
74
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
75
- rank0_print("Loading LLaVA from base model...")
76
- if "mixtral" in model_name.lower():
77
- from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
78
-
79
- lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
80
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
81
- model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
82
- elif "mistral" in model_name.lower():
83
- from llava.model.language_model.llava_mistral import LlavaMistralConfig
84
-
85
- lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
86
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
87
- model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
88
- elif "gemma" in model_name.lower():
89
- from llava.model.language_model.llava_gemma import LlavaGemmaConfig
90
-
91
- lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
92
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
93
- model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
94
- else:
95
- from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
96
-
97
- lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
98
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
99
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
100
-
101
- token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
102
- if model.lm_head.weight.shape[0] != token_num:
103
- model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
104
- model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
105
-
106
- rank0_print("Loading additional LLaVA weights...")
107
- if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
108
- non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
109
- else:
110
- # this is probably from HF Hub
111
- from huggingface_hub import hf_hub_download
112
-
113
- def load_from_hf(repo_id, filename, subfolder=None):
114
- cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
115
- return torch.load(cache_file, map_location="cpu")
116
-
117
- non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
118
- non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
119
- if any(k.startswith("model.model.") for k in non_lora_trainables):
120
- non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
121
- model.load_state_dict(non_lora_trainables, strict=False)
122
-
123
- from peft import PeftModel
124
-
125
- rank0_print("Loading LoRA weights...")
126
- model = PeftModel.from_pretrained(model, model_path)
127
- rank0_print("Merging LoRA weights...")
128
- model = model.merge_and_unload()
129
- rank0_print("Model is loaded...")
130
- elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
131
- rank0_print(f"Loading LLaVA from base model {model_base}...")
132
- if "mixtral" in model_name.lower():
133
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
134
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
135
- model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
136
- elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
137
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
138
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
139
- model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
140
- elif "gemma" in model_name.lower():
141
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
142
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
143
- model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
144
- elif (
145
- "wizardlm-2" in model_name.lower()
146
- and "vicuna" in model_name.lower()
147
- or "llama" in model_name.lower()
148
- or "yi" in model_name.lower()
149
- or "nous-hermes" in model_name.lower()
150
- or "llava-v1.6-34b" in model_name.lower()
151
- or "llava-v1.5" in model_name.lower()
152
- ):
153
- from llava.model.language_model.llava_llama import LlavaConfig
154
-
155
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
156
- if customized_config is None:
157
- llava_cfg = LlavaConfig.from_pretrained(model_path)
158
- if "v1.5" in model_name.lower():
159
- llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
160
- else:
161
- llava_cfg = customized_config
162
-
163
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
164
- llava_cfg = LlavaConfig.from_pretrained(model_path)
165
- model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
166
- else:
167
- raise ValueError(f"Model {model_name} not supported")
168
-
169
- mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
170
- mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
171
- model.load_state_dict(mm_projector_weights, strict=False)
172
- else:
173
- rank0_print(f"Loaded LLaVA model: {model_path}")
174
- if "mixtral" in model_name.lower():
175
- from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
176
-
177
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
178
- if customized_config is None:
179
- llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
180
- else:
181
- llava_cfg = customized_config
182
-
183
- if overwrite_config is not None:
184
- rank0_print(f"Overwriting config with {overwrite_config}")
185
- for k, v in overwrite_config.items():
186
- setattr(llava_cfg, k, v)
187
-
188
- tokenizer = AutoTokenizer.from_pretrained(model_path)
189
- model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
190
-
191
- elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
192
- tokenizer = AutoTokenizer.from_pretrained(model_path)
193
- model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
194
- elif (
195
- "wizardlm-2" in model_name.lower()
196
- and "vicuna" in model_name.lower()
197
- or "llama" in model_name.lower()
198
- or "yi" in model_name.lower()
199
- or "nous-hermes" in model_name.lower()
200
- or "llava-v1.6-34b" in model_name.lower()
201
- or "llava-v1.5" in model_name.lower()
202
- ):
203
- from llava.model.language_model.llava_llama import LlavaConfig
204
-
205
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
206
- if customized_config is None:
207
- llava_cfg = LlavaConfig.from_pretrained(model_path)
208
- if "v1.5" in model_path.lower():
209
- llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
210
- else:
211
- llava_cfg = customized_config
212
-
213
- if overwrite_config is not None:
214
- rank0_print(f"Overwriting config with {overwrite_config}")
215
- for k, v in overwrite_config.items():
216
- setattr(llava_cfg, k, v)
217
-
218
- model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
219
-
220
- elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
221
- tokenizer = AutoTokenizer.from_pretrained(model_path)
222
- if "moe" in model_name.lower() or "A14B" in model_name.lower():
223
- from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
224
- if overwrite_config is not None:
225
- llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
226
- rank0_print(f"Overwriting config with {overwrite_config}")
227
- for k, v in overwrite_config.items():
228
- setattr(llava_cfg, k, v)
229
- model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
230
- else:
231
- model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
232
-
233
- else:
234
- from llava.model.language_model.llava_qwen import LlavaQwenConfig
235
- if overwrite_config is not None:
236
- llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
237
- rank0_print(f"Overwriting config with {overwrite_config}")
238
- for k, v in overwrite_config.items():
239
- setattr(llava_cfg, k, v)
240
- model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
241
- else:
242
- model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
243
-
244
- elif "gemma" in model_name.lower():
245
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
246
- cfg_pretrained = AutoConfig.from_pretrained(model_path)
247
- model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
248
- else:
249
- try:
250
- from llava.model.language_model.llava_llama import LlavaConfig
251
-
252
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
253
- if customized_config is None:
254
- llava_cfg = LlavaConfig.from_pretrained(model_path)
255
- if "v1.5" in model_path.lower():
256
- llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
257
- else:
258
- llava_cfg = customized_config
259
-
260
- if overwrite_config is not None:
261
- rank0_print(f"Overwriting config with {overwrite_config}")
262
- for k, v in overwrite_config.items():
263
- setattr(llava_cfg, k, v)
264
- model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
265
- except:
266
- raise ValueError(f"Model {model_name} not supported")
267
-
268
- else:
269
- # Load language model
270
- if model_base is not None:
271
- # PEFT model
272
- from peft import PeftModel
273
-
274
- tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
275
- model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
276
- print(f"Loading LoRA weights from {model_path}")
277
- model = PeftModel.from_pretrained(model, model_path)
278
- print(f"Merging weights")
279
- model = model.merge_and_unload()
280
- print("Convert to FP16...")
281
- model.to(torch.float16)
282
- else:
283
- use_fast = False
284
- if "mpt" in model_name.lower().replace("prompt", ""):
285
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
286
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
287
- else:
288
- tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
289
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
290
-
291
- rank0_print(f"Model Class: {model.__class__.__name__}")
292
- image_processor = None
293
-
294
- if "llava" in model_name.lower() or is_multimodal:
295
- mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
296
- mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
297
- if mm_use_im_patch_token:
298
- tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
299
- if mm_use_im_start_end:
300
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
301
- model.resize_token_embeddings(len(tokenizer))
302
-
303
- vision_tower = model.get_vision_tower()
304
- if not vision_tower.is_loaded:
305
- vision_tower.load_model(device_map=device_map)
306
- if device_map != "auto":
307
- vision_tower.to(device="cuda", dtype=torch.float16)
308
- image_processor = vision_tower.image_processor
309
-
310
- if hasattr(model.config, "max_sequence_length"):
311
- context_len = model.config.max_sequence_length
312
- elif hasattr(model.config, "max_position_embeddings"):
313
- context_len = model.config.max_position_embeddings
314
- elif hasattr(model.config, "tokenizer_model_max_length"):
315
- context_len = model.config.tokenizer_model_max_length
316
- else:
317
- context_len = 2048
318
-
319
- return tokenizer, model, image_processor, context_len
320
-
321
- # From vidrag_pipeline.py
322
- max_frames_num = 32
323
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto")
324
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
325
- # whisper_model = WhisperForConditionalGeneration.from_pretrained(
326
- # "openai/whisper-large",
327
- # torch_dtype=torch.float16,
328
- # device_map="auto"
329
- # )
330
- # whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
331
-
332
- @spaces.GPU
333
- def process_video(video_path, max_frames_num, fps=1, force_sample=False):
334
- if max_frames_num == 0:
335
- return np.zeros((1, 336, 336, 3))
336
- vr = VideoReader(video_path, ctx=cpu(),num_threads=1)
337
- total_frame_num = len(vr)
338
- video_time = total_frame_num / vr.get_avg_fps()
339
- fps = round(vr.get_avg_fps()/fps)
340
- frame_idx = [i for i in range(0, len(vr), fps)]
341
- frame_time = [i/fps for i in frame_idx]
342
- if len(frame_idx) > max_frames_num or force_sample:
343
- sample_fps = max_frames_num
344
- uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
345
- frame_idx = uniform_sampled_frames.tolist()
346
- frame_time = [i/vr.get_avg_fps() for i in frame_idx]
347
- frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
348
- spare_frames = vr.get_batch(frame_idx).asnumpy()
349
-
350
- return spare_frames, frame_time, video_time
351
-
352
- def extract_audio(video_path, audio_path):
353
- if not os.path.exists(audio_path):
354
- ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run()
355
-
356
- def chunk_audio(audio_path, chunk_length_s=30):
357
- speech, sr = torchaudio.load(audio_path)
358
- speech = speech.mean(dim=0)
359
- speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech)
360
- num_samples_per_chunk = chunk_length_s * 16000
361
- chunks = []
362
- for i in range(0, len(speech), num_samples_per_chunk):
363
- chunks.append(speech[i:i + num_samples_per_chunk])
364
- return chunks
365
-
366
- # def transcribe_chunk(chunk):
367
-
368
- # inputs = whisper_processor(chunk, return_tensors="pt")
369
- # inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16)
370
- # with torch.no_grad():
371
- # predicted_ids = whisper_model.generate(
372
- # inputs["input_features"],
373
- # no_repeat_ngram_size=2,
374
- # early_stopping=True
375
- # )
376
- # transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
377
- # return transcription
378
-
379
- # def get_asr_docs(video_path, audio_path):
380
-
381
- # full_transcription = []
382
-
383
- # try:
384
- # extract_audio(video_path, audio_path)
385
- # except:
386
- # return full_transcription
387
- # audio_chunks = chunk_audio(audio_path, chunk_length_s=30)
388
-
389
- # for chunk in audio_chunks:
390
- # transcription = transcribe_chunk(chunk)
391
- # full_transcription.append(transcription)
392
-
393
- # return full_transcription
394
-
395
- def get_ocr_docs(frames):
396
- reader = easyocr.Reader(['en'])
397
- text_set = []
398
- ocr_docs = []
399
- for img in frames:
400
- ocr_results = reader.readtext(img)
401
- det_info = ""
402
- for result in ocr_results:
403
- text = result[1]
404
- confidence = result[2]
405
- if confidence > 0.5 and text not in text_set:
406
- det_info += f"{text}; "
407
- text_set.append(text)
408
- if len(det_info) > 0:
409
- ocr_docs.append(det_info)
410
-
411
- return ocr_docs
412
-
413
-
414
- def save_frames(frames):
415
- file_paths = []
416
- for i, frame in enumerate(frames):
417
- img = Image.fromarray(frame)
418
- file_path = f'restore/frame_{i}.png'
419
- img.save(file_path)
420
- file_paths.append(file_path)
421
- return file_paths
422
-
423
- def get_det_docs(frames, prompt):
424
- prompt = ",".join(prompt)
425
- frames_path = save_frames(frames)
426
- res = []
427
- if len(frames) > 0:
428
- client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
429
- client_socket.connect(('0.0.0.0', 9999))
430
- data = (frames_path, prompt)
431
- client_socket.send(pickle.dumps(data))
432
- result_data = client_socket.recv(4096)
433
- try:
434
- res = pickle.loads(result_data)
435
- except:
436
- res = []
437
- return res
438
-
439
- def det_preprocess(det_docs, location, relation, number):
440
-
441
- scene_descriptions = []
442
-
443
- for det_doc_per_frame in det_docs:
444
- objects = []
445
- scene_description = ""
446
- if len(det_doc_per_frame) > 0:
447
- for obj_id, objs in enumerate(det_doc_per_frame.split(";")):
448
- obj_name = objs.split(":")[0].strip()
449
- obj_bbox = objs.split(":")[1].strip()
450
- obj_bbox = ast.literal_eval(obj_bbox)
451
- objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox})
452
-
453
- scene_description = generate_scene_graph_description(objects, location, relation, number)
454
- scene_descriptions.append(scene_description)
455
-
456
- return scene_descriptions
457
-
458
- # load your VLM
459
- device = "cuda"
460
- overwrite_config = {}
461
- tokenizer, model, image_processor, max_length = load_pretrained_model(
462
- "lmms-lab/LLaVA-Video-7B-Qwen2",
463
- None,
464
- "llava_qwen",
465
- torch_dtype="bfloat16",
466
- device_map="auto",
467
- offload_buffers=True,
468
- overwrite_config=overwrite_config) # Add any other thing you want to pass in llava_model_args
469
-
470
-
471
- # 2) Check vocab sizes and fix BEFORE dispatching
472
- vsz_model = model.get_input_embeddings().weight.shape[0]
473
- vsz_tok = len(tokenizer)
474
- if vsz_tok != vsz_model:
475
- print(f"[fix] resizing embeddings: model={vsz_model} -> tokenizer={vsz_tok}")
476
- model.resize_token_embeddings(vsz_tok)
477
- # optional: init new rows
478
- with torch.no_grad():
479
- added = vsz_tok - vsz_model
480
- if added > 0:
481
- emb = model.get_input_embeddings().weight
482
- emb[-added:].normal_(mean=0.0, std=0.02)
483
-
484
-
485
- model.eval()
486
- conv_template = "qwen_2" # Make sure you use correct chat template for different models
487
-
488
-
489
- # The inference function of your VLM
490
- def llava_inference(qs, video):
491
- if video is not None:
492
- question = DEFAULT_IMAGE_TOKEN + qs
493
- else:
494
- question = qs
495
- conv = copy.deepcopy(conv_templates[conv_template])
496
- conv.append_message(conv.roles[0], question)
497
- conv.append_message(conv.roles[1], None)
498
- prompt_question = conv.get_prompt()
499
-
500
- input_ids = tokenizer_image_token(
501
- prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
502
- ).unsqueeze(0).to(device)
503
- # input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
504
- # cont = model.generate(
505
- # input_ids,
506
- # video=video,
507
- # modalities= ["video"],
508
- # do_sample=True,
509
- # temperature=0.7,
510
- # max_new_tokens=4096,
511
- # )
512
-
513
- if video is not None:
514
- cont = model.generate(
515
- input_ids,
516
- images=video,
517
- modalities=["video"],
518
- do_sample=True,
519
- temperature=0.7,
520
- max_new_tokens=512
521
- )
522
- else:
523
- cont = model.generate(
524
- input_ids,
525
- do_sample=True,
526
- temperature=0.7,
527
- max_new_tokens=512
528
- )
529
-
530
- text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
531
- return text_outputs
532
-
533
-
534
- # super-parameters setting
535
- rag_threshold = 0.3
536
- clip_threshold = 0.3
537
- beta = 3.0
538
-
539
- # Choose the auxiliary texts you want
540
- USE_OCR = True
541
- USE_ASR = False
542
- USE_DET = True
543
- print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------")
544
- print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------")
545
- print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------")
546
- print(f"---------------Frames: {max_frames_num}-----------------")
547
-
548
- # Create directories
549
- os.makedirs("restore/audio", exist_ok=True)
550
- os.makedirs("restore", exist_ok=True)
551
-
552
- def process_query(video_path, question):
553
- if video_path is None:
554
- return "Please upload a video."
555
-
556
- frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True)
557
- raw_video = [f for f in frames]
558
-
559
- video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16()
560
- video = [video]
561
-
562
- if USE_DET:
563
- video_tensor = []
564
- for frame in raw_video:
565
- processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16)
566
- video_tensor.append(processed.squeeze(0))
567
- video_tensor = torch.stack(video_tensor, dim=0)
568
-
569
- if USE_OCR:
570
- ocr_docs_total = get_ocr_docs(frames)
571
-
572
- if USE_ASR:
573
- if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")):
574
- with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f:
575
- asr_docs_total = f.readlines()
576
- # else:
577
- # audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav")
578
- # # asr_docs_total = get_asr_docs(video_path, audio_path)
579
- # with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f:
580
- # for doc in asr_docs_total:
581
- # f.write(doc + '\n')
582
-
583
- # step 0: get cot information
584
- retrieve_pmt_0 = "Question: " + question
585
- # you can change this decouple prompt to fit your requirements
586
- retrieve_pmt_0 += "\nTo answer the question step by step, you can provide your retrieve request to assist you by the following json format:"
587
- retrieve_pmt_0 += '''{
588
- "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null.
589
- "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null.
590
- "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation".
591
- }
592
- ## Example 1:
593
- Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4.
594
- Your retrieve can be:
595
- {
596
- "ASR": "The location and the color of balloons, the number of the blue balloons.",
597
- "DET": ["blue ballons", "long table"],
598
- "TYPE": ["relation", "number"]
599
- }
600
- ## Example 2:
601
- Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow.
602
- Your retrieve can be:
603
- {
604
- "ASR": null,
605
- "DET": ["the man in black", "woman"],
606
- "TYPE": ["location", "relation"]
607
- }
608
- ## Example 3:
609
- Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States.
610
- Your retrieve can be:
611
- {
612
- "ASR": "The country recognized worldwide for its comedy.",
613
- "DET": null,
614
- "TYPE": null
615
- }
616
- Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.'''
617
-
618
- json_request = llava_inference(retrieve_pmt_0, None)
619
-
620
- # step 1: get docs information
621
- query = [question]
622
-
623
- # APE fetch
624
- if USE_DET:
625
- det_docs = []
626
- try:
627
- request_det = json.loads(json_request)["DET"]
628
- request_det = filter_keywords(request_det)
629
- clip_text = ["A picture of " + txt for txt in request_det]
630
- if len(clip_text) == 0:
631
- clip_text = ["A picture of object"]
632
- except:
633
- request_det = None
634
- clip_text = ["A picture of object"]
635
-
636
- clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device)
637
- clip_img_feats = clip_model.get_image_features(video_tensor)
638
- with torch.no_grad():
639
- text_features = clip_model.get_text_features(**clip_inputs)
640
- similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu()
641
- similarities = np.array(similarities, dtype=np.float64)
642
- alpha = beta * (len(similarities) / 16)
643
- similarities = similarities * alpha / np.sum(similarities)
644
-
645
- del clip_inputs, clip_img_feats, text_features
646
- torch.cuda.empty_cache()
647
-
648
- det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold]
649
-
650
- if request_det is not None and len(request_det) > 0:
651
- det_docs = get_det_docs(frames[det_top_idx], request_det)
652
-
653
- L, R, N = False, False, False
654
- try:
655
- det_retrieve_info = json.loads(json_request)["TYPE"]
656
- except:
657
- det_retrieve_info = None
658
- if det_retrieve_info is not None:
659
- if "location" in det_retrieve_info:
660
- L = True
661
- if "relation" in det_retrieve_info:
662
- R = True
663
- if "number" in det_retrieve_info:
664
- N = True
665
- det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information
666
-
667
-
668
- # OCR fetch
669
- if USE_OCR:
670
- try:
671
- request_det = json.loads(json_request)["DET"]
672
- request_det = filter_keywords(request_det)
673
- except:
674
- request_det = None
675
- ocr_docs = []
676
- if len(ocr_docs_total) > 0:
677
- ocr_query = query.copy()
678
- if request_det is not None and len(request_det) > 0:
679
- ocr_query.extend(request_det)
680
- ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold)
681
-
682
- # ASR fetch
683
- if USE_ASR:
684
- asr_docs = []
685
- try:
686
- request_asr = json.loads(json_request)["ASR"]
687
- except:
688
- request_asr = None
689
- if len(asr_docs_total) > 0:
690
- asr_query = query.copy()
691
- if request_asr is not None:
692
- asr_query.append(request_asr)
693
- asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold)
694
-
695
- qs = ""
696
- if USE_DET and len(det_docs) > 0:
697
- for i, info in enumerate(det_docs):
698
- if len(info) > 0:
699
- qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n"
700
- if len(qs) > 0:
701
- qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs
702
- if USE_ASR and len(asr_docs) > 0:
703
- qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs)
704
- if USE_OCR and len(ocr_docs) > 0:
705
- qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs)
706
- qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, or D) of the correct option. Question: " + question # you can change this prompt
707
-
708
- res = llava_inference(qs, video)
709
- return res
710
-
711
- demo = gr.Interface(
712
- fn=process_query,
713
- inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Question")],
714
- outputs=gr.Textbox(label="Answer"),
715
- title="Video Question Answering with LLaVA",
716
- description="Upload a video and ask a question to get a summary or answer."
717
- )
718
-
719
- if __name__ == "__main__":
720
- demo.launch()