prithivMLmods commited on
Commit
f9f220e
·
verified ·
1 Parent(s): 3596c45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +701 -294
app.py CHANGED
@@ -1,313 +1,720 @@
 
1
  import spaces
 
 
 
 
 
2
  import json
3
- import math
4
  import os
5
- import traceback
6
- from io import BytesIO
7
- from typing import Any, Dict, List, Optional, Tuple
8
  import re
9
- import time
10
- from threading import Thread
11
-
12
- import gradio as gr
13
- import requests
14
  import torch
15
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- from transformers import (
18
- Qwen2VLForConditionalGeneration,
19
- Qwen2_5_VLForConditionalGeneration,
20
- AutoModelForImageTextToText,
21
- AutoProcessor,
22
- TextIteratorStreamer,
23
- AutoModel,
24
- AutoTokenizer,
25
- )
26
 
27
- from transformers.image_utils import load_image
28
-
29
- # --- Constants and Model Setup ---
30
- MAX_INPUT_TOKEN_LENGTH = 4096
31
- # Note: The following line correctly falls back to CPU if CUDA is not available.
32
- # Let the environment (e.g., Hugging Face Spaces) determine the device.
33
- # This avoids conflicts with the CUDA environment setup by the platform.
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
-
36
- print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
37
- print("torch.__version__ =", torch.__version__)
38
- print("torch.version.cuda =", torch.version.cuda)
39
- print("cuda available:", torch.cuda.is_available())
40
- print("cuda device count:", torch.cuda.device_count())
41
- if torch.cuda.is_available():
42
- print("current device:", torch.cuda.current_device())
43
- print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
44
-
45
- print("Using device:", device)
46
-
47
- # --- Model Loading ---
48
-
49
- # --- Prompts for Different Tasks ---
50
- layout_prompt = """Please output the layout information from the image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
51
-
52
- 1. Bbox format: [x1, y1, x2, y2]
53
- 2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
54
- 3. Text Extraction & Formatting Rules:
55
- - For tables, provide the content in a structured JSON format.
56
- - For all other elements, provide the plain text.
57
- 4. Constraints:
58
- - The output must be the original text from the image.
59
- - All layout elements must be sorted according to human reading order.
60
- 5. Final Output: The entire output must be a single JSON object wrapped in ```json ... ```.
61
- """
62
-
63
- ocr_prompt = "Perform precise OCR on the image. Extract all text content, maintaining the original structure, paragraphs, and tables as formatted markdown."
64
-
65
- # --- Model Loading ---
66
- MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-080125"
67
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
68
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
69
- MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16
70
- ).to(device).eval()
71
-
72
- MODEL_ID_T = "prithivMLmods/Megalodon-OCR-Sync-0713"
73
- processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
74
- model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
75
- MODEL_ID_T, trust_remote_code=True, torch_dtype=torch.float16
76
- ).to(device).eval()
77
-
78
- MODEL_ID_C = "nanonets/Nanonets-OCR-s"
79
- processor_c = AutoProcessor.from_pretrained(MODEL_ID_C, trust_remote_code=True)
80
- model_c = Qwen2_5_VLForConditionalGeneration.from_pretrained(
81
- MODEL_ID_C, trust_remote_code=True, torch_dtype=torch.float16
82
- ).to(device).eval()
83
-
84
- MODEL_ID_G = "echo840/MonkeyOCR"
85
- SUBFOLDER = "Recognition"
86
- processor_g = AutoProcessor.from_pretrained(
87
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER
88
- )
89
- model_g = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90
- MODEL_ID_G, trust_remote_code=True, subfolder=SUBFOLDER, torch_dtype=torch.float16
91
- ).to(device).eval()
92
-
93
- MODEL_ID_I = "allenai/olmOCR-7B-0725"
94
- processor_i = AutoProcessor.from_pretrained(MODEL_ID_I, trust_remote_code=True)
95
- model_i = Qwen2_5_VLForConditionalGeneration.from_pretrained(
96
- MODEL_ID_I, trust_remote_code=True, torch_dtype=torch.float16
97
- ).to(device).eval()
98
-
99
- # --- Utility Functions ---
100
- def layoutjson2md(layout_data: Any) -> str:
101
- """
102
- FIXED: Converts the structured JSON from Layout Analysis into formatted Markdown.
103
- This version is robust against malformed JSON from the model.
104
- """
105
- markdown_lines = []
106
-
107
- # If the model wraps the list in a dictionary, find and extract the list.
108
- if isinstance(layout_data, dict):
109
- found_list = None
110
- for value in layout_data.values():
111
- if isinstance(value, list):
112
- found_list = value
113
- break
114
- if found_list is not None:
115
- layout_data = found_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  else:
117
- return "### Error: Could not find a list of layout items in the JSON object."
118
-
119
- if not isinstance(layout_data, list):
120
- return f"### Error: Expected a list of layout items, but received type {type(layout_data).__name__}."
121
-
122
- try:
123
- # Filter out any non-dictionary items and sort by reading order.
124
- valid_items = [item for item in layout_data if isinstance(item, dict)]
125
- sorted_items = sorted(valid_items, key=lambda x: (x.get('bbox', [0, 0, 0, 0])[1], x.get('bbox', [0, 0, 0, 0])[0]))
126
-
127
- for item in sorted_items:
128
- category = item.get('category', 'Text') # Default to 'Text' if no category
129
- text = item.get('text', '')
130
- if not text:
131
- continue
132
-
133
- if category == 'Title':
134
- markdown_lines.append(f"# {text}\n")
135
- elif category == 'Section-header':
136
- markdown_lines.append(f"## {text}\n")
137
- elif category == 'Table':
138
- if isinstance(text, dict) and 'header' in text and 'rows' in text:
139
- header = '| ' + ' | '.join(map(str, text['header'])) + ' |'
140
- separator = '| ' + ' | '.join(['---'] * len(text['header'])) + ' |'
141
- rows = ['| ' + ' | '.join(map(str, row)) + ' |' for row in text['rows']]
142
- markdown_lines.extend([header, separator] + rows)
143
- markdown_lines.append("\n")
144
- else: # Fallback for simple text or malformed tables
145
- markdown_lines.append(f"{text}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
- markdown_lines.append(f"{text}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
- except Exception as e:
150
- print(f"Error converting to markdown: {e}")
151
- traceback.print_exc()
152
- return "### Error: An unexpected error occurred while converting JSON to Markdown."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- return "\n".join(markdown_lines)
155
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- # --- Core Application Logic ---
158
  @spaces.GPU
159
- def process_document_stream(model_name: str, task_choice: str, image: Image.Image, max_new_tokens: int):
160
- """
161
- Main generator function that handles both OCR and Layout Analysis tasks.
162
- """
163
- if image is None:
164
- yield "Please upload an image.", "Please upload an image.", None
165
- return
166
-
167
- # 1. Select prompt based on user's task choice
168
- text_prompt = ocr_prompt if task_choice == "Content Extraction" else layout_prompt
169
-
170
- # 2. Select model and processor
171
- if model_name == "Camel-Doc-OCR-080125": processor, model = processor_m, model_m
172
- elif model_name == "Megalodon-OCR-Sync-0713": processor, model = processor_t, model_t
173
- elif model_name == "Nanonets-OCR-s": processor, model = processor_c, model_c
174
- elif model_name == "MonkeyOCR-Recognition": processor, model = processor_g, model_g
175
- elif model_name == "olmOCR-7B-0725": processor, model = processor_i, model_i
176
- else:
177
- yield "Invalid model selected.", "Invalid model selected.", None
178
- return
179
-
180
- # 3. Prepare model inputs and streamer
181
- messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text_prompt}]}]
182
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
183
- inputs = processor(text=[prompt_full], images=[image], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device)
184
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
185
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
186
-
187
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
188
- thread.start()
189
-
190
- # 4. Stream raw output to the UI in real-time
191
- buffer = ""
192
- for new_text in streamer:
193
- buffer += new_text
194
- buffer = buffer.replace("<|im_end|>", "")
195
- time.sleep(0.01)
196
- yield buffer , "⏳ Processing...", {"status": "streaming"}
197
-
198
- # 5. Post-process the final buffer based on the selected task
199
- if task_choice == "Content Extraction":
200
- # For OCR, the buffer is the final result.
201
- yield buffer, buffer, None
202
- else: # Layout Analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  try:
204
- json_match = re.search(r'```json\s*([\s\S]+?)\s*```', buffer)
205
- if not json_match:
206
- # If no JSON block is found, try to parse the whole buffer as a fallback.
207
- try:
208
- layout_data = json.loads(buffer)
209
- markdown_content = layoutjson2md(layout_data)
210
- yield buffer, markdown_content, layout_data
211
- return
212
- except json.JSONDecodeError:
213
- raise ValueError("JSON object not found in the model's output.")
214
-
215
- json_str = json_match.group(1)
216
- layout_data = json.loads(json_str)
217
- markdown_content = layoutjson2md(layout_data)
218
-
219
- yield buffer, markdown_content, layout_data
220
- except Exception as e:
221
- error_md = f" **Error:** Failed to parse Layout JSON.\n\n**Details:**\n`{str(e)}`\n\n**Raw Output:**\n```\n{buffer}\n```"
222
- error_json = {"error": "ProcessingError", "details": str(e), "raw_output": buffer}
223
- yield buffer, error_md, error_json
224
-
225
-
226
- # --- Gradio UI Definition ---
227
- def create_gradio_interface():
228
- """Builds and returns the Gradio web interface."""
229
- css = """
230
- .main-container { max-width: 1400px; margin: 0 auto; }
231
- .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;}
232
- .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; }
233
- """
234
- with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo:
235
- gr.HTML("""
236
- <div class="title" style="text-align: center">
237
- <h1>Tiny VLMs Lab🧪</h1>
238
- <p style="font-size: 1.1em; color: #6b7280; margin-bottom: 0.6em;">
239
- Advanced Vision-Language Model for Image Content and Layout Extraction
240
- </p>
241
- </div>
242
- """)
243
-
244
- with gr.Row():
245
- # Left Column (Inputs)
246
- with gr.Column(scale=1):
247
- model_choice = gr.Dropdown(
248
- choices=["Camel-Doc-OCR-080125",
249
- "MonkeyOCR-Recognition",
250
- "olmOCR-7B-0725",
251
- "Nanonets-OCR-s",
252
- "Megalodon-OCR-Sync-0713"
253
- ],
254
- label="Select Model",
255
- value="Nanonets-OCR-s"
256
- )
257
- task_choice = gr.Dropdown(
258
- choices=["Content Extraction",
259
- "Layout Analysis(.json)"],
260
- label="Select Task", value="Content Extraction"
261
- )
262
- image_input = gr.Image(label="Upload Image", type="pil", sources=['upload'])
263
- with gr.Accordion("Advanced Settings", open=False):
264
- max_new_tokens = gr.Slider(minimum=512, maximum=8192, value=4096, step=256, label="Max New Tokens")
265
-
266
- process_btn = gr.Button("🚀 Process Document", variant="primary", elem_classes=["process-button"], size="lg")
267
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
268
-
269
- # Right Column (Outputs)
270
- with gr.Column(scale=2):
271
- with gr.Tabs() as tabs:
272
- with gr.Tab("📝 Extracted Content"):
273
- raw_output_stream = gr.Textbox(label="Raw Model Output Stream", interactive=False, lines=13, show_copy_button=True)
274
- with gr.Row():
275
- examples = gr.Examples(
276
- examples=["examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png", "examples/5.png"],
277
- inputs=image_input,
278
- label="Examples"
279
- )
280
- gr.Markdown("[Report-Bug💻](https://huggingface.co/spaces/prithivMLmods/OCR-Comparator/discussions)")
281
- with gr.Tab("📰 README.md"):
282
- with gr.Accordion("(Formatted Result)", open=True):
283
- markdown_output = gr.Markdown(label="Formatted Markdown")
284
-
285
- with gr.Tab("📋 Layout Analysis Results"):
286
- json_output = gr.JSON(label="Structured Layout Data (JSON)")
287
-
288
- # Event Handlers
289
- def clear_all_outputs():
290
- return None, "Raw output will appear here.", "Formatted results will appear here.", None
291
-
292
- process_btn.click(
293
- fn=process_document_stream,
294
- inputs=[model_choice,
295
- task_choice,
296
- image_input,
297
- max_new_tokens],
298
- outputs=[raw_output_stream,
299
- markdown_output,
300
- json_output]
301
  )
302
- clear_btn.click(
303
- clear_all_outputs,
304
- outputs=[image_input,
305
- raw_output_stream,
306
- markdown_output,
307
- json_output]
308
  )
309
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  if __name__ == "__main__":
312
- demo = create_gradio_interface()
313
- demo.queue(max_size=50).launch(share=True, ssr_mode=False, show_error=True)
 
1
+ import gradio as gr
2
  import spaces
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, WhisperForConditionalGeneration, WhisperProcessor, CLIPProcessor, CLIPModel
5
+ import copy
6
+ from decord import VideoReader, cpu
7
+ import numpy as np
8
  import json
9
+ from tqdm import tqdm
10
  import os
11
+ import easyocr
 
 
12
  import re
13
+ import ast
14
+ import socket
15
+ import pickle
16
+ import ffmpeg
17
+ import torchaudio
18
  import torch
19
+ import warnings
20
+ import shutil
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
22
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
23
+ from llava.utils import rank0_print
24
+ from llava.model import *
25
+ from llava.model.language_model.llava_qwen import LlavaQwenForCausalLM, LlavaQwenConfig
26
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
27
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
28
+ from llava.conversation import conv_templates, SeparatorStyle
29
+
30
+ # Inline the tools if code is provided; assuming they are modules, but for single file, inline if possible.
31
+ # For now, assume imports work, but since single file, need to define them.
32
+ # User mentioned imports like from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic
33
+ # But code not provided, so I'll keep the imports, assuming environment has them.
34
+ # Similarly for filter_keywords, generate_scene_graph_description
35
+
36
+ from tools.rag_retriever_dynamic import retrieve_documents_with_dynamic
37
+ from tools.filter_keywords import filter_keywords
38
+ from tools.scene_graph import generate_scene_graph_description
39
+
40
+ # From builder.py
41
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16", attn_implementation=None, customized_config=None, overwrite_config=None, **kwargs):
42
+ kwargs["device_map"] = device_map
43
+
44
+ if load_8bit:
45
+ kwargs["load_in_8bit"] = True
46
+ elif load_4bit:
47
+ kwargs["load_in_4bit"] = True
48
+ kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
49
+ elif torch_dtype == "float16":
50
+ kwargs["torch_dtype"] = torch.float16
51
+ elif torch_dtype == "bfloat16":
52
+ kwargs["torch_dtype"] = torch.bfloat16
53
+ else:
54
+ import pdb;pdb.set_trace()
55
 
56
+ if customized_config is not None:
57
+ kwargs["config"] = customized_config
 
 
 
 
 
 
 
58
 
59
+ if "multimodal" in kwargs:
60
+ if kwargs["multimodal"] is True:
61
+ is_multimodal = True
62
+ kwargs.pop("multimodal")
63
+ else:
64
+ is_multimodal = False
65
+
66
+ if "llava" in model_name.lower() or is_multimodal:
67
+ # Load LLaVA model
68
+ if "lora" in model_name.lower() and model_base is None:
69
+ warnings.warn(
70
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
71
+ )
72
+ if "lora" in model_name.lower() and model_base is not None:
73
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
74
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
75
+ rank0_print("Loading LLaVA from base model...")
76
+ if "mixtral" in model_name.lower():
77
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
78
+
79
+ lora_cfg_pretrained = LlavaMixtralConfig.from_pretrained(model_path)
80
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
81
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
82
+ elif "mistral" in model_name.lower():
83
+ from llava.model.language_model.llava_mistral import LlavaMistralConfig
84
+
85
+ lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
87
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
88
+ elif "gemma" in model_name.lower():
89
+ from llava.model.language_model.llava_gemma import LlavaGemmaConfig
90
+
91
+ lora_cfg_pretrained = LlavaGemmaConfig.from_pretrained(model_path)
92
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
93
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
94
+ else:
95
+ from llava.model.language_model.llava_llama import LlavaConfig, LlavaLlamaForCausalLM
96
+
97
+ lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
98
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
99
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
100
+
101
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
102
+ if model.lm_head.weight.shape[0] != token_num:
103
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
104
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
105
+
106
+ rank0_print("Loading additional LLaVA weights...")
107
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
108
+ non_lora_trainables = torch.load(os.path.join(model_path, "non_lora_trainables.bin"), map_location="cpu")
109
+ else:
110
+ # this is probably from HF Hub
111
+ from huggingface_hub import hf_hub_download
112
+
113
+ def load_from_hf(repo_id, filename, subfolder=None):
114
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder)
115
+ return torch.load(cache_file, map_location="cpu")
116
+
117
+ non_lora_trainables = load_from_hf(model_path, "non_lora_trainables.bin")
118
+ non_lora_trainables = {(k[11:] if k.startswith("base_model.") else k): v for k, v in non_lora_trainables.items()}
119
+ if any(k.startswith("model.model.") for k in non_lora_trainables):
120
+ non_lora_trainables = {(k[6:] if k.startswith("model.") else k): v for k, v in non_lora_trainables.items()}
121
+ model.load_state_dict(non_lora_trainables, strict=False)
122
+
123
+ from peft import PeftModel
124
+
125
+ rank0_print("Loading LoRA weights...")
126
+ model = PeftModel.from_pretrained(model, model_path)
127
+ rank0_print("Merging LoRA weights...")
128
+ model = model.merge_and_unload()
129
+ rank0_print("Model is loaded...")
130
+ elif model_base is not None: # this may be mm projector only, loading projector with preset language mdoel
131
+ rank0_print(f"Loading LLaVA from base model {model_base}...")
132
+ if "mixtral" in model_name.lower():
133
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
134
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
135
+ model = LlavaMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
136
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
137
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
138
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
139
+ model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
140
+ elif "gemma" in model_name.lower():
141
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
142
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
143
+ model = LlavaGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
144
+ elif (
145
+ "wizardlm-2" in model_name.lower()
146
+ and "vicuna" in model_name.lower()
147
+ or "llama" in model_name.lower()
148
+ or "yi" in model_name.lower()
149
+ or "nous-hermes" in model_name.lower()
150
+ or "llava-v1.6-34b" in model_name.lower()
151
+ or "llava-v1.5" in model_name.lower()
152
+ ):
153
+ from llava.model.language_model.llava_llama import LlavaConfig
154
+
155
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
156
+ if customized_config is None:
157
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
158
+ if "v1.5" in model_name.lower():
159
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
160
+ else:
161
+ llava_cfg = customized_config
162
+
163
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
164
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
165
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=llava_cfg, **kwargs)
166
+ else:
167
+ raise ValueError(f"Model {model_name} not supported")
168
+
169
+ mm_projector_weights = torch.load(os.path.join(model_path, "mm_projector.bin"), map_location="cpu")
170
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
171
+ model.load_state_dict(mm_projector_weights, strict=False)
172
  else:
173
+ rank0_print(f"Loaded LLaVA model: {model_path}")
174
+ if "mixtral" in model_name.lower():
175
+ from llava.model.language_model.llava_mixtral import LlavaMixtralConfig
176
+
177
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
178
+ if customized_config is None:
179
+ llava_cfg = LlavaMixtralConfig.from_pretrained(model_path)
180
+ else:
181
+ llava_cfg = customized_config
182
+
183
+ if overwrite_config is not None:
184
+ rank0_print(f"Overwriting config with {overwrite_config}")
185
+ for k, v in overwrite_config.items():
186
+ setattr(llava_cfg, k, v)
187
+
188
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
189
+ model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
190
+
191
+ elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
192
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
193
+ model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
194
+ elif (
195
+ "wizardlm-2" in model_name.lower()
196
+ and "vicuna" in model_name.lower()
197
+ or "llama" in model_name.lower()
198
+ or "yi" in model_name.lower()
199
+ or "nous-hermes" in model_name.lower()
200
+ or "llava-v1.6-34b" in model_name.lower()
201
+ or "llava-v1.5" in model_name.lower()
202
+ ):
203
+ from llava.model.language_model.llava_llama import LlavaConfig
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
206
+ if customized_config is None:
207
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
208
+ if "v1.5" in model_path.lower():
209
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
210
+ else:
211
+ llava_cfg = customized_config
212
+
213
+ if overwrite_config is not None:
214
+ rank0_print(f"Overwriting config with {overwrite_config}")
215
+ for k, v in overwrite_config.items():
216
+ setattr(llava_cfg, k, v)
217
+
218
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
219
+
220
+ elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
221
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
222
+ if "moe" in model_name.lower() or "A14B" in model_name.lower():
223
+ from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
224
+ if overwrite_config is not None:
225
+ llava_cfg = LlavaQwenMoeConfig.from_pretrained(model_path)
226
+ rank0_print(f"Overwriting config with {overwrite_config}")
227
+ for k, v in overwrite_config.items():
228
+ setattr(llava_cfg, k, v)
229
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
230
+ else:
231
+ model = LlavaQwenMoeForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
232
+
233
+ else:
234
+ from llava.model.language_model.llava_qwen import LlavaQwenConfig
235
+ if overwrite_config is not None:
236
+ llava_cfg = LlavaQwenConfig.from_pretrained(model_path)
237
+ rank0_print(f"Overwriting config with {overwrite_config}")
238
+ for k, v in overwrite_config.items():
239
+ setattr(llava_cfg, k, v)
240
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
241
+ else:
242
+ model = LlavaQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, **kwargs)
243
+
244
+ elif "gemma" in model_name.lower():
245
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
246
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
247
+ model = LlavaGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
248
  else:
249
+ try:
250
+ from llava.model.language_model.llava_llama import LlavaConfig
251
+
252
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
253
+ if customized_config is None:
254
+ llava_cfg = LlavaConfig.from_pretrained(model_path)
255
+ if "v1.5" in model_path.lower():
256
+ llava_cfg.delay_load = True # a workaround for correctly loading v1.5 models
257
+ else:
258
+ llava_cfg = customized_config
259
+
260
+ if overwrite_config is not None:
261
+ rank0_print(f"Overwriting config with {overwrite_config}")
262
+ for k, v in overwrite_config.items():
263
+ setattr(llava_cfg, k, v)
264
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=llava_cfg, **kwargs)
265
+ except:
266
+ raise ValueError(f"Model {model_name} not supported")
267
 
268
+ else:
269
+ # Load language model
270
+ if model_base is not None:
271
+ # PEFT model
272
+ from peft import PeftModel
273
+
274
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
275
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
276
+ print(f"Loading LoRA weights from {model_path}")
277
+ model = PeftModel.from_pretrained(model, model_path)
278
+ print(f"Merging weights")
279
+ model = model.merge_and_unload()
280
+ print("Convert to FP16...")
281
+ model.to(torch.float16)
282
+ else:
283
+ use_fast = False
284
+ if "mpt" in model_name.lower().replace("prompt", ""):
285
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
286
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
287
+ else:
288
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
289
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
290
+
291
+ rank0_print(f"Model Class: {model.__class__.__name__}")
292
+ image_processor = None
293
+
294
+ if "llava" in model_name.lower() or is_multimodal:
295
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
296
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
297
+ if mm_use_im_patch_token:
298
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
299
+ if mm_use_im_start_end:
300
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
301
+ model.resize_token_embeddings(len(tokenizer))
302
+
303
+ vision_tower = model.get_vision_tower()
304
+ if not vision_tower.is_loaded:
305
+ vision_tower.load_model(device_map=device_map)
306
+ if device_map != "auto":
307
+ vision_tower.to(device="cuda", dtype=torch.float16)
308
+ image_processor = vision_tower.image_processor
309
+
310
+ if hasattr(model.config, "max_sequence_length"):
311
+ context_len = model.config.max_sequence_length
312
+ elif hasattr(model.config, "max_position_embeddings"):
313
+ context_len = model.config.max_position_embeddings
314
+ elif hasattr(model.config, "tokenizer_model_max_length"):
315
+ context_len = model.config.tokenizer_model_max_length
316
+ else:
317
+ context_len = 2048
318
 
319
+ return tokenizer, model, image_processor, context_len
320
 
321
+ # From vidrag_pipeline.py
322
+ max_frames_num = 32
323
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336", torch_dtype=torch.float16, device_map="auto")
324
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
325
+ # whisper_model = WhisperForConditionalGeneration.from_pretrained(
326
+ # "openai/whisper-large",
327
+ # torch_dtype=torch.float16,
328
+ # device_map="auto"
329
+ # )
330
+ # whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
331
 
 
332
  @spaces.GPU
333
+ def process_video(video_path, max_frames_num, fps=1, force_sample=False):
334
+ if max_frames_num == 0:
335
+ return np.zeros((1, 336, 336, 3))
336
+ vr = VideoReader(video_path, ctx=cpu(),num_threads=1)
337
+ total_frame_num = len(vr)
338
+ video_time = total_frame_num / vr.get_avg_fps()
339
+ fps = round(vr.get_avg_fps()/fps)
340
+ frame_idx = [i for i in range(0, len(vr), fps)]
341
+ frame_time = [i/fps for i in frame_idx]
342
+ if len(frame_idx) > max_frames_num or force_sample:
343
+ sample_fps = max_frames_num
344
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
345
+ frame_idx = uniform_sampled_frames.tolist()
346
+ frame_time = [i/vr.get_avg_fps() for i in frame_idx]
347
+ frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
348
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
349
+
350
+ return spare_frames, frame_time, video_time
351
+
352
+ def extract_audio(video_path, audio_path):
353
+ if not os.path.exists(audio_path):
354
+ ffmpeg.input(video_path).output(audio_path, acodec='pcm_s16le', ac=1, ar='16k').run()
355
+
356
+ def chunk_audio(audio_path, chunk_length_s=30):
357
+ speech, sr = torchaudio.load(audio_path)
358
+ speech = speech.mean(dim=0)
359
+ speech = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(speech)
360
+ num_samples_per_chunk = chunk_length_s * 16000
361
+ chunks = []
362
+ for i in range(0, len(speech), num_samples_per_chunk):
363
+ chunks.append(speech[i:i + num_samples_per_chunk])
364
+ return chunks
365
+
366
+ # def transcribe_chunk(chunk):
367
+
368
+ # inputs = whisper_processor(chunk, return_tensors="pt")
369
+ # inputs["input_features"] = inputs["input_features"].to(whisper_model.device, torch.float16)
370
+ # with torch.no_grad():
371
+ # predicted_ids = whisper_model.generate(
372
+ # inputs["input_features"],
373
+ # no_repeat_ngram_size=2,
374
+ # early_stopping=True
375
+ # )
376
+ # transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
377
+ # return transcription
378
+
379
+ # def get_asr_docs(video_path, audio_path):
380
+
381
+ # full_transcription = []
382
+
383
+ # try:
384
+ # extract_audio(video_path, audio_path)
385
+ # except:
386
+ # return full_transcription
387
+ # audio_chunks = chunk_audio(audio_path, chunk_length_s=30)
388
+
389
+ # for chunk in audio_chunks:
390
+ # transcription = transcribe_chunk(chunk)
391
+ # full_transcription.append(transcription)
392
+
393
+ # return full_transcription
394
+
395
+ def get_ocr_docs(frames):
396
+ reader = easyocr.Reader(['en'])
397
+ text_set = []
398
+ ocr_docs = []
399
+ for img in frames:
400
+ ocr_results = reader.readtext(img)
401
+ det_info = ""
402
+ for result in ocr_results:
403
+ text = result[1]
404
+ confidence = result[2]
405
+ if confidence > 0.5 and text not in text_set:
406
+ det_info += f"{text}; "
407
+ text_set.append(text)
408
+ if len(det_info) > 0:
409
+ ocr_docs.append(det_info)
410
+
411
+ return ocr_docs
412
+
413
+
414
+ def save_frames(frames):
415
+ file_paths = []
416
+ for i, frame in enumerate(frames):
417
+ img = Image.fromarray(frame)
418
+ file_path = f'restore/frame_{i}.png'
419
+ img.save(file_path)
420
+ file_paths.append(file_path)
421
+ return file_paths
422
+
423
+ def get_det_docs(frames, prompt):
424
+ prompt = ",".join(prompt)
425
+ frames_path = save_frames(frames)
426
+ res = []
427
+ if len(frames) > 0:
428
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
429
+ client_socket.connect(('0.0.0.0', 9999))
430
+ data = (frames_path, prompt)
431
+ client_socket.send(pickle.dumps(data))
432
+ result_data = client_socket.recv(4096)
433
  try:
434
+ res = pickle.loads(result_data)
435
+ except:
436
+ res = []
437
+ return res
438
+
439
+ def det_preprocess(det_docs, location, relation, number):
440
+
441
+ scene_descriptions = []
442
+
443
+ for det_doc_per_frame in det_docs:
444
+ objects = []
445
+ scene_description = ""
446
+ if len(det_doc_per_frame) > 0:
447
+ for obj_id, objs in enumerate(det_doc_per_frame.split(";")):
448
+ obj_name = objs.split(":")[0].strip()
449
+ obj_bbox = objs.split(":")[1].strip()
450
+ obj_bbox = ast.literal_eval(obj_bbox)
451
+ objects.append({"id": obj_id, "label": obj_name, "bbox": obj_bbox})
452
+
453
+ scene_description = generate_scene_graph_description(objects, location, relation, number)
454
+ scene_descriptions.append(scene_description)
455
+
456
+ return scene_descriptions
457
+
458
+ # load your VLM
459
+ device = "cuda"
460
+ overwrite_config = {}
461
+ tokenizer, model, image_processor, max_length = load_pretrained_model(
462
+ "lmms-lab/LLaVA-Video-7B-Qwen2",
463
+ None,
464
+ "llava_qwen",
465
+ torch_dtype="bfloat16",
466
+ device_map="auto",
467
+ offload_buffers=True,
468
+ overwrite_config=overwrite_config) # Add any other thing you want to pass in llava_model_args
469
+
470
+
471
+ # 2) Check vocab sizes and fix BEFORE dispatching
472
+ vsz_model = model.get_input_embeddings().weight.shape[0]
473
+ vsz_tok = len(tokenizer)
474
+ if vsz_tok != vsz_model:
475
+ print(f"[fix] resizing embeddings: model={vsz_model} -> tokenizer={vsz_tok}")
476
+ model.resize_token_embeddings(vsz_tok)
477
+ # optional: init new rows
478
+ with torch.no_grad():
479
+ added = vsz_tok - vsz_model
480
+ if added > 0:
481
+ emb = model.get_input_embeddings().weight
482
+ emb[-added:].normal_(mean=0.0, std=0.02)
483
+
484
+
485
+ model.eval()
486
+ conv_template = "qwen_2" # Make sure you use correct chat template for different models
487
+
488
+
489
+ # The inference function of your VLM
490
+ def llava_inference(qs, video):
491
+ if video is not None:
492
+ question = DEFAULT_IMAGE_TOKEN + qs
493
+ else:
494
+ question = qs
495
+ conv = copy.deepcopy(conv_templates[conv_template])
496
+ conv.append_message(conv.roles[0], question)
497
+ conv.append_message(conv.roles[1], None)
498
+ prompt_question = conv.get_prompt()
499
+
500
+ input_ids = tokenizer_image_token(
501
+ prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
502
+ ).unsqueeze(0).to(device)
503
+ # input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
504
+ # cont = model.generate(
505
+ # input_ids,
506
+ # video=video,
507
+ # modalities= ["video"],
508
+ # do_sample=True,
509
+ # temperature=0.7,
510
+ # max_new_tokens=4096,
511
+ # )
512
+
513
+ if video is not None:
514
+ cont = model.generate(
515
+ input_ids,
516
+ images=video,
517
+ modalities=["video"],
518
+ do_sample=True,
519
+ temperature=0.7,
520
+ max_new_tokens=512
 
 
 
 
 
 
 
 
 
 
521
  )
522
+ else:
523
+ cont = model.generate(
524
+ input_ids,
525
+ do_sample=True,
526
+ temperature=0.7,
527
+ max_new_tokens=512
528
  )
529
+
530
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
531
+ return text_outputs
532
+
533
+
534
+ # super-parameters setting
535
+ rag_threshold = 0.3
536
+ clip_threshold = 0.3
537
+ beta = 3.0
538
+
539
+ # Choose the auxiliary texts you want
540
+ USE_OCR = True
541
+ USE_ASR = False
542
+ USE_DET = True
543
+ print(f"---------------OCR{rag_threshold}: {USE_OCR}-----------------")
544
+ print(f"---------------ASR{rag_threshold}: {USE_ASR}-----------------")
545
+ print(f"---------------DET{beta}-{clip_threshold}: {USE_DET}-----------------")
546
+ print(f"---------------Frames: {max_frames_num}-----------------")
547
+
548
+ # Create directories
549
+ os.makedirs("restore/audio", exist_ok=True)
550
+ os.makedirs("restore", exist_ok=True)
551
+
552
+ def process_query(video_path, question):
553
+ if video_path is None:
554
+ return "Please upload a video."
555
+
556
+ frames, frame_time, video_time = process_video(video_path, max_frames_num, 1, force_sample=True)
557
+ raw_video = [f for f in frames]
558
+
559
+ video = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].cuda().bfloat16()
560
+ video = [video]
561
+
562
+ if USE_DET:
563
+ video_tensor = []
564
+ for frame in raw_video:
565
+ processed = clip_processor(images=frame, return_tensors="pt")["pixel_values"].to(clip_model.device, dtype=torch.float16)
566
+ video_tensor.append(processed.squeeze(0))
567
+ video_tensor = torch.stack(video_tensor, dim=0)
568
+
569
+ if USE_OCR:
570
+ ocr_docs_total = get_ocr_docs(frames)
571
+
572
+ if USE_ASR:
573
+ if os.path.exists(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt")):
574
+ with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'r', encoding='utf-8') as f:
575
+ asr_docs_total = f.readlines()
576
+ # else:
577
+ # audio_path = os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".wav")
578
+ # # asr_docs_total = get_asr_docs(video_path, audio_path)
579
+ # with open(os.path.join("restore/audio", os.path.basename(video_path).split(".")[0] + ".txt"), 'w', encoding='utf-8') as f:
580
+ # for doc in asr_docs_total:
581
+ # f.write(doc + '\n')
582
+
583
+ # step 0: get cot information
584
+ retrieve_pmt_0 = "Question: " + question
585
+ # you can change this decouple prompt to fit your requirements
586
+ retrieve_pmt_0 += "\nTo answer the question step by step, you can provide your retrieve request to assist you by the following json format:"
587
+ retrieve_pmt_0 += '''{
588
+ "ASR": Optional[str]. The subtitles of the video that may relavent to the question you want to retrieve, in two sentences. If you no need for this information, please return null.
589
+ "DET": Optional[list]. (The output must include only physical entities, not abstract concepts, less than five entities) All the physical entities and their location related to the question you want to retrieve, not abstract concepts. If you no need for this information, please return null.
590
+ "TYPE": Optional[list]. (The output must be specified as null or a list containing only one or more of the following strings: 'location', 'number', 'relation'. No other values are valid for this field) The information you want to obtain about the detected objects. If you need the object location in the video frame, output "location"; if you need the number of specific object, output "number"; if you need the positional relationship between objects, output "relation".
591
+ }
592
+ ## Example 1:
593
+ Question: How many blue balloons are over the long table in the middle of the room at the end of this video? A. 1. B. 2. C. 3. D. 4.
594
+ Your retrieve can be:
595
+ {
596
+ "ASR": "The location and the color of balloons, the number of the blue balloons.",
597
+ "DET": ["blue ballons", "long table"],
598
+ "TYPE": ["relation", "number"]
599
+ }
600
+ ## Example 2:
601
+ Question: In the lower left corner of the video, what color is the woman wearing on the right side of the man in black clothes? A. Blue. B. White. C. Red. D. Yellow.
602
+ Your retrieve can be:
603
+ {
604
+ "ASR": null,
605
+ "DET": ["the man in black", "woman"],
606
+ "TYPE": ["location", "relation"]
607
+ }
608
+ ## Example 3:
609
+ Question: In which country is the comedy featured in the video recognized worldwide? A. China. B. UK. C. Germany. D. United States.
610
+ Your retrieve can be:
611
+ {
612
+ "ASR": "The country recognized worldwide for its comedy.",
613
+ "DET": null,
614
+ "TYPE": null
615
+ }
616
+ Note that you don't need to answer the question in this step, so you don't need any infomation about the video of image. You only need to provide your retrieve request (it's optional), and I will help you retrieve the infomation you want. Please provide the json format.'''
617
+
618
+ json_request = llava_inference(retrieve_pmt_0, None)
619
+
620
+ # step 1: get docs information
621
+ query = [question]
622
+
623
+ # APE fetch
624
+ if USE_DET:
625
+ det_docs = []
626
+ try:
627
+ request_det = json.loads(json_request)["DET"]
628
+ request_det = filter_keywords(request_det)
629
+ clip_text = ["A picture of " + txt for txt in request_det]
630
+ if len(clip_text) == 0:
631
+ clip_text = ["A picture of object"]
632
+ except:
633
+ request_det = None
634
+ clip_text = ["A picture of object"]
635
+
636
+ clip_inputs = clip_processor(text=clip_text, return_tensors="pt", padding=True, truncation=True).to(clip_model.device)
637
+ clip_img_feats = clip_model.get_image_features(video_tensor)
638
+ with torch.no_grad():
639
+ text_features = clip_model.get_text_features(**clip_inputs)
640
+ similarities = (clip_img_feats @ text_features.T).squeeze(0).mean(1).cpu()
641
+ similarities = np.array(similarities, dtype=np.float64)
642
+ alpha = beta * (len(similarities) / 16)
643
+ similarities = similarities * alpha / np.sum(similarities)
644
+
645
+ del clip_inputs, clip_img_feats, text_features
646
+ torch.cuda.empty_cache()
647
+
648
+ det_top_idx = [idx for idx in range(max_frames_num) if similarities[idx] > clip_threshold]
649
+
650
+ if request_det is not None and len(request_det) > 0:
651
+ det_docs = get_det_docs(frames[det_top_idx], request_det)
652
+
653
+ L, R, N = False, False, False
654
+ try:
655
+ det_retrieve_info = json.loads(json_request)["TYPE"]
656
+ except:
657
+ det_retrieve_info = None
658
+ if det_retrieve_info is not None:
659
+ if "location" in det_retrieve_info:
660
+ L = True
661
+ if "relation" in det_retrieve_info:
662
+ R = True
663
+ if "number" in det_retrieve_info:
664
+ N = True
665
+ det_docs = det_preprocess(det_docs, location=L, relation=R, number=N) # pre-process of APE information
666
+
667
+
668
+ # OCR fetch
669
+ if USE_OCR:
670
+ try:
671
+ request_det = json.loads(json_request)["DET"]
672
+ request_det = filter_keywords(request_det)
673
+ except:
674
+ request_det = None
675
+ ocr_docs = []
676
+ if len(ocr_docs_total) > 0:
677
+ ocr_query = query.copy()
678
+ if request_det is not None and len(request_det) > 0:
679
+ ocr_query.extend(request_det)
680
+ ocr_docs, _ = retrieve_documents_with_dynamic(ocr_docs_total, ocr_query, threshold=rag_threshold)
681
+
682
+ # ASR fetch
683
+ if USE_ASR:
684
+ asr_docs = []
685
+ try:
686
+ request_asr = json.loads(json_request)["ASR"]
687
+ except:
688
+ request_asr = None
689
+ if len(asr_docs_total) > 0:
690
+ asr_query = query.copy()
691
+ if request_asr is not None:
692
+ asr_query.append(request_asr)
693
+ asr_docs, _ = retrieve_documents_with_dynamic(asr_docs_total, asr_query, threshold=rag_threshold)
694
+
695
+ qs = ""
696
+ if USE_DET and len(det_docs) > 0:
697
+ for i, info in enumerate(det_docs):
698
+ if len(info) > 0:
699
+ qs += f"Frame {str(det_top_idx[i]+1)}: " + info + "\n"
700
+ if len(qs) > 0:
701
+ qs = f"\nVideo have {str(max_frames_num)} frames in total, the detected objects' information in specific frames: " + qs
702
+ if USE_ASR and len(asr_docs) > 0:
703
+ qs += "\nVideo Automatic Speech Recognition information (given in chronological order of the video): " + " ".join(asr_docs)
704
+ if USE_OCR and len(ocr_docs) > 0:
705
+ qs += "\nVideo OCR information (given in chronological order of the video): " + "; ".join(ocr_docs)
706
+ qs += "Select the best answer to the following multiple-choice question based on the video and the information (if given). Respond with only the letter (A, B, C, or D) of the correct option. Question: " + question # you can change this prompt
707
+
708
+ res = llava_inference(qs, video)
709
+ return res
710
+
711
+ demo = gr.Interface(
712
+ fn=process_query,
713
+ inputs=[gr.Video(label="Upload Video"), gr.Textbox(label="Question")],
714
+ outputs=gr.Textbox(label="Answer"),
715
+ title="Video Question Answering with LLaVA",
716
+ description="Upload a video and ask a question to get a summary or answer."
717
+ )
718
 
719
  if __name__ == "__main__":
720
+ demo.launch()