diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..f9ec6e2fdd9173a6886cdac2517aaf8375240d0f 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+data/10309844035.mp4 filter=lfs diff=lfs merge=lfs -text
+data/13887487955.mp4 filter=lfs diff=lfs merge=lfs -text
+data/4167294363.mp4 filter=lfs diff=lfs merge=lfs -text
+data/4742652230.mp4 filter=lfs diff=lfs merge=lfs -text
+data/4766274786.mp4 filter=lfs diff=lfs merge=lfs -text
+data/5012237466.mp4 filter=lfs diff=lfs merge=lfs -text
+data/5188348585.mp4 filter=lfs diff=lfs merge=lfs -text
+data/9383140374.mp4 filter=lfs diff=lfs merge=lfs -text
+data/DTInxNfWXVc_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/RoripwjYFp8_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/UFWQKrcbhjI_360.0_510.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/Z3-IZ3HAmIA_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/h6QKDqomIPk_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/pA6Z-qYhSNg_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/rrTIeJRVGjg_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
+data/yId2wIocTys_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fe47536fab4a3c41da6f2b55620f067d0b32fd66
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+# Byte-compiled / optimized / DLL files
+__pycache__
+*.egg-info
+*.py[cod]
+*$py.class
+
+# Temporary data
+.DS_Store
+._*
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a5a997f4880eb71330bf2f833ae821a45ef2fc0
--- /dev/null
+++ b/app.py
@@ -0,0 +1,640 @@
+# Copyright (c) 2024 Ye Liu. Licensed under the BSD-3-Clause license.
+
+import html
+import json
+import os
+import random
+import time
+from functools import partial
+from threading import Thread
+
+import gradio as gr
+import nncore
+import torch
+from huggingface_hub import snapshot_download
+from transformers import TextIteratorStreamer
+
+from videomind.constants import GROUNDER_PROMPT, PLANNER_PROMPT, VERIFIER_PROMPT
+from videomind.dataset.utils import process_vision_info
+from videomind.model.builder import build_model
+from videomind.utils.io import get_duration
+from videomind.utils.parser import parse_query, parse_span
+
+BASE_MODEL = 'model_zoo/Qwen2-VL-2B-Instruct'
+BASE_MODEL_HF = 'Qwen/Qwen2-VL-2B-Instruct'
+
+MODEL = 'model_zoo/VideoMind-2B'
+MODEL_HF = 'yeliudev/VideoMind-2B'
+
+TITLE = 'VideoMind: A Chain-of-LoRA Agent for Long Video Reasoning'
+
+TITLE_MD = f'
💡 {TITLE}
'
+DESCRIPTION_MD = """VideoMind is a multi-modal agent framework that enhances video reasoning by emulating *human-like* processes, such as *breaking down tasks*, *localizing and verifying moments*, and *synthesizing answers*. This approach addresses the unique challenges of temporal-grounded reasoning in a progressive strategy. Please find more details at our Project Page, Tech Report and GitHub Repo.""" # noqa
+
+# yapf:disable
+EXAMPLES = [
+ ('data/4167294363.mp4', 'Why did the old man stand up?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/5012237466.mp4', 'How does the child in stripes react about the fountain?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/13887487955.mp4', 'What did the excavator do after it pushed the cement forward?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/5188348585.mp4', 'What did the person do before pouring the liquor?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/4766274786.mp4', 'What did the girl do after the baby lost the balloon?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/4742652230.mp4', 'Why is the girl pushing the boy only around the toy but not to other places?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/9383140374.mp4', 'How does the girl in pink control the movement of the claw?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/10309844035.mp4', 'Why are they holding up the phones?', ['pla', 'gnd', 'ver', 'ans']),
+ ('data/pA6Z-qYhSNg_60.0_210.0.mp4', 'Different types of meat products are being cut, shaped and prepared', ['gnd', 'ver']),
+ ('data/UFWQKrcbhjI_360.0_510.0.mp4', 'A man talks to the camera whilst walking along a roadside in a rural area', ['gnd', 'ver']),
+ ('data/RoripwjYFp8_210.0_360.0.mp4', 'A woman wearing glasses eating something at a street market', ['gnd', 'ver']),
+ ('data/h6QKDqomIPk_210.0_360.0.mp4', 'A toddler sits in his car seat, holding his yellow tablet', ['gnd', 'ver']),
+ ('data/Z3-IZ3HAmIA_60.0_210.0.mp4', 'A view from the window as the plane accelerates and takes off from the runway', ['gnd', 'ver']),
+ ('data/yId2wIocTys_210.0_360.0.mp4', "Temporally locate the visual content mentioned in the text query 'kids exercise in front of parked cars' within the video.", ['pla', 'gnd', 'ver']),
+ ('data/rrTIeJRVGjg_60.0_210.0.mp4', "Localize the moment that provides relevant context about 'man stands in front of a white building monologuing'.", ['pla', 'gnd', 'ver']),
+ ('data/DTInxNfWXVc_210.0_360.0.mp4', "Find the video segment that corresponds to the given textual query 'man with headphones talking'.", ['pla', 'gnd', 'ver']),
+]
+# yapf:enable
+
+CSS = """button .box { text-align: left }"""
+
+JS = """
+function init() {
+ var info = document.getElementById('role').querySelectorAll('[class^="svelte"]')[1]
+ info.innerHTML = info.innerHTML.replace(/</g, '<').replace(/>/g, '>')
+}
+"""
+
+
+class CustomStreamer(TextIteratorStreamer):
+
+ def put(self, value):
+ if len(value.shape) > 1 and value.shape[0] > 1:
+ raise ValueError('TextStreamer only supports batch size 1')
+ elif len(value.shape) > 1:
+ value = value[0]
+
+ if self.skip_prompt and self.next_tokens_are_prompt:
+ self.next_tokens_are_prompt = False
+ return
+
+ self.token_cache.extend(value.tolist())
+
+ # force skipping eos token
+ if self.token_cache[-1] == self.tokenizer.eos_token_id:
+ self.token_cache = self.token_cache[:-1]
+
+ text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
+
+ # cache decoded text for future use
+ self.text_cache = text
+
+ if text.endswith('\n'):
+ printable_text = text[self.print_len:]
+ self.token_cache = []
+ self.print_len = 0
+ elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
+ printable_text = text[self.print_len:]
+ self.print_len += len(printable_text)
+ else:
+ printable_text = text[self.print_len:text.rfind(' ') + 1]
+ self.print_len += len(printable_text)
+
+ self.on_finalized_text(printable_text)
+
+
+def seconds_to_hms(seconds):
+ hours, remainder = divmod(round(seconds), 3600)
+ minutes, seconds = divmod(remainder, 60)
+ return f'{hours:02}:{minutes:02}:{seconds:02}'
+
+
+def enable_btns():
+ return (gr.Button(interactive=True), ) * 3
+
+
+def disable_btns():
+ return (gr.Button(interactive=False), ) * 3
+
+
+def update_placeholder(role):
+ placeholder = 'Ask a question about the video...' if 'ans' in role else 'Write a query to search for a moment...'
+ return gr.Textbox(placeholder=placeholder)
+
+
+def main(video, prompt, role, temperature, max_new_tokens, model, processor, streamer, device):
+ history = []
+
+ if not video:
+ gr.Warning('Please upload a video or click [Random] to sample one.')
+ return history
+
+ if not prompt:
+ gr.Warning('Please provide a prompt or click [Random] to sample one.')
+ return history
+
+ if 'gnd' not in role and 'ans' not in role:
+ gr.Warning('Please at least select Grounder or Answerer.')
+ return history
+
+ if 'ver' in role and 'gnd' not in role:
+ gr.Warning('Verifier cannot be used without Grounder.')
+ return history
+
+ if 'pla' in role and any(k not in role for k in ('gnd', 'ver', 'ans')):
+ gr.Warning('Planner can only be used when all other roles are selected.')
+ return history
+
+ history.append({'role': 'user', 'content': prompt})
+ yield history
+
+ duration = get_duration(video)
+
+ # do grounding and answering by default
+ do_grounding = True
+ do_answering = True
+
+ # initialize grounding query as prompt
+ query = prompt
+
+ if 'pla' in role:
+ text = PLANNER_PROMPT.format(prompt)
+
+ history.append({
+ 'metadata': {
+ 'title': '🗺️ Working as Planner...'
+ },
+ 'role': 'assistant',
+ 'content': f'##### Planner Prompt:\n\n{html.escape(text)}\n\n##### Planner Response:\n\n...'
+ })
+ yield history
+
+ start_time = time.perf_counter()
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video,
+ 'num_threads': 1,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 100,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': text
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('planner')
+
+ generation_kwargs = dict(
+ **data,
+ streamer=streamer,
+ do_sample=temperature > 0,
+ temperature=temperature if temperature > 0 else None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=max_new_tokens)
+
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
+ t.start()
+
+ skipped = False
+ for i, text in enumerate(streamer):
+ if text and not skipped:
+ history[-1]['content'] = history[-1]['content'].rstrip('.')
+ skipped = True
+ history[-1]['content'] += text
+ yield history
+
+ elapsed_time = round(time.perf_counter() - start_time, 1)
+ history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
+ yield history
+
+ try:
+ parsed = json.loads(streamer.text_cache)
+ action = parsed[0] if isinstance(parsed, list) else parsed
+ if action['type'].lower() == 'grounder' and action['value']:
+ query = action['value']
+ elif action['type'].lower() == 'answerer':
+ do_grounding = False
+ do_answering = True
+ except Exception:
+ pass
+
+ response = 'After browsing the video and the question. My plan to figure out the answer is as follows:\n'
+ step_idx = 1
+ if 'gnd' in role and do_grounding:
+ response += f'\n{step_idx}. Localize the relevant moment in this video using the query "{query}".'
+ step_idx += 1
+ if 'ver' in role and do_grounding:
+ response += f'\n{step_idx}. Verify the grounded moments one-by-one and select the best cancdidate.'
+ step_idx += 1
+ if 'ans' in role and do_answering:
+ if step_idx > 1:
+ response += f'\n{step_idx}. Crop the video segment and zoom-in to higher resolution.'
+ else:
+ response += f'\n{step_idx}. Analyze the whole video directly without cropping.'
+
+ history.append({'role': 'assistant', 'content': ''})
+ for i, text in enumerate(response.split(' ')):
+ history[-1]['content'] += ' ' + text if i > 0 else text
+ yield history
+
+ if 'gnd' in role and do_grounding:
+ query = parse_query(query)
+
+ text = GROUNDER_PROMPT.format(query)
+
+ history.append({
+ 'metadata': {
+ 'title': '🔍 Working as Grounder...'
+ },
+ 'role': 'assistant',
+ 'content': f'##### Grounder Prompt:\n\n{html.escape(text)}\n\n##### Grounder Response:\n\n...'
+ })
+ yield history
+
+ start_time = time.perf_counter()
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video,
+ 'num_threads': 1,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 150,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': text
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('grounder')
+
+ generation_kwargs = dict(
+ **data,
+ streamer=streamer,
+ do_sample=temperature > 0,
+ temperature=temperature if temperature > 0 else None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=max_new_tokens)
+
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
+ t.start()
+
+ skipped = False
+ for i, text in enumerate(streamer):
+ if text and not skipped:
+ history[-1]['content'] = history[-1]['content'].rstrip('.')
+ skipped = True
+ history[-1]['content'] += text
+ yield history
+
+ elapsed_time = round(time.perf_counter() - start_time, 1)
+ history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
+ yield history
+
+ if len(model.reg) > 0:
+ # 1. extract timestamps and confidences
+ blob = model.reg[0].cpu().float()
+ pred, conf = blob[:, :2] * duration, blob[:, -1].tolist()
+
+ # 2. clamp timestamps
+ pred = pred.clamp(min=0, max=duration)
+
+ # 3. sort timestamps
+ inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
+ pred[inds] = pred[inds].roll(1)
+
+ # 4. convert timestamps to list
+ pred = pred.tolist()
+ else:
+ if 'ver' in role:
+ pred = [[i * duration / 6, (i + 2) * duration / 6] for i in range(5)]
+ conf = [0] * 5
+ else:
+ pred = [[0, duration]]
+ conf = [0]
+
+ response = 'The candidate moments and confidence scores are as follows:\n'
+ response += '\n| ID | Start Time | End Time | Confidence |'
+ response += '\n| :-: | :-: | :-: | :-: |'
+
+ # using top-5 predictions
+ for i, (p, c) in enumerate(zip(pred[:5], conf[:5])):
+ response += f'\n| {i} | {seconds_to_hms(p[0])} | {seconds_to_hms(p[1])} | {c:.2f} |'
+
+ response += f'\n\nTherefore, the target moment might happens from {seconds_to_hms(pred[0][0])} to {seconds_to_hms(pred[0][1])}.'
+
+ history.append({'role': 'assistant', 'content': ''})
+ for i, text in enumerate(response.split(' ')):
+ history[-1]['content'] += ' ' + text if i > 0 else text
+ yield history
+
+ if 'ver' in role and do_grounding:
+ text = VERIFIER_PROMPT.format(query)
+
+ history.append({
+ 'metadata': {
+ 'title': '📊 Working as Verifier...'
+ },
+ 'role': 'assistant',
+ 'content': f'##### Verifier Prompt:\n\n{html.escape(text)}\n\n##### Verifier Response:\n\n...'
+ })
+ yield history
+
+ start_time = time.perf_counter()
+
+ # using top-5 predictions
+ prob = []
+ for i, cand in enumerate(pred[:5]):
+ s0, e0 = parse_span(cand, duration, 2)
+ offset = (e0 - s0) / 2
+ s1, e1 = parse_span([s0 - offset, e0 + offset], duration)
+
+ # percentage of s0, e0 within s1, e1
+ s = (s0 - s1) / (e1 - s1)
+ e = (e0 - s1) / (e1 - s1)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video,
+ 'num_threads': 1,
+ 'video_start': s1,
+ 'video_end': e1,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 64,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': text
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+
+ # ===== insert segment start/end tokens =====
+ video_grid_thw = data['video_grid_thw'][0]
+ num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4)
+ assert num_frames * window * 4 == data['pixel_values_videos'].size(0)
+
+ pos_s, pos_e = round(s * num_frames), round(e * num_frames)
+ pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames)
+ assert pos_s <= pos_e, (num_frames, s, e)
+
+ base_idx = torch.nonzero(data['input_ids'][0] == model.config.vision_start_token_id).item()
+ pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2
+
+ input_ids = data['input_ids'][0].tolist()
+ input_ids.insert(pos_s, model.config.seg_s_token_id)
+ input_ids.insert(pos_e, model.config.seg_e_token_id)
+ data['input_ids'] = torch.LongTensor([input_ids])
+ data['attention_mask'] = torch.ones_like(data['input_ids'])
+ # ===========================================
+
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('verifier')
+
+ with torch.inference_mode():
+ logits = model(**data).logits[0, -1].softmax(dim=-1)
+
+ # NOTE: magic numbers here
+ # In Qwen2-VL vocab: 9454 -> Yes, 2753 -> No
+ score = (logits[9454] - logits[2753]).sigmoid().item()
+ prob.append(score)
+
+ if i == 0:
+ history[-1]['content'] = history[-1]['content'].rstrip('.')[:-1]
+
+ response = f'\nCandidate ID {i}: P(Yes) = {score:.2f}'
+ for j, text in enumerate(response.split(' ')):
+ history[-1]['content'] += ' ' + text if j > 0 else text
+ yield history
+
+ elapsed_time = round(time.perf_counter() - start_time, 1)
+ history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
+ yield history
+
+ ranks = torch.Tensor(prob).argsort(descending=True).tolist()
+
+ prob = [prob[idx] for idx in ranks]
+ pred = [pred[idx] for idx in ranks]
+ conf = [conf[idx] for idx in ranks]
+
+ response = 'After verification, the candidate moments are re-ranked as follows:\n'
+ response += '\n| ID | Start Time | End Time | Score |'
+ response += '\n| :-: | :-: | :-: | :-: |'
+
+ ids = list(range(len(ranks)))
+ for r, p, c in zip(ranks, pred, prob):
+ response += f'\n| {ids[r]} | {seconds_to_hms(p[0])} | {seconds_to_hms(p[1])} | {c:.2f} |'
+
+ response += f'\n\nTherefore, the target moment should be from {seconds_to_hms(pred[0][0])} to {seconds_to_hms(pred[0][1])}.'
+
+ history.append({'role': 'assistant', 'content': ''})
+ for i, text in enumerate(response.split(' ')):
+ history[-1]['content'] += ' ' + text if i > 0 else text
+ yield history
+
+ if 'ans' in role and do_answering:
+ text = f'{prompt} Please think step by step and provide your response.'
+
+ history.append({
+ 'metadata': {
+ 'title': '📝 Working as Answerer...'
+ },
+ 'role': 'assistant',
+ 'content': f'##### Answerer Prompt:\n\n{html.escape(text)}\n\n##### Answerer Response:\n\n...'
+ })
+ yield history
+
+ start_time = time.perf_counter()
+
+ # choose the potential best moment
+ selected = pred[0] if 'gnd' in role and do_grounding else [0, duration]
+ s, e = parse_span(selected, duration, 32)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video,
+ 'num_threads': 1,
+ 'video_start': s,
+ 'video_end': e,
+ 'min_pixels': 128 * 28 * 28,
+ 'max_pixels': 256 * 28 * 28,
+ 'max_frames': 32,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': text
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ with model.disable_adapter():
+ generation_kwargs = dict(
+ **data,
+ streamer=streamer,
+ do_sample=temperature > 0,
+ temperature=temperature if temperature > 0 else None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=max_new_tokens)
+
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
+ t.start()
+
+ skipped = False
+ for i, text in enumerate(streamer):
+ if text and not skipped:
+ history[-1]['content'] = history[-1]['content'].rstrip('.')
+ skipped = True
+ history[-1]['content'] += text
+ yield history
+
+ elapsed_time = round(time.perf_counter() - start_time, 1)
+ history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
+ yield history
+
+ if 'gnd' in role and do_grounding:
+ response = f'After zooming in and analyzing the target moment, I finalize my answer: {streamer.text_cache}'
+ else:
+ response = f'After watching the whole video, my answer is: {streamer.text_cache}'
+
+ history.append({'role': 'assistant', 'content': ''})
+ for i, text in enumerate(response.split(' ')):
+ history[-1]['content'] += ' ' + text if i > 0 else text
+ yield history
+
+
+if __name__ == '__main__':
+ if not nncore.is_dir(BASE_MODEL):
+ snapshot_download(BASE_MODEL_HF, local_dir=BASE_MODEL)
+
+ if not nncore.is_dir(MODEL):
+ snapshot_download(MODEL_HF, local_dir=MODEL)
+
+ print('Initializing role *grounder*')
+ model, processor = build_model(MODEL)
+
+ print('Initializing role *planner*')
+ model.load_adapter(nncore.join(MODEL, 'planner'), adapter_name='planner')
+
+ print('Initializing role *verifier*')
+ model.load_adapter(nncore.join(MODEL, 'verifier'), adapter_name='verifier')
+
+ streamer = CustomStreamer(processor.tokenizer, skip_prompt=True)
+
+ device = next(model.parameters()).device
+
+ main = partial(main, model=model, processor=processor, streamer=streamer, device=device)
+
+ path = os.path.dirname(os.path.realpath(__file__))
+
+ chat = gr.Chatbot(
+ type='messages',
+ height='70vh',
+ avatar_images=[f'{path}/assets/user.png', f'{path}/assets/bot.png'],
+ placeholder='A conversation with VideoMind',
+ label='VideoMind')
+
+ prompt = gr.Textbox(label='Text Prompt', placeholder='Ask a question about the video...')
+
+ with gr.Blocks(title=TITLE, css=CSS, js=JS) as demo:
+ gr.Markdown(TITLE_MD)
+ gr.Markdown(DESCRIPTION_MD)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ video = gr.Video()
+
+ with gr.Group():
+ role = gr.CheckboxGroup(
+ choices=[('🗺️ Planner', 'pla'), ('🔍 Grounder', 'gnd'), ('📊 Verifier', 'ver'),
+ ('📝 Answerer', 'ans')],
+ value=['pla', 'gnd', 'ver', 'ans'],
+ interactive=True,
+ elem_id='role',
+ label='Role(s) To Use',
+ info='[Auto Planning]: Planner + Grounder + Verifier + Answerer
'
+ '[Grounded Video Question-Answering]: Grounder + Verifier + Answerer
'
+ '[Video Temporal Grounding]: Grounder + Verifier
'
+ '[Direct Video Question-Answering]: Answerer
')
+ role.change(update_placeholder, role, prompt)
+
+ with gr.Accordion(label='Hyperparameters', open=False):
+ temperature = gr.Slider(
+ 0,
+ 1,
+ value=0,
+ step=0.1,
+ interactive=True,
+ label='Temperature',
+ info='Higher value leads to more creativity and randomness (Default: 0)')
+ max_new_tokens = gr.Slider(
+ 1,
+ 1024,
+ value=256,
+ interactive=True,
+ label='Max Output Tokens',
+ info='The maximum number of output tokens for each role (Default: 256)')
+
+ prompt.render()
+
+ with gr.Row():
+ random_btn = gr.Button(value='🔮 Random')
+ random_btn.click(lambda: random.choice(EXAMPLES), None, [video, prompt, role])
+
+ reset_btn = gr.ClearButton([video, prompt, chat], value='🗑️ Reset')
+ reset_btn.click(lambda: (['pla', 'gnd', 'ver', 'ans'], 0, 256), None,
+ [role, temperature, max_new_tokens])
+
+ submit_btn = gr.Button(value='🚀 Submit', variant='primary')
+ submit_ctx = submit_btn.click(disable_btns, None, [random_btn, reset_btn, submit_btn])
+ submit_ctx = submit_ctx.then(main, [video, prompt, role, temperature, max_new_tokens], chat)
+ submit_ctx.then(enable_btns, None, [random_btn, reset_btn, submit_btn])
+
+ with gr.Column(scale=5):
+ chat.render()
+
+ demo.queue()
+ demo.launch(server_name='0.0.0.0')
diff --git a/assets/bot.png b/assets/bot.png
new file mode 100644
index 0000000000000000000000000000000000000000..696a3bb5c360358ca7174f49746001cba331e6b4
Binary files /dev/null and b/assets/bot.png differ
diff --git a/assets/user.png b/assets/user.png
new file mode 100644
index 0000000000000000000000000000000000000000..e43cee4e6edd51e83ab35aa6270cf16e80e7622d
Binary files /dev/null and b/assets/user.png differ
diff --git a/data/10309844035.mp4 b/data/10309844035.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..76e90e5f9f6c97f6f5dfe19242227a489a981169
--- /dev/null
+++ b/data/10309844035.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8996ff134787d6b769c2491b9079a02c05953465ad770f07a8d9138e2668d24f
+size 4041678
diff --git a/data/13887487955.mp4 b/data/13887487955.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1e83efb63125b1ca29cd6162eaa5843b034f5961
--- /dev/null
+++ b/data/13887487955.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5fecab1076ee42b3804718f9f64bef06cbfafd6995ad5f5ee42ba6354721429
+size 5544739
diff --git a/data/4167294363.mp4 b/data/4167294363.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..339380896781b1ec77eaea1122a2dbcb0b91fc99
--- /dev/null
+++ b/data/4167294363.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d0e0a4a381836f68e16a816d87f241fed3e31ea321f544b921743d6c1c50666
+size 6611151
diff --git a/data/4742652230.mp4 b/data/4742652230.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0e45cf00cf375d5eb97ab9fd179854daae792e27
--- /dev/null
+++ b/data/4742652230.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8733ab4b0716d13ea7a79fc4ddacaf9eede567db364f0ecddfa4582c2f237f82
+size 2200304
diff --git a/data/4766274786.mp4 b/data/4766274786.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a8e42d09d51e795f05f895e788f9022470691f3f
--- /dev/null
+++ b/data/4766274786.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afa38a9ce9e89f934293214d79755c89159664223b3ca366813fd5fe524ed013
+size 3395545
diff --git a/data/5012237466.mp4 b/data/5012237466.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..1a96cf099d8d4baabc7ee45ebb6f78a5ad11fe48
--- /dev/null
+++ b/data/5012237466.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd1929aa93d037f809f402e9801047125dc9fe8060301e69ded9ba1f2d785cc8
+size 4822293
diff --git a/data/5188348585.mp4 b/data/5188348585.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9e55d094776b4f180cc0fb008d1a9ac171ebd316
--- /dev/null
+++ b/data/5188348585.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b225f448a546ba2f65958f18c6731a6dde9b1f437014e90036b22eb40e9ad0a5
+size 5051675
diff --git a/data/9383140374.mp4 b/data/9383140374.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b6899d90c9ecb2e4dfc82f003c270a84e877e59d
--- /dev/null
+++ b/data/9383140374.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30b6b3eb43f711bef194150d473a59850ff5d7fec0f5cc30e7526aa9e382303f
+size 2518081
diff --git a/data/DTInxNfWXVc_210.0_360.0.mp4 b/data/DTInxNfWXVc_210.0_360.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a4d851d6b9c28fef1e4a5b0feba8462faacf5ed1
--- /dev/null
+++ b/data/DTInxNfWXVc_210.0_360.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a09eee0dc404688731fb768c120d3519605f2343376b9bd727a71b91379fd9a9
+size 4999970
diff --git a/data/RoripwjYFp8_210.0_360.0.mp4 b/data/RoripwjYFp8_210.0_360.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7f04c0f8971239598ee92cdafbf4ba601c8e7a9c
--- /dev/null
+++ b/data/RoripwjYFp8_210.0_360.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b39b15158dc20c0bc6f1758a9239c8f3eed20ba4a90953338eec2246fa8f1f0
+size 9287252
diff --git a/data/UFWQKrcbhjI_360.0_510.0.mp4 b/data/UFWQKrcbhjI_360.0_510.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..e87ba5ae900d689cc0ad626b4599af6ac25cd28d
--- /dev/null
+++ b/data/UFWQKrcbhjI_360.0_510.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8669153d9ffac4b5534c20fab8d795347f5babe588da9b8330e049d623ebb443
+size 14510618
diff --git a/data/Z3-IZ3HAmIA_60.0_210.0.mp4 b/data/Z3-IZ3HAmIA_60.0_210.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..739af1eea71252773de9c6392548a1ef086b8ad8
--- /dev/null
+++ b/data/Z3-IZ3HAmIA_60.0_210.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b3a342993ee61efc5f3b859cd9c1e0d360b3331eed9deb8466891e4bcacc554
+size 14397799
diff --git a/data/h6QKDqomIPk_210.0_360.0.mp4 b/data/h6QKDqomIPk_210.0_360.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..819d3e86ec306105087248a3983c5c86cc0bd3cd
--- /dev/null
+++ b/data/h6QKDqomIPk_210.0_360.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:103820de2b8a1a3935b39ed80d91cd08e546e5617310b3d1bb3dadb06b2ffb95
+size 13485144
diff --git a/data/pA6Z-qYhSNg_60.0_210.0.mp4 b/data/pA6Z-qYhSNg_60.0_210.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2ba3f64450c0f778e547d7ce2d7279bae48e52cc
--- /dev/null
+++ b/data/pA6Z-qYhSNg_60.0_210.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c84660fd4ebd8c23a2a7364174b1e819fec8b0e1cb8b9d9cd86f9e429cbdf66c
+size 8658509
diff --git a/data/rrTIeJRVGjg_60.0_210.0.mp4 b/data/rrTIeJRVGjg_60.0_210.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..9adaea8c1071a4b2e1caab23223b057bc8a9d052
--- /dev/null
+++ b/data/rrTIeJRVGjg_60.0_210.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efe6f48a49963bd4880ef5065840e05dd25e2aa975870140bcdaf4220bbd2827
+size 11410412
diff --git a/data/yId2wIocTys_210.0_360.0.mp4 b/data/yId2wIocTys_210.0_360.0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ad9d951b134ed06f0473f0cd5da48bdd404d66a5
--- /dev/null
+++ b/data/yId2wIocTys_210.0_360.0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:447fcb1fd1f94ed6a88d56dd0f6f859646cb8c58ed8e3b7a82f374e2cfee1646
+size 14769130
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..77959fa0157d314d3505add89f3372c8e7a27662
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,26 @@
+accelerate==1.2.1
+decord==0.6.0
+gradio==4.44.1
+pandas==2.2.3
+peft==0.14.0
+pysrt==1.1.2
+scikit-image==0.25.0
+scikit-learn==1.6.1
+sentencepiece==0.2.0
+termplotlib==0.3.9
+triton==3.0.0
+
+# our codebase contains necessary patches for 4.45.2
+transformers==4.45.2
+
+# https://github.com/microsoft/DeepSpeed/issues/6793
+deepspeed==0.15.4
+
+# https://github.com/pytorch/pytorch/issues/138386
+torch==2.4.1
+torchvision==0.19.1
+
+# torch-npu only supports torch 2.4.0
+# torch==2.4.0+cpu
+# torch-npu==2.4.0.post2
+# torchvision==0.19.0+cpu
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..3e1da030391c89b484ebd58e303514e1e409af80
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,16 @@
+[yapf]
+column_limit = 120
+based_on_style = pep8
+blank_line_before_nested_class_or_def = true
+split_before_expression_after_opening_paren = true
+
+[isort]
+line_length = 120
+multi_line_output = 0
+known_third_party = decord,deepspeed,gradio,huggingface_hub,nncore,numpy,pandas,peft,PIL,pysrt,safetensors,tabulate,termplotlib,torch,torchvision,transformers
+no_lines_before = STDLIB,LOCALFOLDER
+default_section = FIRSTPARTY
+
+[flake8]
+max-line-length = 500
+extend-ignore = E741
diff --git a/videomind/constants.py b/videomind/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..31da716b64740f8e5574765328c67ee7f42ce1e9
--- /dev/null
+++ b/videomind/constants.py
@@ -0,0 +1,42 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+IGNORE_INDEX = -100
+
+REG_TOKEN = '<|reg|>'
+
+SEG_S_TOKEN = '<|seg_start|>'
+SEG_E_TOKEN = '<|seg_end|>'
+
+PLANNER_PROMPT = (
+ 'You are acting as the planner now. '
+ 'Given a question about the video, your task is to analyze the question and identify the best way to answer this question. '
+ 'You have access to the following tools:\n\n'
+ 'Grounder: Accepts a text query and localize the relevant video segment according to the query.\n'
+ 'Verifier: A tool supporting grounder by verifying the reliability of its outputs.\n'
+ 'Answerer: Answer a given question directly based on the whole video or a cropped video segment.\n\n'
+ 'Your response must be a list in JSON format. '
+ 'A valid plan for reasoning could be "grounder, verifier, answer", "grounder, verifier", or "answerer", depending on the given question. '
+ 'Please see an example for the format below.\n\n'
+ '[{{"type": "grounder", "value": ""}}, {{"type": "verifier"}}, {{"type": "answerer"}}]\n\n'
+ 'Note that only the grounder can accept an argument called "value", which is the text query used for grounding. '
+ "Now I give you the question: '{}'. "
+ 'Please think carefully and respond with your plan in JSON directly.')
+
+GROUNDER_PROMPT = (
+ 'You are acting as the grounder now. '
+ 'Given a video and a text query, your goal is to temporally localize the video moment described by the query. '
+ 'If the query is directly describing a moment, simply localize it according to its content. '
+ "Otherwise, if the moment is described as 'before/after a pivotal event', you need to determine the actual event it refers to. "
+ 'The localized moment should only cover the target event. '
+ "Now I give you the query: '{}'. "
+ 'Please think carefully and provide your response.')
+
+VERIFIER_PROMPT = (
+ 'You are acting as the verifier now. '
+ 'You will be presented a text query describing a moment that potentialy happens in the given video. '
+ f'Your task is to identify whether the video segment between {SEG_S_TOKEN} and {SEG_E_TOKEN} perfectly covers the moment. '
+ f'If the described moment can be seen in the video, please focus on verifying whether the moment starts at {SEG_S_TOKEN} and ends at {SEG_E_TOKEN}. '
+ "Respond with 'Yes' if you think the moment boundaries are correct, otherwise 'No'. "
+ "If the described moment cannot be seen in the video, respond with 'No' directly. "
+ "Now I give you the query: '{}'. "
+ "Please think carefully and respond with 'Yes' or 'No' directly.")
diff --git a/videomind/conversation.py b/videomind/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..04e8f645099b8cc0838ac8e1acde1180a771158e
--- /dev/null
+++ b/videomind/conversation.py
@@ -0,0 +1,49 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class Conversation:
+ style: str
+ system: str
+ roles: List[str]
+ seps: List[str]
+ messages: List[str]
+
+ def append_message(self, role, msg):
+ self.messages.append([role, msg])
+
+ def clear(self):
+ self.messages = []
+
+ def get_prompt(self):
+ assert self.style in ('chatml', )
+
+ prompt = self.system + self.seps[0] if self.system is not None else ''
+
+ for i, (role, msg) in enumerate(self.messages):
+ prompt += role
+ sep = self.seps[i % 2]
+ if msg is not None:
+ prompt += msg
+ if not prompt.endswith(sep):
+ prompt += sep
+
+ prompt = prompt.lstrip('\n')
+ return prompt
+
+
+def get_conv(conv_type):
+ if conv_type == 'chatml':
+ conv = Conversation(
+ style='chatml',
+ system='<|im_start|>system\nYou are a helpful assistant.',
+ roles=('\n<|im_start|>user\n', '\n<|im_start|>assistant\n'),
+ seps=('<|im_end|>', '<|im_end|>'),
+ messages=[])
+ else:
+ raise ValueError(f'unknown conversation type: {conv_type}')
+
+ return conv
diff --git a/videomind/dataset/__init__.py b/videomind/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7a88064ccc36a04f9e565bfe7ba1d850509c655
--- /dev/null
+++ b/videomind/dataset/__init__.py
@@ -0,0 +1,61 @@
+from .collator import HybridDataCollator
+from .hybrid import HybridDataset
+from .sub_classes import (ActivitynetCaptionsBiasDataset, ActivitynetCaptionsDataset, ActivitynetRTLDataset,
+ CGBenchDataset, CharadesSTADataset, CosMoCapDataset, DiDeMoDataset, Ego4DNaQDataset,
+ Ego4DNLQDataset, EgoTimeQACropDataset, EgoTimeQADataset, EgoTimeQAGroundingDataset,
+ HiRESTGroundingDataset, HiRESTStepBiasDataset, HiRESTStepDataset, InternVidVTimeDataset,
+ LongVideoBenchDataset, LVBenchDataset, MLVUDataset, MVBenchDataset, NExTGQACropDataset,
+ NExTGQADataset, NExTGQAGroundingDataset, NExTQADataset, QAEgo4DCropDataset, QAEgo4DDataset,
+ QAEgo4DGroundingDataset, QuerYDDataset, QVHighlightsDataset, ReXTimeCropDataset,
+ ReXTimeDataset, ReXTimeGroundingDataset, STARDataset, TACoSDataset, VideoMMEDataset,
+ VideoXumDataset, VidMorpDataset, YouCook2BiasDataset, YouCook2Dataset)
+from .wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset, PlanningDataset, VerifyingDataset
+
+__all__ = [
+ 'HybridDataCollator',
+ 'HybridDataset',
+ 'ActivitynetCaptionsBiasDataset',
+ 'ActivitynetCaptionsDataset',
+ 'ActivitynetRTLDataset',
+ 'CGBenchDataset',
+ 'CharadesSTADataset',
+ 'CosMoCapDataset',
+ 'DiDeMoDataset',
+ 'Ego4DNaQDataset',
+ 'Ego4DNLQDataset',
+ 'EgoTimeQACropDataset',
+ 'EgoTimeQADataset',
+ 'EgoTimeQAGroundingDataset',
+ 'HiRESTGroundingDataset',
+ 'HiRESTStepBiasDataset',
+ 'HiRESTStepDataset',
+ 'InternVidVTimeDataset',
+ 'LongVideoBenchDataset',
+ 'LVBenchDataset',
+ 'MLVUDataset',
+ 'MVBenchDataset',
+ 'NExTGQACropDataset',
+ 'NExTGQADataset',
+ 'NExTGQAGroundingDataset',
+ 'NExTQADataset',
+ 'QAEgo4DCropDataset',
+ 'QAEgo4DDataset',
+ 'QAEgo4DGroundingDataset',
+ 'QuerYDDataset',
+ 'QVHighlightsDataset',
+ 'ReXTimeCropDataset',
+ 'ReXTimeDataset',
+ 'ReXTimeGroundingDataset',
+ 'STARDataset',
+ 'TACoSDataset',
+ 'VideoMMEDataset',
+ 'VideoXumDataset',
+ 'VidMorpDataset',
+ 'YouCook2BiasDataset',
+ 'YouCook2Dataset',
+ 'AnsweringCropDataset',
+ 'AnsweringDataset',
+ 'GroundingDataset',
+ 'PlanningDataset',
+ 'VerifyingDataset',
+]
diff --git a/videomind/dataset/collator.py b/videomind/dataset/collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6d2f34d36c9a1b1970d6f5035b5f10e1e3557f
--- /dev/null
+++ b/videomind/dataset/collator.py
@@ -0,0 +1,40 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import warnings
+
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+from videomind.constants import IGNORE_INDEX
+
+
+class HybridDataCollator(object):
+
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+
+ def __call__(self, batch):
+ input_ids = [d['input_ids'] for d in batch]
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
+
+ labels = [d['labels'] for d in batch]
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
+
+ assert input_ids.size() == labels.size()
+
+ seq_len, max_len = input_ids.size(1), self.tokenizer.model_max_length
+ if seq_len > max_len:
+ warnings.warn(f'The length of input sequence is exceeding model max length: {seq_len} > {max_len}')
+ input_ids, labels = input_ids[:, :max_len], labels[:, :max_len]
+
+ data = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids != self.tokenizer.pad_token_id)
+
+ for key in ('pixel_values', 'pixel_values_videos', 'image_grid_thw', 'video_grid_thw'):
+ if key in batch[0]:
+ data[key] = torch.cat([d[key] for d in batch])
+
+ for key in ('timestamps', 'saliency', 'pos_clip'):
+ if key in batch[0]:
+ data[key] = [d[key] for d in batch]
+
+ return data
diff --git a/videomind/dataset/hybrid.py b/videomind/dataset/hybrid.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f7ea6677f0d6cc057a31642995b7a1b4762739
--- /dev/null
+++ b/videomind/dataset/hybrid.py
@@ -0,0 +1,180 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import math
+import random
+from collections import defaultdict
+from itertools import accumulate
+
+import nncore
+import numpy as np
+import termplotlib as tpl
+import torch
+from tabulate import tabulate
+from torch.utils.data import Dataset
+
+from videomind.constants import IGNORE_INDEX
+from videomind.dataset.utils import preprocess, process_vision_info
+from videomind.utils.parser import parse_span
+
+DATASETS = nncore.Registry('datasets')
+
+
+class HybridDataset(Dataset):
+
+ def __init__(self, processor, model_config, model_args, data_args, training_args):
+ super().__init__()
+
+ datasets = []
+ for key in data_args.datasets.split(','):
+ datasets.append(DATASETS.get(key)(processor, model_args, data_args, training_args))
+
+ data_types = [a['data_type'] for d in datasets for a in d.annos]
+
+ cum_length = [0] + list(accumulate([len(d) for d in datasets]))
+ idx_ranges = [[cum_length[i], cum_length[i + 1]] for i in range(len(cum_length) - 1)]
+
+ if training_args.local_rank in (0, -1):
+ raw_length = sum(d.raw_length for d in datasets)
+ cur_length = idx_ranges[-1][-1]
+
+ ratio = round(cur_length / raw_length * 100, 2)
+ print(f'Number of samples: {raw_length} (original) -> {cur_length} (filtered) {ratio}%')
+
+ data_type_cnt = ' '.join([f'{data_types.count(t)} ({t})' for t in list(set(data_types))])
+ print(f'Data types: {data_type_cnt}')
+
+ tab = defaultdict(int)
+ for dataset in datasets:
+ for anno in dataset.annos:
+ tab[anno.get('source', 'unknown')] += 1
+
+ tab = [[k, v, round(v / cur_length, 3)] for k, v in tab.items()]
+ print(tabulate(tab, headers=['Source', '#Samples', 'Ratio'], tablefmt='pretty', stralign='left'))
+
+ d, _ = torch.Tensor([a['duration'] for d in datasets for a in d.annos if 'duration' in a]).sort()
+ if d.size(0) > 0:
+ n, r = min(d.size(0), 10), d.flip(0)
+ print(f'Top-{n} max video durations: {[round(r[i].item(), 1) for i in range(n)]}')
+ print(f'Top-{n} min video durations: {[round(d[i].item(), 1) for i in range(n)]}')
+ print(f'Average video duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s')
+
+ print('Video duration histogram:')
+ counts, edges = np.histogram(d)
+ labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)]
+ fig = tpl.figure()
+ fig.barh(counts, labels)
+ fig.show()
+
+ d, _ = torch.Tensor([abs(b[0] - b[1]) for d in datasets for a in d.annos if 'span' in a
+ for b in a['span']]).sort()
+ if d.size(0) > 0:
+ n, r = min(d.size(0), 10), d.flip(0)
+ print(f'Top-{n} max span durations: {[round(r[i].item(), 1) for i in range(n)]}')
+ print(f'Top-{n} min span durations: {[round(d[i].item(), 1) for i in range(n)]}')
+ print(f'Average span duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s')
+
+ print('Span duration histogram:')
+ counts, edges = np.histogram(d)
+ labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)]
+ fig = tpl.figure()
+ fig.barh(counts, labels)
+ fig.show()
+
+ self.datasets = datasets
+ self.data_types = data_types
+ self.idx_ranges = idx_ranges
+ self.processor = processor
+ self.model_config = model_config
+ self.model_args = model_args
+ self.data_args = data_args
+ self.training_args = training_args
+
+ def __len__(self):
+ return self.idx_ranges[-1][-1]
+
+ def __getitem__(self, idx):
+ for retry in range(self.data_args.max_retries + 1):
+ try:
+ return self.fetch_data(idx)
+ except Exception as e:
+ print(f'Error in loading {idx}: {type(e).__name__}({e})')
+ idx = random.choice([i for i, t in enumerate(self.data_types) if t == self.data_types[idx]])
+
+ raise RuntimeError(f'Data loading failed after {retry} retries')
+
+ def map(self, *args, **kwargs):
+ return self
+
+ def fetch_data(self, idx):
+ for (s, e), dataset in zip(self.idx_ranges, self.datasets):
+ if s <= idx < e:
+ meta = dataset[idx - s]
+ break
+
+ text = self.processor.apply_chat_template(meta['messages'])
+ text = [text.strip()]
+
+ images, videos = process_vision_info(meta['messages'], sanity_check=True)
+
+ data = self.processor(text=text, images=images, videos=videos, return_tensors='pt')
+ assert data['input_ids'].size(0) == 1
+
+ data['input_ids'] = data['input_ids'][0]
+ data['labels'] = preprocess(data['input_ids'], text[0], self.processor.tokenizer, self.model_args.conv_type)
+
+ # insert segment start/end tokens
+ if 'ss' in meta and 'se' in meta:
+ video_grid_thw = data['video_grid_thw'][0]
+ num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4)
+ assert num_frames * window * 4 == data['pixel_values_videos'].size(0)
+
+ pos_s, pos_e = round(meta['ss'] * num_frames), round(meta['se'] * num_frames)
+ pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames)
+ assert pos_s <= pos_e, (num_frames, meta['ss'], meta['se'])
+
+ base_idx = torch.nonzero(data['input_ids'] == self.model_config.vision_start_token_id).item()
+ pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2
+
+ input_ids = data['input_ids'].tolist()
+ input_ids.insert(pos_s, self.model_config.seg_s_token_id)
+ input_ids.insert(pos_e, self.model_config.seg_e_token_id)
+ data['input_ids'] = torch.LongTensor(input_ids)
+
+ labels = data['labels'].tolist()
+ labels.insert(pos_s, IGNORE_INDEX)
+ labels.insert(pos_e, IGNORE_INDEX)
+ data['labels'] = torch.LongTensor(labels)
+
+ if 'span' in meta:
+ span, duration = meta['span'], meta['duration']
+
+ pixel_values_videos, video_grid_thw = data['pixel_values_videos'], data['video_grid_thw']
+ num_frames = int(video_grid_thw[0][0])
+
+ assert video_grid_thw.size(0) == 1
+ assert video_grid_thw.prod() == pixel_values_videos.size(0)
+
+ # actual fps would be 1/2 of config (temporal patch size = 2)
+ fps = num_frames / duration
+
+ safe_span = [parse_span(b, duration, 1 / fps) for b in span]
+
+ # num_reg_tokens -> num_bnds -> s & e
+ timestamps = [[[s / duration, e / duration] for s, e in safe_span]]
+
+ saliency, pos_inds = torch.zeros(num_frames), []
+ for s, e in safe_span:
+ span_ind = max(0, s * fps), min(e * fps, num_frames)
+ pos_inds = list(range(math.ceil(span_ind[0]), math.ceil(span_ind[1])))
+ assert len(pos_inds) > 0, f'empty pos_inds ({idx}): {fps} {num_frames} {duration} {span}'
+ saliency[pos_inds] = 1
+
+ assert saliency.any(), f'empty saliency ({idx}): {pos_inds} {fps} {num_frames} {duration} {span}'
+ pos_clip = random.sample(saliency.nonzero()[:, 0].tolist(), 1)
+ pos_clip = torch.LongTensor(pos_clip)
+
+ data['timestamps'] = timestamps
+ data['saliency'] = saliency
+ data['pos_clip'] = pos_clip
+
+ return data
diff --git a/videomind/dataset/sub_classes/__init__.py b/videomind/dataset/sub_classes/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..54279430775aef6395c99f636cbc2a2ac9d0d883
--- /dev/null
+++ b/videomind/dataset/sub_classes/__init__.py
@@ -0,0 +1,69 @@
+from .activitynet_captions import ActivitynetCaptionsBiasDataset, ActivitynetCaptionsDataset
+from .activitynet_rtl import ActivitynetRTLDataset
+from .cgbench import CGBenchDataset
+from .charades_sta import CharadesSTADataset
+from .cosmo_cap import CosMoCapDataset
+from .didemo import DiDeMoDataset
+from .ego4d_naq import Ego4DNaQDataset
+from .ego4d_nlq import Ego4DNLQDataset
+from .ego_timeqa import EgoTimeQACropDataset, EgoTimeQADataset, EgoTimeQAGroundingDataset
+from .hirest import HiRESTGroundingDataset, HiRESTStepBiasDataset, HiRESTStepDataset
+from .internvit_vtime import InternVidVTimeDataset
+from .longvideobench import LongVideoBenchDataset
+from .lvbench import LVBenchDataset
+from .mlvu import MLVUDataset
+from .mvbench import MVBenchDataset
+from .nextgqa import NExTGQACropDataset, NExTGQADataset, NExTGQAGroundingDataset
+from .nextqa import NExTQADataset
+from .qa_ego4d import QAEgo4DCropDataset, QAEgo4DDataset, QAEgo4DGroundingDataset
+from .queryd import QuerYDDataset
+from .qvhighlights import QVHighlightsDataset
+from .rextime import ReXTimeCropDataset, ReXTimeDataset, ReXTimeGroundingDataset
+from .star import STARDataset
+from .tacos import TACoSDataset
+from .vid_morp import VidMorpDataset
+from .videomme import VideoMMEDataset
+from .videoxum import VideoXumDataset
+from .youcook2 import YouCook2BiasDataset, YouCook2Dataset
+
+__all__ = [
+ 'ActivitynetCaptionsBiasDataset',
+ 'ActivitynetCaptionsDataset',
+ 'ActivitynetRTLDataset',
+ 'CGBenchDataset',
+ 'CharadesSTADataset',
+ 'CosMoCapDataset',
+ 'DiDeMoDataset',
+ 'Ego4DNaQDataset',
+ 'Ego4DNLQDataset',
+ 'EgoTimeQACropDataset',
+ 'EgoTimeQADataset',
+ 'EgoTimeQAGroundingDataset',
+ 'HiRESTGroundingDataset',
+ 'HiRESTStepBiasDataset',
+ 'HiRESTStepDataset',
+ 'InternVidVTimeDataset',
+ 'LongVideoBenchDataset',
+ 'LVBenchDataset',
+ 'MLVUDataset',
+ 'MVBenchDataset',
+ 'NExTGQACropDataset',
+ 'NExTGQADataset',
+ 'NExTGQAGroundingDataset',
+ 'NExTQADataset',
+ 'QAEgo4DCropDataset',
+ 'QAEgo4DDataset',
+ 'QAEgo4DGroundingDataset',
+ 'QuerYDDataset',
+ 'QVHighlightsDataset',
+ 'ReXTimeCropDataset',
+ 'ReXTimeDataset',
+ 'ReXTimeGroundingDataset',
+ 'STARDataset',
+ 'TACoSDataset',
+ 'VidMorpDataset',
+ 'VideoMMEDataset',
+ 'VideoXumDataset',
+ 'YouCook2BiasDataset',
+ 'YouCook2Dataset',
+]
diff --git a/videomind/dataset/sub_classes/activitynet_captions.py b/videomind/dataset/sub_classes/activitynet_captions.py
new file mode 100644
index 0000000000000000000000000000000000000000..84487635903dbf4665e203d143d06ed82b338228
--- /dev/null
+++ b/videomind/dataset/sub_classes/activitynet_captions.py
@@ -0,0 +1,96 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from collections import OrderedDict
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='activitynet_captions')
+class ActivitynetCaptionsDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/activitynet_captions/train.json'
+ ANNO_PATH_VALID = 'data/activitynet_captions/val_1.json'
+ ANNO_PATH_TEST = 'data/activitynet_captions/val_2.json'
+
+ VIDEO_ROOT = 'data/activitynet/videos_3fps_480_noaudio'
+ DURATIONS = 'data/activitynet/durations.json'
+
+ UNIT = 0.01
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ for query, span in zip(raw_anno['sentences'], raw_anno['timestamps']):
+ anno = dict(
+ source='activitynet_captions',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=parse_query(query),
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='activitynet_captions_bias')
+class ActivitynetCaptionsBiasDataset(ActivitynetCaptionsDataset):
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ assert len(raw_anno['sentences']) == len(raw_anno['timestamps'])
+
+ for i in range(len(raw_anno['sentences']) - 1):
+ span_a = raw_anno['timestamps'][i]
+ span_b = raw_anno['timestamps'][i + 1]
+
+ if span_b[0] - span_a[1] < 3:
+ query_a = parse_query(f"The moment before {raw_anno['sentences'][i + 1]}")
+ query_b = parse_query(f"The moment after {raw_anno['sentences'][i]}")
+
+ anno_a = dict(
+ source='activitynet_captions_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=query_a,
+ span=[span_a])
+
+ anno_b = dict(
+ source='activitynet_captions_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=query_b,
+ span=[span_b])
+
+ annos.append(anno_a)
+ annos.append(anno_b)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/activitynet_rtl.py b/videomind/dataset/sub_classes/activitynet_rtl.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d6e892d9739c8b511ce3aa493cfeda09c37f0c4
--- /dev/null
+++ b/videomind/dataset/sub_classes/activitynet_rtl.py
@@ -0,0 +1,68 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import re
+from collections import OrderedDict
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='activitynet_rtl')
+class ActivitynetRTLDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/activitynet_rtl/activitynet_train_gpt-4-0613_temp_6_f10009.json'
+ ANNO_PATH_TEST = 'data/activitynet_rtl/annot_val_1_q229.json'
+
+ VIDEO_ROOT = 'data/activitynet/videos_3fps_480_noaudio'
+
+ UNIT = 0.01
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ for meta in raw_anno['QA']:
+ match = re.findall(r'<(\d+(\.\d+)?)>', meta['a'])
+ span = [float(m[0]) for m in match[:2]]
+
+ # some samples do not have timestamps
+ if len(span) != 2:
+ continue
+
+ anno = dict(
+ source='activitynet_rtl',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(meta['q']),
+ span=[span])
+
+ annos.append(anno)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = f"v_{raw_anno['vid']}"
+
+ match = re.findall(r'<(\d+(\.\d+)?)>', raw_anno['answer'])
+ span = [float(m[0]) for m in match[:2]]
+ assert len(span) == 2
+
+ anno = dict(
+ source='activitynet_rtl',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['question']),
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/cgbench.py b/videomind/dataset/sub_classes/cgbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3130038172c971b28e114f67615fbd81a863e4e
--- /dev/null
+++ b/videomind/dataset/sub_classes/cgbench.py
@@ -0,0 +1,47 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='cgbench')
+class CGBenchDataset(Dataset):
+
+ ANNO_PATH_TEST = 'data/cgbench/cgbench_mini.json'
+
+ VIDEO_ROOT = 'data/cgbench/videos_3fps_480_noaudio'
+ SUBTITLE_ROOT = 'data/cgbench/subtitles'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='test'):
+ assert split == 'test'
+
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_uid']
+
+ anno = dict(
+ source='cgbench',
+ data_type='multimodal',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ subtitle_path=nncore.join(self.SUBTITLE_ROOT, vid + '.srt'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=[o.capitalize() for o in raw_anno['choices']],
+ answer=raw_anno['answer'].capitalize(),
+ ans=raw_anno['right_answer'],
+ span=raw_anno['clue_intervals'],
+ task=raw_anno['sub_category'],
+ domain=raw_anno['domain'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/charades_sta.py b/videomind/dataset/sub_classes/charades_sta.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f4448e29a32fe69363da3efec5dbbeabde9bf7
--- /dev/null
+++ b/videomind/dataset/sub_classes/charades_sta.py
@@ -0,0 +1,45 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='charades_sta')
+class CharadesSTADataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/charades_sta/charades_sta_train.txt'
+ ANNO_PATH_TEST = 'data/charades_sta/charades_sta_test.txt'
+
+ VIDEO_ROOT = 'data/charades_sta/videos_3fps_480_noaudio'
+ DURATIONS = 'data/charades_sta/durations.json'
+
+ UNIT = 0.1
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for raw_anno in raw_annos:
+ info, query = raw_anno.split('##')
+ vid, s, e = info.split()
+
+ anno = dict(
+ source='charades_sta',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=parse_query(query),
+ span=[[float(s), float(e)]])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/cosmo_cap.py b/videomind/dataset/sub_classes/cosmo_cap.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e7fea88b2d57310db72429a1300d1aa26c67908
--- /dev/null
+++ b/videomind/dataset/sub_classes/cosmo_cap.py
@@ -0,0 +1,37 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='cosmo_cap')
+class CosMoCapDataset(GroundingDataset):
+
+ ANNO_PATH = 'data/cosmo_cap/anno_cosmo_cap.jsonl'
+
+ VIDEO_ROOT = 'data/cosmo_cap/videos_3fps_480_noaudio'
+
+ UNIT = 1.0
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ raw_annos = nncore.load(self.ANNO_PATH)
+
+ annos = []
+ for raw_anno in raw_annos:
+ anno = dict(
+ source='cosmo_cap',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, raw_anno['vid'] + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=[raw_anno['span']])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/didemo.py b/videomind/dataset/sub_classes/didemo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6bdba5eb6b62c52bc57207f9b487cf10d5378c6
--- /dev/null
+++ b/videomind/dataset/sub_classes/didemo.py
@@ -0,0 +1,59 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import random
+
+import nncore
+import numpy as np
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='didemo')
+class DiDeMoDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/didemo/train_data.json'
+ ANNO_PATH_VALID = 'data/didemo/val_data.json'
+ ANNO_PATH_TEST = 'data/didemo/test_data.json'
+
+ VIDEO_ROOT = 'data/didemo/videos_3fps_480_noaudio'
+ DURATIONS = 'data/didemo/durations.json'
+
+ UNIT = 1.0
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video'].split('.')[0]
+
+ # apply mean on multiple spans
+ span = np.array(raw_anno['times']).mean(axis=0).tolist()
+ span = [round(span[0] * 5), round((span[1] + 1) * 5)]
+
+ # augment spans during training
+ if split == 'train':
+ offset = random.randint(-2, 2)
+ span = [span[0] + offset, span[1] + offset]
+
+ anno = dict(
+ source='didemo',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=parse_query(raw_anno['description']),
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/ego4d_naq.py b/videomind/dataset/sub_classes/ego4d_naq.py
new file mode 100644
index 0000000000000000000000000000000000000000..917e4d1fef969e12f6a6de455f9ed5d25d1c8004
--- /dev/null
+++ b/videomind/dataset/sub_classes/ego4d_naq.py
@@ -0,0 +1,81 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from collections import OrderedDict
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='ego4d_naq')
+class Ego4DNaQDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/ego4d_naq/train.json'
+ ANNO_PATH_VALID = 'data/ego4d_naq/val.json'
+ ANNO_PATH_TEST = 'data/ego4d_naq/test.json'
+
+ VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ duration = raw_anno['num_frames'] / raw_anno['fps']
+
+ # 300s: 254k samples (dropped 121k samples merged 156k samples)
+ # 480s: 567k samples (dropped 249k samples merged 328k samples)
+ if split == 'train' and (duration < 10 or duration > 600):
+ continue
+
+ meta = dict()
+ for span, query in zip(raw_anno['exact_times'], raw_anno['sentences']):
+ span = [round(span[0], 3), round(span[1], 3)]
+
+ query = parse_query(query)
+
+ # these annotations might be from nlq
+ nlq_keys = ('who', 'what', 'when', 'in what', 'did', 'where', 'how', 'i what')
+ if split == 'train' and any(query.startswith(k) for k in nlq_keys):
+ continue
+
+ # bad samples
+ if split == 'train' and '#unsure' in query:
+ continue
+
+ # too short or too long samples
+ num_words = len(query.split(' '))
+ if split == 'train' and (num_words < 3 or num_words > 30):
+ continue
+
+ if query not in meta:
+ meta[query] = []
+
+ meta[query].append(span)
+
+ for query, span in meta.items():
+ # skip samples with multiple moments
+ if len(span) > 1:
+ continue
+
+ anno = dict(
+ source='ego4d_naq',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=query,
+ span=span)
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/ego4d_nlq.py b/videomind/dataset/sub_classes/ego4d_nlq.py
new file mode 100644
index 0000000000000000000000000000000000000000..936c1f2b427e2023c37d9d8c839f906dee8e9dbc
--- /dev/null
+++ b/videomind/dataset/sub_classes/ego4d_nlq.py
@@ -0,0 +1,41 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='ego4d_nlq')
+class Ego4DNLQDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/ego4d_nlq/nlq_train.jsonl'
+ ANNO_PATH_VALID = 'data/ego4d_nlq/nlq_val.jsonl'
+
+ VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+
+ annos = []
+ for raw_anno in raw_annos:
+ assert len(raw_anno['relevant_windows']) == 1
+
+ anno = dict(
+ source='ego4d_nlq',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, raw_anno['vid'] + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=raw_anno['relevant_windows'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/ego_timeqa.py b/videomind/dataset/sub_classes/ego_timeqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..521b02dfa13e95c9788b4bed23b3f00f6d744621
--- /dev/null
+++ b/videomind/dataset/sub_classes/ego_timeqa.py
@@ -0,0 +1,93 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import random
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='ego_timeqa')
+class EgoTimeQADataset(AnsweringDataset):
+
+ ANNO_PATH_TRAIN = 'data/ego_timeqa/annotations.EgoTimeQA.json'
+
+ VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
+ DURATIONS = 'data/ego4d/v2/durations.json'
+
+ SOURCE = 'ego_timeqa'
+ DATA_TYPE = 'multimodal'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+
+ duration = durations[vid]
+
+ # 303k -> 284k (to be verified)
+ if duration < 10 or duration > 600:
+ continue
+
+ span = [raw_anno['moment_start_frame'] / 30, raw_anno['moment_end_frame'] / 30]
+ span = [round(span[0], 3), round(span[1], 3)]
+
+ # this would remove many samples (284k -> 37k)
+ # if span[1] - span[0] < 2:
+ # continue
+
+ question = raw_anno['question'].replace(' l ', ' I ').capitalize()
+ question = parse_question(question)
+ query = parse_query(question)
+
+ # too short or too long samples
+ num_words = len(query.split(' '))
+ if split == 'train' and (num_words < 3 or num_words > 30):
+ continue
+
+ answer = raw_anno['answer'].capitalize()
+
+ assert len(raw_anno['wrong_answers']) == 3
+ idx = random.randint(0, 3)
+ ans = chr(ord('A') + idx)
+ options = [o.capitalize() for o in raw_anno['wrong_answers']]
+ options.insert(idx, answer)
+
+ anno = dict(
+ source=self.SOURCE,
+ data_type=self.DATA_TYPE,
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=query,
+ question=question,
+ options=options,
+ answer=answer,
+ ans=ans,
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='ego_timeqa_crop')
+class EgoTimeQACropDataset(AnsweringCropDataset, EgoTimeQADataset):
+
+ SOURCE = 'ego_timeqa_crop'
+
+
+@DATASETS.register(name='ego_timeqa_grounding')
+class EgoTimeQAGroundingDataset(GroundingDataset, EgoTimeQADataset):
+
+ SOURCE = 'ego_timeqa_grounding'
+ DATA_TYPE = 'grounding'
diff --git a/videomind/dataset/sub_classes/hirest.py b/videomind/dataset/sub_classes/hirest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6681379e88a7005b0f255420b22b4c563219935
--- /dev/null
+++ b/videomind/dataset/sub_classes/hirest.py
@@ -0,0 +1,150 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from collections import OrderedDict
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='hirest_grounding')
+class HiRESTGroundingDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/hirest/all_data_train.json'
+ ANNO_PATH_VALID = 'data/hirest/all_data_val.json'
+
+ VIDEO_ROOT = 'data/hirest/videos_3fps_480_noaudio'
+
+ UNIT = 1.0
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for query, videos in raw_annos.items():
+ for video_name, raw_anno in videos.items():
+ if not raw_anno['relevant'] or not raw_anno['clip']:
+ continue
+
+ assert len(raw_anno['bounds']) == 2
+
+ vid = video_name.split('.')[0]
+
+ if vid not in all_videos:
+ continue
+
+ anno = dict(
+ source='hirest_grounding',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, video_name),
+ duration=raw_anno['v_duration'],
+ query=parse_query(query),
+ span=[raw_anno['bounds']])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='hirest_step')
+class HiRESTStepDataset(HiRESTGroundingDataset):
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for query, videos in raw_annos.items():
+ for video_name, raw_anno in videos.items():
+ if not raw_anno['relevant'] or not raw_anno['clip'] or len(raw_anno['steps']) == 0:
+ continue
+
+ vid = video_name.split('.')[0]
+
+ if vid not in all_videos:
+ continue
+
+ for step in raw_anno['steps']:
+ assert len(step['absolute_bounds']) == 2
+
+ anno = dict(
+ source='hirest_step',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, video_name),
+ duration=raw_anno['v_duration'],
+ query=parse_query(step['heading']),
+ span=[step['absolute_bounds']])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='hirest_step_bias')
+class HiRESTStepBiasDataset(HiRESTStepDataset):
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for query, videos in raw_annos.items():
+ for video_name, raw_anno in videos.items():
+ if not raw_anno['relevant'] or not raw_anno['clip'] or len(raw_anno['steps']) == 0:
+ continue
+
+ vid = video_name.split('.')[0]
+
+ if vid not in all_videos:
+ continue
+
+ for i in range(len(raw_anno['steps']) - 1):
+ span_a = raw_anno['steps'][i]['absolute_bounds']
+ span_b = raw_anno['steps'][i + 1]['absolute_bounds']
+
+ assert len(span_a) == 2 and len(span_b) == 2 and span_a[1] == span_b[0]
+
+ query_a = parse_query(f"The moment before {raw_anno['steps'][i + 1]['heading']}")
+ query_b = parse_query(f"The moment after {raw_anno['steps'][i]['heading']}")
+
+ anno_a = dict(
+ source='hirest_step_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, video_name),
+ duration=raw_anno['v_duration'],
+ query=query_a,
+ span=[span_a])
+
+ anno_b = dict(
+ source='hirest_step_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, video_name),
+ duration=raw_anno['v_duration'],
+ query=query_b,
+ span=[span_b])
+
+ annos.append(anno_a)
+ annos.append(anno_b)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/internvit_vtime.py b/videomind/dataset/sub_classes/internvit_vtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7e5abd90c86bd7447b06b732f2033d73be0f5ab
--- /dev/null
+++ b/videomind/dataset/sub_classes/internvit_vtime.py
@@ -0,0 +1,45 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='internvid_vtime')
+class InternVidVTimeDataset(GroundingDataset):
+
+ ANNO_PATH = 'data/internvid_vtime/anno_internvid_vtime_query_gpt4o_mini.jsonl'
+
+ VIDEO_ROOT = 'data/internvid_vtime/videos_crop_3fps_480_noaudio'
+
+ UNIT = 0.1
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ raw_annos = nncore.load(self.ANNO_PATH)
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['vid']
+
+ if vid not in all_videos:
+ continue
+
+ anno = dict(
+ source='internvid_vtime',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=[raw_anno['span']])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/longvideobench.py b/videomind/dataset/sub_classes/longvideobench.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0406555b43d76fbd12ee508c9f55983ded6337
--- /dev/null
+++ b/videomind/dataset/sub_classes/longvideobench.py
@@ -0,0 +1,53 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='longvideobench')
+class LongVideoBenchDataset(Dataset):
+
+ ANNO_PATH_VALID = 'data/longvideobench/lvb_val.json'
+ ANNO_PATH_TEST = 'data/longvideobench/lvb_test_wo_gt.json'
+
+ VIDEO_ROOT = 'data/longvideobench/videos_3fps_480_noaudio'
+
+ @classmethod
+ def load_annos(self, split='valid'):
+ if split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ print('WARNING: Test split does not have ground truth annotations')
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+
+ if vid.startswith('@'):
+ vid = vid[-19:]
+
+ # videos might come from youtube or other sources
+ assert len(vid) in (11, 19)
+
+ anno = dict(
+ source='longvideobench',
+ data_type='multimodal',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=raw_anno['candidates'],
+ task=str(raw_anno['duration_group']),
+ level=raw_anno['level'],
+ question_category=raw_anno['question_category'])
+
+ if 'correct_choice' in raw_anno:
+ anno['answer'] = raw_anno['candidates'][raw_anno['correct_choice']]
+ anno['ans'] = chr(ord('A') + raw_anno['correct_choice'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/lvbench.py b/videomind/dataset/sub_classes/lvbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..0464c31d5dea502381305d128552231484623224
--- /dev/null
+++ b/videomind/dataset/sub_classes/lvbench.py
@@ -0,0 +1,52 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='lvbench')
+class LVBenchDataset(Dataset):
+
+ ANNO_PATH = 'data/lvbench/LVBench/video_info.meta.jsonl'
+
+ VIDEO_ROOT = 'data/lvbench/videos_3fps_480_noaudio'
+
+ @classmethod
+ def load_annos(self, split='test'):
+ assert split == 'test'
+
+ raw_annos = nncore.load(self.ANNO_PATH)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['key']
+
+ for meta in raw_anno['qa']:
+ tok = meta['question'].split('\n')
+
+ assert len(tok) == 5
+ assert all(any(o.startswith(k) for k in ('(A) ', '(B) ', '(C) ', '(D) ')) for o in tok[1:])
+
+ options = [o[4:] for o in tok[1:]]
+ ans = meta['answer']
+ answer = options[ord(ans) - ord('A')]
+ assert ans in 'ABCD'
+
+ anno = dict(
+ source='lvbench',
+ data_type='multimodal',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ query=parse_query(tok[0]),
+ question=parse_question(tok[0]),
+ options=options,
+ answer=answer,
+ ans=ans,
+ task=meta['question_type'],
+ time_reference=meta['time_reference'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/mlvu.py b/videomind/dataset/sub_classes/mlvu.py
new file mode 100644
index 0000000000000000000000000000000000000000..28dbe2bffc2ceb75d28e2b863440ce5fbac81d06
--- /dev/null
+++ b/videomind/dataset/sub_classes/mlvu.py
@@ -0,0 +1,55 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='mlvu')
+class MLVUDataset(Dataset):
+
+ TASK_TO_DIR_MAP = {
+ 'plotQA': '1_plotQA',
+ 'findNeedle': '2_needle',
+ 'ego': '3_ego',
+ 'count': '4_count',
+ 'order': '5_order',
+ 'anomaly_reco': '6_anomaly_reco',
+ 'topic_reasoning': '7_topic_reasoning'
+ }
+
+ DATA_ROOT = 'data/mlvu'
+
+ @classmethod
+ def load_annos(self, split='test'):
+ assert split == 'test'
+
+ paths = [nncore.join(self.DATA_ROOT, 'json', f'{n}.json') for n in self.TASK_TO_DIR_MAP.values()]
+
+ raw_annos = nncore.flatten([nncore.load(p) for p in paths])
+
+ annos = []
+ for raw_anno in raw_annos:
+ task = raw_anno['question_type']
+ video_name = nncore.join(self.TASK_TO_DIR_MAP[task], raw_anno['video'])
+
+ options = raw_anno['candidates']
+ answer = raw_anno['answer']
+ ans = chr(ord('A') + options.index(answer))
+
+ anno = dict(
+ source='mlvu',
+ data_type='multimodal',
+ video_path=nncore.join(self.DATA_ROOT, 'video', video_name),
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=options,
+ answer=answer,
+ ans=ans,
+ task=task)
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/mvbench.py b/videomind/dataset/sub_classes/mvbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..559660cecff898640de3d429a932eb8ad8e7272f
--- /dev/null
+++ b/videomind/dataset/sub_classes/mvbench.py
@@ -0,0 +1,74 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='mvbench')
+class MVBenchDataset(Dataset):
+
+ META_DATA = [('Episodic Reasoning', 'episodic_reasoning.json', 'tvqa/frames_fps3_hq', 'frame'),
+ ('Action Sequence', 'action_sequence.json', 'star/Charades_v1_480', 'video'),
+ ('Action Prediction', 'action_prediction.json', 'star/Charades_v1_480', 'video'),
+ ('Action Antonym', 'action_antonym.json', 'ssv2_video', 'video'),
+ ('Fine-grained Action', 'fine_grained_action.json', 'Moments_in_Time_Raw/videos', 'video'),
+ ('Unexpected Action', 'unexpected_action.json', 'FunQA_test/test', 'video'),
+ ('Object Existence', 'object_existence.json', 'clevrer/video_validation', 'video'),
+ ('Object Interaction', 'object_interaction.json', 'star/Charades_v1_480', 'video'),
+ ('Object Shuffle', 'object_shuffle.json', 'perception/videos', 'video'),
+ ('Moving Direction', 'moving_direction.json', 'clevrer/video_validation', 'video'),
+ ('Action Localization', 'action_localization.json', 'sta/sta_video', 'video'),
+ ('Scene Transition', 'scene_transition.json', 'scene_qa/video', 'video'),
+ ('Action Count', 'action_count.json', 'perception/videos', 'video'),
+ ('Moving Count', 'moving_count.json', 'clevrer/video_validation', 'video'),
+ ('Moving Attribute', 'moving_attribute.json', 'clevrer/video_validation', 'video'),
+ ('State Change', 'state_change.json', 'perception/videos', 'video'),
+ ('Fine-grained Pose', 'fine_grained_pose.json', 'nturgbd', 'video'),
+ ('Character Order', 'character_order.json', 'perception/videos', 'video'),
+ ('Egocentric Navigation', 'egocentric_navigation.json', 'vlnqa', 'video'),
+ ('Counterfactual Inference', 'counterfactual_inference.json', 'clevrer/video_validation', 'video')]
+
+ DATA_ROOT = 'data/mvbench'
+
+ MIN_LEN = 64
+
+ @classmethod
+ def load_annos(self, split='test', sample_frames=32):
+ assert split == 'test'
+
+ annos = []
+ for meta in self.META_DATA:
+ raw_annos = nncore.load(nncore.join(self.DATA_ROOT, 'json', meta[1]))
+
+ for raw_anno in raw_annos:
+ video_name = nncore.join(meta[2], raw_anno['video'])
+ video_path = nncore.join(self.DATA_ROOT, 'video', video_name)
+
+ if meta[3] == 'frame':
+ num_frames = len(nncore.ls(video_path, ext='.jpg'))
+ video_path = [
+ nncore.join(video_path, f'{i:0>5}.jpg')
+ for i in range(1, num_frames + 1, num_frames // (sample_frames - 1))
+ ][:sample_frames]
+
+ options = raw_anno['candidates']
+ answer = raw_anno['answer']
+ ans = chr(ord('A') + options.index(answer))
+
+ anno = dict(
+ source='mvbench',
+ data_type='multimodal',
+ video_path=video_path,
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=options,
+ answer=answer,
+ ans=ans,
+ task=meta[0])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/nextgqa.py b/videomind/dataset/sub_classes/nextgqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..5afbb6e8bda0d3f940814b8cbad5f24742478ec9
--- /dev/null
+++ b/videomind/dataset/sub_classes/nextgqa.py
@@ -0,0 +1,87 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import csv
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='nextgqa')
+class NExTGQADataset(AnsweringDataset):
+
+ ANNO_PATH_VALID = 'data/nextgqa/val.csv'
+ ANNO_PATH_TEST = 'data/nextgqa/test.csv'
+
+ SPAN_PATH_VALID = 'data/nextgqa/gsub_val.json'
+ SPAN_PATH_TEST = 'data/nextgqa/gsub_test.json'
+
+ VIDEO_ID_MAP = 'data/nextgqa/map_vid_vidorID.json'
+ VIDEO_ROOT = 'data/nextqa/videos'
+
+ SOURCE = 'nextgqa'
+ DATA_TYPE = 'multimodal'
+
+ UNIT = 0.1
+
+ @classmethod
+ def load_annos(self, split='valid'):
+ assert split in ('valid', 'test')
+
+ if split == 'valid':
+ anno_path = self.ANNO_PATH_VALID
+ raw_spans = nncore.load(self.SPAN_PATH_VALID)
+ else:
+ anno_path = self.ANNO_PATH_TEST
+ raw_spans = nncore.load(self.SPAN_PATH_TEST)
+
+ with open(anno_path, mode='r') as f:
+ reader = csv.DictReader(f)
+ raw_annos = [d for d in reader]
+
+ video_id_map = nncore.load(self.VIDEO_ID_MAP)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+ qid = raw_anno['qid']
+
+ video_id = video_id_map[vid]
+
+ query = parse_query(raw_anno['question'].capitalize() + '?')
+ question = parse_question(raw_anno['question'].capitalize() + '?')
+ options = [raw_anno[k].capitalize() for k in ('a0', 'a1', 'a2', 'a3', 'a4')]
+ answer = raw_anno['answer'].capitalize()
+ ans = chr(ord('A') + options.index(answer))
+
+ anno = dict(
+ source=self.SOURCE,
+ data_type=self.DATA_TYPE,
+ video_path=nncore.join(self.VIDEO_ROOT, video_id + '.mp4'),
+ duration=raw_spans[vid]['duration'],
+ query=query,
+ question=question,
+ options=options,
+ answer=answer,
+ ans=ans,
+ span=raw_spans[vid]['location'][qid],
+ task=raw_anno['type'])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='nextgqa_crop')
+class NExTGQACropDataset(AnsweringCropDataset, NExTGQADataset):
+
+ SOURCE = 'nextgqa_crop'
+
+
+@DATASETS.register(name='nextgqa_grounding')
+class NExTGQAGroundingDataset(GroundingDataset, NExTGQADataset):
+
+ SOURCE = 'nextgqa_grounding'
+ DATA_TYPE = 'grounding'
diff --git a/videomind/dataset/sub_classes/nextqa.py b/videomind/dataset/sub_classes/nextqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e46acb67fd24da9e61c0a91522c5dd0e648093
--- /dev/null
+++ b/videomind/dataset/sub_classes/nextqa.py
@@ -0,0 +1,63 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import csv
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='nextqa')
+class NExTQADataset(AnsweringDataset):
+
+ ANNO_PATH_TRAIN = 'data/nextqa/train.csv'
+ ANNO_PATH_VALID = 'data/nextqa/val.csv'
+ ANNO_PATH_TEST = 'data/nextqa/test.csv'
+
+ VIDEO_ID_MAP = 'data/nextqa/map_vid_vidorID.json'
+ VIDEO_ROOT = 'data/nextqa/NExTVideo'
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ anno_path = self.ANNO_PATH_TRAIN
+ elif split == 'valid':
+ anno_path = self.ANNO_PATH_VALID
+ else:
+ anno_path = self.ANNO_PATH_TEST
+
+ with open(anno_path, mode='r') as f:
+ reader = csv.DictReader(f)
+ raw_annos = [d for d in reader]
+
+ video_id_map = nncore.load(self.VIDEO_ID_MAP)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video']
+ qid = raw_anno['qid']
+
+ video_id = video_id_map[vid]
+ query = parse_query(raw_anno['question'].capitalize() + '?')
+ question = parse_question(raw_anno['question'].capitalize() + '?')
+ options = [raw_anno[k].capitalize() for k in ('a0', 'a1', 'a2', 'a3', 'a4')]
+ ans = chr(ord('A') + int(raw_anno['answer']))
+ answer = options[int(raw_anno['answer'])]
+
+ anno = dict(
+ source='nextqa',
+ data_type='multimodal',
+ uid=f'{vid}_{qid}',
+ video_path=nncore.join(self.VIDEO_ROOT, video_id + '.mp4'),
+ query=query,
+ question=question,
+ options=options,
+ answer=answer,
+ ans=ans,
+ task=raw_anno['type'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/qa_ego4d.py b/videomind/dataset/sub_classes/qa_ego4d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7db13cb507cfce7cf9c038a0c38156b8e834f62
--- /dev/null
+++ b/videomind/dataset/sub_classes/qa_ego4d.py
@@ -0,0 +1,98 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import random
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='qa_ego4d')
+class QAEgo4DDataset(AnsweringDataset):
+
+ ANNO_PATH_TRAIN = 'data/qa_ego4d/annotations.QaEgo4D_train.json'
+ ANNO_PATH_VALID = 'data/qa_ego4d/annotations.QaEgo4D_val_options.json'
+ ANNO_PATH_TEST = 'data/qa_ego4d/annotations.QaEgo4D_test_options.json'
+
+ VIDEO_ROOT = 'data/ego4d/v1/videos_3fps_480_noaudio'
+ DURATIONS = 'data/ego4d/v1/durations.json'
+
+ SOURCE = 'qa_ego4d'
+ DATA_TYPE = 'multimodal'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+
+ duration = durations[vid]
+
+ # too short or too long samples
+ if split == 'train' and (duration < 10 or duration > 600):
+ continue
+
+ span = [raw_anno['moment_start_frame'] / 30, raw_anno['moment_end_frame'] / 30]
+ span = [round(span[0], 3), round(span[1], 3)]
+
+ # skip samples with too short moments
+ # if split == 'train' and span[1] - span[0] < 2:
+ # continue
+
+ answer = raw_anno['answer'].capitalize()
+
+ if 'options' in raw_anno:
+ options = [o.capitalize() for o in raw_anno['options']]
+ idx = options.index(answer)
+ ans = chr(ord('A') + idx)
+ else:
+ # NOTE: indeterministic evaluation
+ assert len(raw_anno['wrong_answers']) == 3
+ idx = random.randint(0, 3)
+ ans = chr(ord('A') + idx)
+ options = [o.capitalize() for o in raw_anno['wrong_answers']]
+ options.insert(idx, answer)
+
+ assert len(options) == 4, options
+
+ anno = dict(
+ source=self.SOURCE,
+ data_type=self.DATA_TYPE,
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=parse_query(raw_anno['question'].capitalize()),
+ question=parse_question(raw_anno['question'].capitalize()),
+ options=options,
+ answer=answer,
+ ans=ans,
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='qa_ego4d_crop')
+class QAEgo4DCropDataset(AnsweringCropDataset, QAEgo4DDataset):
+
+ SOURCE = 'qa_ego4d_crop'
+
+
+@DATASETS.register(name='qa_ego4d_grounding')
+class QAEgo4DGroundingDataset(GroundingDataset, QAEgo4DDataset):
+
+ SOURCE = 'qa_ego4d_grounding'
+ DATA_TYPE = 'grounding'
diff --git a/videomind/dataset/sub_classes/queryd.py b/videomind/dataset/sub_classes/queryd.py
new file mode 100644
index 0000000000000000000000000000000000000000..928a4b17541b6166223542498182f2cee51145f0
--- /dev/null
+++ b/videomind/dataset/sub_classes/queryd.py
@@ -0,0 +1,49 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='queryd')
+class QuerYDDataset(GroundingDataset):
+
+ VID_PATH = 'data/queryd/train_list.txt'
+ QUERY_PATH = 'data/queryd/raw_captions_combined_filtered-v2.pkl'
+ SPAN_PATH = 'data/queryd/times_captions_combined_filtered-v2.pkl'
+
+ VIDEO_ROOT = 'data/queryd/videos_3fps_480_noaudio'
+ DURATIONS = 'data/queryd/durations.json'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ vids = nncore.load(self.VID_PATH)
+ queries = nncore.load(self.QUERY_PATH)
+ spans = nncore.load(self.SPAN_PATH)
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for vid in vids:
+ for query, span in zip(queries[vid], spans[vid]):
+ video_name = vid[6:]
+
+ if video_name not in durations:
+ continue
+
+ anno = dict(
+ source='queryd',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, video_name + '.mp4'),
+ duration=durations[video_name],
+ query=parse_query(' '.join(query)),
+ span=[span])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/qvhighlights.py b/videomind/dataset/sub_classes/qvhighlights.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9b1cfc25bd938c1e44b627e2ac71efdea22c183
--- /dev/null
+++ b/videomind/dataset/sub_classes/qvhighlights.py
@@ -0,0 +1,78 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='qvhighlights')
+class QVHighlightsDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/qvhighlights/highlight_train_release.jsonl'
+ ANNO_PATH_VALID = 'data/qvhighlights/highlight_val_release.jsonl'
+ ANNO_PATH_TEST = 'data/qvhighlights/highlight_test_release.jsonl'
+
+ VIDEO_ROOT = 'data/qvhighlights/videos_3fps_480_noaudio'
+
+ UNIT = 2.0
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ print('WARNING: Test split does not have ground truth annotations')
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['vid']
+ qid = raw_anno['qid']
+
+ anno = dict(
+ source='qvhighlights',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=raw_anno.get('relevant_windows'),
+ vid=vid,
+ qid=qid)
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='qvhighlights_single')
+class QVHighlightsSingleDataset(QVHighlightsDataset):
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+
+ annos = []
+ for raw_anno in raw_annos:
+ # skip samples with multiple moments
+ if len(raw_anno['relevant_windows']) > 1:
+ continue
+
+ vid = raw_anno['vid']
+
+ anno = dict(
+ source='qvhighlights_single',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=raw_anno.get('relevant_windows'))
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/rextime.py b/videomind/dataset/sub_classes/rextime.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d782dcb8538f0fee50fcbd47a0078b4d029b3d4
--- /dev/null
+++ b/videomind/dataset/sub_classes/rextime.py
@@ -0,0 +1,81 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='rextime')
+class ReXTimeDataset(AnsweringDataset):
+
+ ANNO_PATH_TRAIN = 'data/rextime/rextime_train.json'
+ ANNO_PATH_VALID = 'data/rextime/rextime_val.json'
+ ANNO_PATH_TEST = 'data/rextime/rextime_test_release.json'
+
+ VIDEO_ROOT_ANET = 'data/activitynet/videos_3fps_480_noaudio'
+ VIDEO_ROOT_QVHL = 'data/qvhighlights/videos_3fps_480_noaudio'
+
+ DURATIONS_ANET = 'data/activitynet/durations.json'
+ DURATIONS_QVHL = 'data/qvhighlights/durations.json'
+
+ SOURCE = 'rextime'
+ DATA_TYPE = 'multimodal'
+
+ UNIT = 1.0
+ MIN_LEN = 64
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ print('WARNING: Test split does not have ground truth annotations')
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ durations_anet = nncore.load(self.DURATIONS_ANET)
+ durations_qvhl = nncore.load(self.DURATIONS_QVHL)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['vid']
+
+ if len(vid) == 13:
+ video_path = nncore.join(self.VIDEO_ROOT_ANET, vid + '.mp4')
+ duration = durations_anet[vid]
+ else:
+ video_path = nncore.join(self.VIDEO_ROOT_QVHL, vid + '.mp4')
+ duration = durations_qvhl[vid]
+
+ anno = dict(
+ source=self.SOURCE,
+ data_type=self.DATA_TYPE,
+ video_path=video_path,
+ duration=duration,
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=[o.capitalize() for o in raw_anno['options']],
+ answer=raw_anno['answer'].replace('From to , ', '').capitalize(),
+ ans=raw_anno['ans'],
+ span=[raw_anno['span']],
+ task=raw_anno['category'])
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='rextime_crop')
+class ReXTimeCropDataset(AnsweringCropDataset, ReXTimeDataset):
+
+ SOURCE = 'rextime_crop'
+
+
+@DATASETS.register(name='rextime_grounding')
+class ReXTimeGroundingDataset(GroundingDataset, ReXTimeDataset):
+
+ SOURCE = 'rextime_grounding'
+ DATA_TYPE = 'grounding'
diff --git a/videomind/dataset/sub_classes/star.py b/videomind/dataset/sub_classes/star.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f1f7349fd411969ccd2c024f07939f964fffe41
--- /dev/null
+++ b/videomind/dataset/sub_classes/star.py
@@ -0,0 +1,54 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import AnsweringCropDataset
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='star')
+class STARDataset(AnsweringCropDataset):
+
+ ANNO_PATH_TRAIN = 'data/star/STAR_train.json'
+ ANNO_PATH_VALID = 'data/star/STAR_val.json'
+
+ VIDEO_ROOT = 'data/charades_sta/videos_3fps_480_noaudio'
+ DURATIONS = 'data/charades_sta/durations.json'
+
+ UNIT = 0.1
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+
+ durations = nncore.load(self.DURATIONS)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+
+ options = [c['choice'] for c in raw_anno['choices']]
+ answer = raw_anno['answer']
+ ans = chr(ord('A') + options.index(answer))
+
+ anno = dict(
+ source='star',
+ data_type='multimodal',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=durations[vid],
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=options,
+ answer=answer,
+ ans=ans,
+ span=[[raw_anno['start'], raw_anno['end']]],
+ task=raw_anno['question_id'].split('_')[0],
+ no_aug=True)
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/tacos.py b/videomind/dataset/sub_classes/tacos.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e606d77c56b787d5ed6b7b47b423ad82cc02032
--- /dev/null
+++ b/videomind/dataset/sub_classes/tacos.py
@@ -0,0 +1,46 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='tacos')
+class TACoSDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/tacos/train.jsonl'
+ ANNO_PATH_VALID = 'data/tacos/val.jsonl'
+ ANNO_PATH_TEST = 'data/tacos/test.jsonl'
+
+ VIDEO_ROOT = 'data/tacos/videos_3fps_480_noaudio'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'val':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ annos = []
+ for raw_anno in raw_annos:
+ assert len(raw_anno['relevant_windows']) == 1
+
+ vid = raw_anno['vid']
+
+ anno = dict(
+ source='tacos',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '-cam-002.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=raw_anno['relevant_windows'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/vid_morp.py b/videomind/dataset/sub_classes/vid_morp.py
new file mode 100644
index 0000000000000000000000000000000000000000..3992abe238f888d0ba15dfd7924d951e14803f47
--- /dev/null
+++ b/videomind/dataset/sub_classes/vid_morp.py
@@ -0,0 +1,45 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='vid_morp')
+class VidMorpDataset(GroundingDataset):
+
+ ANNO_PATH = 'data/vid_morp/anno_vid_morp.jsonl'
+
+ VIDEO_ROOT = 'data/vid_morp/videos_3fps_480_noaudio'
+
+ UNIT = 0.001
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ raw_annos = nncore.load(self.ANNO_PATH)
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['vid']
+
+ if vid not in all_videos:
+ continue
+
+ anno = dict(
+ source='vid_morp',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=raw_anno['duration'],
+ query=parse_query(raw_anno['query']),
+ span=[raw_anno['span']])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/videomme.py b/videomind/dataset/sub_classes/videomme.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7fd32f443e87ce6dbdf16222dce5efcad43058
--- /dev/null
+++ b/videomind/dataset/sub_classes/videomme.py
@@ -0,0 +1,52 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+from torch.utils.data import Dataset
+
+import pandas as pd
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_query, parse_question
+
+
+@DATASETS.register(name='videomme')
+class VideoMMEDataset(Dataset):
+
+ ANNO_PATH = 'data/videomme/test-00000-of-00001.parquet'
+
+ VIDEO_ROOT = 'data/videomme/videos'
+ SUBTITLE_ROOT = 'data/videomme/subtitles'
+
+ @classmethod
+ def load_annos(self, split='test'):
+ assert split == 'test'
+
+ raw_annos = pd.read_parquet(self.ANNO_PATH).to_dict(orient='records')
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['videoID']
+
+ options = raw_anno['options'].tolist()
+
+ assert len(options) == 4
+ assert all(any(o.startswith(k) for k in ('A. ', 'B. ', 'C. ', 'D. ')) for o in options)
+
+ options = [o[3:] for o in options]
+ ans = raw_anno['answer']
+ answer = options[ord(ans) - ord('A')]
+ assert ans in 'ABCD'
+
+ anno = dict(
+ source='videomme',
+ data_type='multimodal',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ query=parse_query(raw_anno['question']),
+ question=parse_question(raw_anno['question']),
+ options=options,
+ answer=answer,
+ ans=ans,
+ task=raw_anno['duration'])
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/videoxum.py b/videomind/dataset/sub_classes/videoxum.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e60e8c04ea82f29e23bc81f510a58a732d6b08
--- /dev/null
+++ b/videomind/dataset/sub_classes/videoxum.py
@@ -0,0 +1,54 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='videoxum')
+class VideoXumDataset(GroundingDataset):
+
+ ANNO_PATH_TRAIN = 'data/videoxum/train_videoxum.json'
+ ANNO_PATH_VALID = 'data/videoxum/val_videoxum.json'
+ ANNO_PATH_TEST = 'data/videoxum/test_videoxum.json'
+
+ VIDEO_ROOT = 'data/activitynet/videos_3fps_480_noaudio'
+
+ UNIT = 0.01
+
+ @classmethod
+ def load_annos(self, split='train'):
+ if split == 'train':
+ raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
+ elif split == 'valid':
+ raw_annos = nncore.load(self.ANNO_PATH_VALID)
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH_TEST)
+
+ annos = []
+ for raw_anno in raw_annos:
+ vid = raw_anno['video_id']
+
+ duration = raw_anno['duration']
+
+ for query, spans in zip(raw_anno['tsum'], raw_anno['vsum']):
+ assert len(spans) == 10
+
+ # average the spans from 10 annotators
+ span = [round(sum(s[0] for s in spans) / 10, 2), round(sum(s[1] for s in spans) / 10, 2)]
+
+ anno = dict(
+ source='videoxum',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=parse_query(query),
+ span=[span])
+
+ annos.append(anno)
+
+ annos.append(anno)
+
+ return annos
diff --git a/videomind/dataset/sub_classes/youcook2.py b/videomind/dataset/sub_classes/youcook2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e277ebe1317c4d3d5eac37e32ba1ebe8db694a8e
--- /dev/null
+++ b/videomind/dataset/sub_classes/youcook2.py
@@ -0,0 +1,107 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from collections import OrderedDict
+
+import nncore
+
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.wrappers import GroundingDataset
+from videomind.utils.parser import parse_query
+
+
+@DATASETS.register(name='youcook2')
+class YouCook2Dataset(GroundingDataset):
+
+ ANNO_PATH = 'data/youcook2/youcookii_annotations_trainval.json'
+
+ VIDEO_ROOT = 'data/youcook2/videos_3fps_480_noaudio'
+
+ UNIT = 1.0
+
+ @classmethod
+ def load_annos(self, split='train'):
+ subset = 'training' if split == 'train' else 'validation'
+
+ raw_annos = nncore.load(self.ANNO_PATH, object_pairs_hook=OrderedDict)['database']
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ if raw_anno['subset'] != subset:
+ continue
+
+ if vid not in all_videos:
+ continue
+
+ duration = raw_anno['duration']
+
+ for meta in raw_anno['annotations']:
+ anno = dict(
+ source='youcook2',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=parse_query(meta['sentence']),
+ span=[meta['segment']])
+
+ annos.append(anno)
+
+ annos.append(anno)
+
+ return annos
+
+
+@DATASETS.register(name='youcook2_bias')
+class YouCook2BiasDataset(YouCook2Dataset):
+
+ @classmethod
+ def load_annos(self, split='train'):
+ subset = 'training' if split == 'train' else 'validation'
+
+ raw_annos = nncore.load(self.ANNO_PATH, object_pairs_hook=OrderedDict)['database']
+
+ all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
+ all_videos = set(v[:11] for v in all_videos)
+
+ annos = []
+ for vid, raw_anno in raw_annos.items():
+ if raw_anno['subset'] != subset:
+ continue
+
+ if vid not in all_videos:
+ continue
+
+ duration = raw_anno['duration']
+
+ moments = raw_anno['annotations']
+
+ for i in range(len(moments) - 1):
+ span_a = moments[i]['segment']
+ span_b = moments[i + 1]['segment']
+
+ if span_b[0] - span_a[1] < 3:
+ query_a = parse_query(f"The moment before {moments[i + 1]['sentence']}")
+ query_b = parse_query(f"The moment after {moments[i]['sentence']}")
+
+ anno_a = dict(
+ source='youcook2_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=parse_query(query_a),
+ span=[span_a])
+
+ anno_b = dict(
+ source='youcook2_bias',
+ data_type='grounding',
+ video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
+ duration=duration,
+ query=parse_query(query_b),
+ span=[span_b])
+
+ annos.append(anno_a)
+ annos.append(anno_b)
+
+ return annos
diff --git a/videomind/dataset/utils.py b/videomind/dataset/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..61378c8eec998aaaa1b2d5ec11d41bfff49e724b
--- /dev/null
+++ b/videomind/dataset/utils.py
@@ -0,0 +1,351 @@
+# Modified from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
+
+import base64
+import math
+import warnings
+from io import BytesIO
+
+import decord
+import numpy as np
+import torch
+from PIL import Image, ImageSequence
+from torchvision import transforms
+from torchvision.transforms import InterpolationMode
+
+import requests
+from videomind.constants import IGNORE_INDEX
+from videomind.conversation import get_conv
+
+IMAGE_FACTOR = 28
+MIN_PIXELS = 4 * 28 * 28
+MAX_PIXELS = 16384 * 28 * 28
+MAX_RATIO = 200
+
+VIDEO_MIN_PIXELS = 128 * 28 * 28
+VIDEO_MAX_PIXELS = 768 * 28 * 28
+VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
+FRAME_FACTOR = 2
+FPS = 2.0
+FPS_MIN_FRAMES = 4
+FPS_MAX_FRAMES = 768
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+def smart_resize(height: int,
+ width: int,
+ factor: int = IMAGE_FACTOR,
+ min_pixels: int = MIN_PIXELS,
+ max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
+ """
+ Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > MAX_RATIO:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}")
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ # change order here to ensure not exceeding max_pixels
+ if h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = floor_by_factor(height / beta, factor)
+ w_bar = floor_by_factor(width / beta, factor)
+ return h_bar, w_bar
+
+
+def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
+ if "image" in ele:
+ image = ele["image"]
+ else:
+ image = ele["image_url"]
+ image_obj = None
+ if isinstance(image, Image.Image):
+ image_obj = image
+ elif image.startswith("http://") or image.startswith("https://"):
+ image_obj = Image.open(requests.get(image, stream=True).raw)
+ elif image.startswith("file://"):
+ image_obj = Image.open(image[7:])
+ elif image.startswith("data:image"):
+ if "base64," in image:
+ _, base64_data = image.split("base64,", 1)
+ data = base64.b64decode(base64_data)
+ image_obj = Image.open(BytesIO(data))
+ else:
+ image_obj = Image.open(image)
+ if image_obj is None:
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
+ image = image_obj.convert("RGB")
+
+ if "resized_height" in ele and "resized_width" in ele:
+ resized_height, resized_width = smart_resize(
+ ele["resized_height"],
+ ele["resized_width"],
+ factor=size_factor,
+ )
+ else:
+ width, height = image.size
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=size_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+
+def smart_nframes(
+ ele: dict,
+ total_frames: int,
+ video_fps: int | float,
+) -> int:
+ """calculate the number of frames for video used for model inputs.
+
+ Args:
+ ele (dict): a dict contains the configuration of video.
+ support either `fps` or `nframes`:
+ - nframes: the number of frames to extract for model inputs.
+ - fps: the fps to extract frames for model inputs.
+ - min_frames: the minimum number of frames of the video, only used when fps is provided.
+ - max_frames: the maximum number of frames of the video, only used when fps is provided.
+ total_frames (int): the original total number of frames of the video.
+ video_fps (int | float): the original fps of the video.
+
+ Raises:
+ ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
+
+ Returns:
+ int: the number of frames for video used for model inputs.
+ """
+ assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
+ if "nframes" in ele:
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
+ else:
+ fps = ele.get("fps", FPS)
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
+ max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
+ nframes = total_frames / video_fps * fps
+ nframes = min(max(nframes, min_frames), max_frames)
+ nframes = round_by_factor(nframes, FRAME_FACTOR)
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
+ raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
+ return nframes
+
+
+def _read_video_gif(path):
+ gif = Image.open(path)
+ frames = []
+ for frame in ImageSequence.Iterator(gif):
+ frames.append(np.array(frame.convert('RGB')))
+ frames = np.stack(frames, axis=0)
+ return frames
+
+
+def _read_video_decord(ele: dict, ) -> torch.Tensor:
+ """read video using decord.VideoReader
+
+ Args:
+ ele (dict): a dict contains the configuration of video.
+ support keys:
+ - video: the path of video. support "file://", "http://", "https://" and local path.
+ - video_start: the start time of video.
+ - video_end: the end time of video.
+ Returns:
+ torch.Tensor: the video tensor with shape (T, C, H, W).
+ """
+ video_path = ele["video"]
+ if video_path.endswith('.gif'):
+ video = _read_video_gif(video_path)
+ total_frames, video_fps = video.shape[0], ele.get('fps', FPS)
+ else:
+ vr = decord.VideoReader(video_path, num_threads=ele.get('num_threads', 0))
+ total_frames, video_fps = len(vr), vr.get_avg_fps()
+
+ # 1. re-calculate total frames
+ s = ele.get('video_start')
+ s = 0 if s is None else s
+ e = ele.get('video_end')
+ e = total_frames / video_fps if e is None else e
+ s_frame = min(max(0, round(s * video_fps)), total_frames - 1)
+ e_frame = min(max(0, round(e * video_fps)), total_frames - 1)
+ if s_frame > e_frame:
+ warnings.warn(f's_frame ({s_frame}) is greater than e_frame ({e_frame}), total_frames: {total_frames}')
+ s_frame, e_frame = e_frame, s_frame
+
+ # TODO: the actual total_frames shall be computed by e_frame - s_frame + 1
+ # but it would affect verifier's performance when video_start and video_end get clamped
+ # shall be fixed by using normalized timestamps instead of real time
+ total_frames = min(max(1, round((e - s) * video_fps)), total_frames)
+
+ if total_frames > FPS_MIN_FRAMES:
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
+ else:
+ nframes = total_frames
+
+ # 2. generate frame ids
+ idx = torch.linspace(s_frame, e_frame, nframes).round().long().tolist()
+ assert len(idx) == nframes, (len(idx), nframes)
+
+ if video_path.endswith('.gif'):
+ video = video[idx]
+ else:
+ video = vr.get_batch(idx).asnumpy()
+
+ video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
+ return video
+
+
+def fetch_video(ele: dict, image_factor: int = IMAGE_FACTOR, sanity_check=False) -> torch.Tensor | list[Image.Image]:
+ if isinstance(ele["video"], str):
+ video = _read_video_decord(ele)
+ nframes, _, height, width = video.shape
+
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
+ max_pixels = ele.get("max_pixels", max_pixels)
+ if "resized_height" in ele and "resized_width" in ele:
+ resized_height, resized_width = smart_resize(
+ ele["resized_height"],
+ ele["resized_width"],
+ factor=image_factor,
+ )
+ else:
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=image_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ video = transforms.functional.resize(
+ video,
+ [resized_height, resized_width],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ ).float()
+
+ if sanity_check and (video == 0).all():
+ raise ValueError("video '{}' contains all zeros".format(ele["video"]))
+
+ return video
+ else:
+ assert isinstance(ele["video"], (list, tuple))
+ process_info = ele.copy()
+ process_info.pop("type", None)
+ process_info.pop("video", None)
+ images = [
+ fetch_image({
+ "image": video_element,
+ **process_info
+ }, size_factor=image_factor) for video_element in ele["video"]
+ ]
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
+ if len(images) < nframes:
+ images.extend([images[-1]] * (nframes - len(images)))
+ return images
+
+
+def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
+ vision_infos = []
+ if isinstance(conversations[0], dict):
+ conversations = [conversations]
+ for conversation in conversations:
+ for message in conversation:
+ if isinstance(message["content"], list):
+ for ele in message["content"]:
+ if ("image" in ele or "image_url" in ele or "video" in ele
+ or ele["type"] in ("image", "image_url", "video")):
+ vision_infos.append(ele)
+ return vision_infos
+
+
+def process_vision_info(
+ conversations: list[dict] | list[list[dict]],
+ sanity_check=False) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None]:
+ vision_infos = extract_vision_info(conversations)
+ # Read images or videos
+ image_inputs = []
+ video_inputs = []
+ for vision_info in vision_infos:
+ if "image" in vision_info or "image_url" in vision_info:
+ image_inputs.append(fetch_image(vision_info))
+ elif "video" in vision_info:
+ video_inputs.append(fetch_video(vision_info, sanity_check=sanity_check))
+ else:
+ raise ValueError("image, image_url or video should in content.")
+ if len(image_inputs) == 0:
+ image_inputs = None
+ if len(video_inputs) == 0:
+ video_inputs = None
+ return image_inputs, video_inputs
+
+
+def preprocess_chatml(input_ids, text, tokenizer):
+ conv = get_conv('chatml')
+
+ rounds = [m + conv.seps[0] for m in text.split(conv.seps[0])]
+ assert (len(rounds) % 2 == 0) == (conv.system is not None)
+ assert rounds[-1] == conv.seps[0]
+ rounds = rounds[:-1]
+
+ if conv.system is None:
+ rounds = [''.join(rounds[i:i + 2]) for i in range(0, len(rounds), 2)]
+ else:
+ rounds = [''.join(rounds[:3])] + [''.join(rounds[i:i + 2]) for i in range(3, len(rounds), 2)]
+
+ labels = input_ids.clone()
+
+ sep = conv.seps[0] + conv.roles[1]
+ cur_len = 0
+
+ for i, rou in enumerate(rounds):
+ if len(rou) == 0:
+ break
+
+ ins = sep.join(rou.split(sep)[:-1]) + sep
+
+ rou_len = tokenizer(rou, return_length=True).length[0]
+ ins_len = tokenizer(ins, return_length=True).length[0]
+
+ labels[cur_len:cur_len + ins_len] = IGNORE_INDEX
+ cur_len += rou_len
+
+ if labels.size(0) != cur_len:
+ warnings.warn(f'Tokenization mismatch: {labels.size(0)} and {cur_len}')
+
+ return labels
+
+
+def preprocess(input_ids, text, tokenizer, conv_type):
+ if conv_type == 'chatml':
+ return preprocess_chatml(input_ids, text, tokenizer)
+ else:
+ raise ValueError(f'unknown conversation type: {conv_type}')
diff --git a/videomind/dataset/wrappers/__init__.py b/videomind/dataset/wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45952e692101072127f58be8115b8b6f0d0fac78
--- /dev/null
+++ b/videomind/dataset/wrappers/__init__.py
@@ -0,0 +1,6 @@
+from .answering import AnsweringCropDataset, AnsweringDataset
+from .grounding import GroundingDataset
+from .planning import PlanningDataset
+from .verifying import VerifyingDataset
+
+__all__ = ['AnsweringCropDataset', 'AnsweringDataset', 'GroundingDataset', 'PlanningDataset', 'VerifyingDataset']
diff --git a/videomind/dataset/wrappers/answering.py b/videomind/dataset/wrappers/answering.py
new file mode 100644
index 0000000000000000000000000000000000000000..42d5ad3364391b3552ff6d49020a327e487103d3
--- /dev/null
+++ b/videomind/dataset/wrappers/answering.py
@@ -0,0 +1,112 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import copy
+import random
+
+from torch.utils.data import Dataset
+
+from videomind.utils.parser import parse_span
+
+
+class AnsweringDataset(Dataset):
+
+ def __init__(self, processor, model_args, data_args, training_args):
+ super(AnsweringDataset, self).__init__()
+
+ raw_annos = self.load_annos()
+
+ annos = []
+ for anno in raw_annos:
+ num_words = len(anno['question'].split(' ')) + len(anno['answer'].split(' '))
+ if data_args.min_num_words >= 0 and num_words < data_args.min_num_words:
+ continue
+ if data_args.max_num_words >= 0 and num_words > data_args.max_num_words:
+ continue
+ if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len:
+ continue
+ if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len:
+ continue
+ annos.append(anno)
+
+ self.annos = annos
+ self.raw_length = len(raw_annos)
+ self.processor = processor
+ self.model_args = model_args
+ self.data_args = data_args
+ self.training_args = training_args
+
+ def __len__(self):
+ return len(self.annos)
+
+ def __getitem__(self, idx):
+ anno = copy.deepcopy(self.annos[idx])
+
+ video_path, question, answer = anno['video_path'], anno['question'], anno['answer']
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'min_pixels': 128 * 28 * 28,
+ 'max_pixels': 256 * 28 * 28,
+ 'max_frames': 32,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': question
+ }]
+ }, {
+ 'role': 'assistant',
+ 'content': answer
+ }]
+
+ meta = dict(messages=messages)
+ return meta
+
+
+class AnsweringCropDataset(AnsweringDataset):
+
+ def __getitem__(self, idx):
+ anno = copy.deepcopy(self.annos[idx])
+
+ video_path, question, answer = anno['video_path'], anno['question'], anno['answer']
+
+ if anno.get('no_aug'):
+ s, e = anno['span'][0]
+ else:
+ # max 32 frames / 2 fps
+ s, e = parse_span(anno['span'][0], anno['duration'], 16)
+
+ # apply temporal jittering
+ offset = (e - s) / 4
+ s = random.uniform(s - offset, s + offset)
+ e = random.uniform(e - offset, e + offset)
+
+ # clamp the augmented span
+ s, e = parse_span([s, e], anno['duration'])
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'video_start': s,
+ 'video_end': e,
+ 'min_pixels': 128 * 28 * 28,
+ 'max_pixels': 256 * 28 * 28,
+ 'max_frames': 32,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': question
+ }]
+ }, {
+ 'role': 'assistant',
+ 'content': answer
+ }]
+
+ meta = dict(messages=messages)
+ return meta
diff --git a/videomind/dataset/wrappers/grounding.py b/videomind/dataset/wrappers/grounding.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f0908d0ddf8f0c66f8c9cd6efe208b53e58a050
--- /dev/null
+++ b/videomind/dataset/wrappers/grounding.py
@@ -0,0 +1,65 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import copy
+
+from torch.utils.data import Dataset
+
+from videomind.constants import GROUNDER_PROMPT, REG_TOKEN
+
+
+class GroundingDataset(Dataset):
+
+ def __init__(self, processor, model_args, data_args, training_args):
+ super(GroundingDataset, self).__init__()
+
+ raw_annos = self.load_annos()
+
+ annos = []
+ for anno in raw_annos:
+ num_words = len(anno['query'].split(' '))
+ if data_args.min_num_words >= 0 and num_words < data_args.min_num_words:
+ continue
+ if data_args.max_num_words >= 0 and num_words > data_args.max_num_words:
+ continue
+ if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len:
+ continue
+ if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len:
+ continue
+ annos.append(anno)
+
+ self.annos = annos
+ self.raw_length = len(raw_annos)
+ self.processor = processor
+ self.model_args = model_args
+ self.data_args = data_args
+ self.training_args = training_args
+
+ def __len__(self):
+ return len(self.annos)
+
+ def __getitem__(self, idx):
+ anno = copy.deepcopy(self.annos[idx])
+
+ video_path, duration, query, span = anno['video_path'], anno['duration'], anno['query'], anno['span']
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 150,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': GROUNDER_PROMPT.format(query)
+ }]
+ }, {
+ 'role': 'assistant',
+ 'content': f'The relevant moment happens in {REG_TOKEN}.'
+ }]
+
+ meta = dict(messages=messages, span=span, duration=duration)
+ return meta
diff --git a/videomind/dataset/wrappers/planning.py b/videomind/dataset/wrappers/planning.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b801a79e32b8b7555913935d159f83f6e9c1160
--- /dev/null
+++ b/videomind/dataset/wrappers/planning.py
@@ -0,0 +1,94 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import copy
+
+import nncore
+from torch.utils.data import Dataset
+
+from videomind.constants import PLANNER_PROMPT
+from videomind.dataset.hybrid import DATASETS
+
+
+class PlanningDataset(Dataset):
+
+ def __init__(self, processor, model_args, data_args, training_args):
+ super(PlanningDataset, self).__init__()
+
+ raw_annos = self.load_annos()
+
+ annos = []
+ for anno in raw_annos:
+ num_words = len(anno.get('question', '').split(' ')) + len(anno.get('query', '').split(' '))
+ if data_args.min_num_words >= 0 and num_words < data_args.min_num_words:
+ continue
+ if data_args.max_num_words >= 0 and num_words > data_args.max_num_words:
+ continue
+ if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len:
+ continue
+ if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len:
+ continue
+ annos.append(anno)
+
+ self.annos = annos
+ self.raw_length = len(raw_annos)
+ self.processor = processor
+ self.model_args = model_args
+ self.data_args = data_args
+ self.training_args = training_args
+
+ def __len__(self):
+ return len(self.annos)
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+ annos = nncore.load(self.ANNO_PATH)
+ return annos
+
+ def __getitem__(self, idx):
+ anno = copy.deepcopy(self.annos[idx])
+
+ video_path, route, question, query = anno['video_path'], anno['route'], anno['question'], anno.get('query')
+
+ if route == 1:
+ # rephrasing + grounding + answering
+ response = f'[{{"type": "grounder", "value": "{query}"}}, {{"type": "verifier"}}, {{"type": "answerer"}}]'
+ elif route == 2:
+ # grounding + answering
+ response = f'[{{"type": "grounder", "value": "{question}"}}, {{"type": "verifier"}}, {{"type": "answerer"}}]'
+ elif route == 3:
+ # rephrasing + grounding
+ response = f'[{{"type": "grounder", "value": "{query}"}}, {{"type": "verifier"}}]'
+ elif route == 4:
+ # answering
+ response = '[{"type": "answerer"}]'
+ else:
+ raise KeyError(f'unknown route type: {route}')
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 100,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': PLANNER_PROMPT.format(question)
+ }]
+ }, {
+ 'role': 'assistant',
+ 'content': response
+ }]
+
+ meta = dict(messages=messages)
+ return meta
+
+
+@DATASETS.register(name='mixed_planning')
+class MixedPlanningDataset(PlanningDataset):
+
+ ANNO_PATH = 'data/planning/planning_nextqa_qvhighlights_gpt4o_mini.jsonl'
diff --git a/videomind/dataset/wrappers/verifying.py b/videomind/dataset/wrappers/verifying.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3279ead5ffece6c9c15932057b5065669c2fbd8
--- /dev/null
+++ b/videomind/dataset/wrappers/verifying.py
@@ -0,0 +1,180 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import copy
+
+import nncore
+import torch
+from nncore.ops import temporal_iou
+from torch.utils.data import Dataset
+
+from videomind.constants import VERIFIER_PROMPT
+from videomind.dataset.hybrid import DATASETS
+from videomind.utils.parser import parse_span
+
+
+class VerifyingDataset(Dataset):
+
+ def __init__(self, processor, model_args, data_args, training_args):
+ super(VerifyingDataset, self).__init__()
+
+ raw_annos = self.load_annos()
+
+ annos = []
+ for anno in raw_annos:
+ num_words = len(anno['query'].split(' '))
+ if data_args.min_num_words >= 0 and num_words < data_args.min_num_words:
+ continue
+ if data_args.max_num_words >= 0 and num_words > data_args.max_num_words:
+ continue
+ if data_args.min_video_len >= 0 and anno.get('duration', float('inf')) < data_args.min_video_len:
+ continue
+ if data_args.max_video_len >= 0 and anno.get('duration', 0) > data_args.max_video_len:
+ continue
+ annos.append(anno)
+
+ self.annos = annos
+ self.raw_length = len(raw_annos)
+ self.processor = processor
+ self.model_args = model_args
+ self.data_args = data_args
+ self.training_args = training_args
+
+ def __len__(self):
+ return len(self.annos)
+
+ @classmethod
+ def load_annos(self, split='train'):
+ assert split == 'train'
+
+ if nncore.is_dir(self.ANNO_PATH):
+ raw_paths = nncore.ls(self.ANNO_PATH, ext='json', join_path=True, sort=True)
+ raw_annos = nncore.flatten([nncore.load(p) for p in raw_paths])
+ else:
+ raw_annos = nncore.load(self.ANNO_PATH)
+
+ annos = []
+ for raw_anno in raw_annos:
+ # using top-5 predictions
+ for pred in raw_anno['pred'][:5]:
+ iou = temporal_iou(torch.Tensor([pred]), torch.Tensor(raw_anno['span']))
+ iou = torch.where(iou.isfinite(), iou, 0)
+ iou = iou.max().item()
+
+ positive = iou >= 0.5
+
+ anno = dict(
+ source=self.SOURCE,
+ data_type='multimodal',
+ video_path=raw_anno['video_path'],
+ duration=raw_anno['duration'],
+ query=raw_anno['query'],
+ span=raw_anno['span'],
+ pred=pred,
+ positive=positive,
+ task=raw_anno.get('task', 'unknown'))
+
+ annos.append(anno)
+
+ pos_inds = [i for i, a in enumerate(annos) if a['positive']]
+ neg_inds = [i for i, a in enumerate(annos) if not a['positive']]
+
+ num_pos = len(pos_inds)
+ num_neg = len(neg_inds)
+
+ print(f'[{self.SOURCE}] pos: {num_pos} neg: {num_neg} n/p ratio: {num_neg / num_pos}')
+
+ # filter negative samples
+ # if num_neg > num_pos * 3:
+ # neg_inds = random.sample(neg_inds, int(num_pos * 3))
+
+ # inds = pos_inds + neg_inds
+ # random.shuffle(inds)
+ # inds = comm.broadcast(inds)
+
+ # annos = [annos[i] for i in inds]
+
+ return annos
+
+ def __getitem__(self, idx):
+ anno = copy.deepcopy(self.annos[idx])
+
+ video_path, duration, query, positive = anno['video_path'], anno['duration'], anno['query'], anno['positive']
+
+ s0, e0 = parse_span(anno['pred'], duration, 2)
+ offset = (e0 - s0) / 2
+ s1, e1 = parse_span([s0 - offset, e0 + offset], duration)
+
+ # percentage of s0, e0 within s1, e1
+ s = (s0 - s1) / (e1 - s1)
+ e = (e0 - s1) / (e1 - s1)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'video_start': s1,
+ 'video_end': e1,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 64,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': VERIFIER_PROMPT.format(query)
+ }]
+ }]
+
+ messages = messages + [{'role': 'assistant', 'content': 'Yes.' if positive else 'No.'}]
+ meta = dict(messages=messages, ss=s, se=e)
+
+ return meta
+
+
+@DATASETS.register(name='qvhighlights_verify_2b')
+class QVHighlightsVerify2BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_qvhighlights_2b.json'
+
+ SOURCE = 'qvhighlights_verify_2b'
+
+
+@DATASETS.register(name='didemo_verify_2b')
+class DiDeMoVerify2BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_didemo_2b.json'
+
+ SOURCE = 'didemo_verify_2b'
+
+
+@DATASETS.register(name='tacos_verify_2b')
+class TACoSVerify2BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_tacos_2b.json'
+
+ SOURCE = 'tacos_verify_2b'
+
+
+@DATASETS.register(name='qvhighlights_verify_7b')
+class QVHighlightsVerify7BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_qvhighlights_7b.json'
+
+ SOURCE = 'qvhighlights_verify_7b'
+
+
+@DATASETS.register(name='didemo_verify_7b')
+class DiDeMoVerify7BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_didemo_7b.json'
+
+ SOURCE = 'didemo_verify_7b'
+
+
+@DATASETS.register(name='tacos_verify_7b')
+class TACoSVerify7BDataset(VerifyingDataset):
+
+ ANNO_PATH = 'data/verifying/verifying_tacos_7b.json'
+
+ SOURCE = 'tacos_verify_7b'
diff --git a/videomind/eval/eval_auto.py b/videomind/eval/eval_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..e840b672d3f1dc5888e5091b154647d93f35fd1a
--- /dev/null
+++ b/videomind/eval/eval_auto.py
@@ -0,0 +1,289 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import argparse
+
+import nncore
+import torch
+from nncore.ops import temporal_area, temporal_intersection, temporal_iof, temporal_iou
+from tabulate import tabulate
+
+
+class SafeInt(int):
+
+ def __truediv__(self, other):
+ try:
+ return SafeInt(super().__truediv__(other))
+ except ZeroDivisionError:
+ return SafeInt(0)
+
+
+def check_ans(options, ans, response):
+ a = ans.lower()
+ b = response.lower().split(' ')[0].replace('(', '').replace(')', '').replace('.', '')
+ if len(b) != 1:
+ b = b[0]
+ nncore.log(f'WARNING: {response} -> {b}')
+ if b not in [chr(ord('a') + i) for i in range(len(options))]:
+ nncore.log(f'ERROR: {response} -> {b}')
+ return
+ return a == b
+
+
+def compute_iou(pred, span, conf, cgbench_mode, conf_thr):
+ pred_tensor = torch.Tensor(pred)
+ span_tensor = torch.Tensor(span)
+
+ if cgbench_mode:
+ if conf_thr > 0:
+ conf_tensor = torch.Tensor(conf)
+ keep = torch.cat((torch.LongTensor([0]), torch.where(conf_tensor > conf_thr)[0])).unique()
+ pred_tensor = pred_tensor[keep]
+ else:
+ pred_tensor = pred_tensor[:1]
+ pred_area = temporal_area(pred_tensor).sum()
+ span_area = temporal_area(span_tensor).sum()
+ inter = temporal_intersection(pred_tensor, span_tensor).sum()
+ iou = (inter / (pred_area + span_area - inter)).unsqueeze(0)
+ assert iou.numel() == 1
+ else:
+ iou = temporal_iou(pred_tensor, span_tensor)
+
+ iou = torch.where(iou.isfinite(), iou, 0)
+ return iou
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('pred_path')
+ parser.add_argument('--dataset')
+ parser.add_argument('--out_name', default='metrics.log')
+ parser.add_argument('--conf_thr', type=float, default=-1)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ assert nncore.is_dir(args.pred_path)
+
+ log_file = nncore.join(args.pred_path, args.out_name)
+ nncore.set_default_logger(logger='eval', fmt=None, log_file=log_file)
+
+ if args.dataset is not None:
+ cgbench_mode = args.dataset == 'cgbench'
+ nncore.log(f'CG-Bench mode: {cgbench_mode}')
+ else:
+ cgbench_mode = False
+ nncore.log('Dataset is unknown, using default mode', log_level='WARNING')
+
+ pred_paths = nncore.ls(args.pred_path, ext=['json', 'jsonl'], join_path=True)
+ nncore.log(f'Total number of files: {len(pred_paths)}')
+
+ if cgbench_mode:
+ top_k = [1]
+ thres = [0.1, 0.2, 0.3, 0.4, 0.5]
+ else:
+ top_k = [1, 3, 5]
+ thres = [0.3, 0.5, 0.7]
+
+ tab_iou, tab_iop, tab_ans = dict(), dict(), dict()
+ iou_raise, iou_lower, iop_raise, iop_lower = SafeInt(0), SafeInt(0), SafeInt(0), SafeInt(0)
+ tab_iou_all = [SafeInt(0) for _ in range(len(top_k) * len(thres) + 3)]
+ tab_iop_all = [SafeInt(0) for _ in range(len(top_k) * len(thres) + 3)]
+ tab_ans_all = [SafeInt(0) for _ in range(len(thres) + 5)]
+
+ for path in pred_paths:
+ data = nncore.load(path)
+
+ for sample in data:
+ task = sample.get('task', 'unknown')
+
+ # samples in lvbench might have multiple tasks
+ if isinstance(task, str):
+ task = [task]
+
+ for t in task:
+ if t not in tab_iou:
+ tab_iou[t] = [SafeInt(0) for _ in range(len(top_k) * len(thres) + 3)]
+
+ if t not in tab_iop:
+ tab_iop[t] = [SafeInt(0) for _ in range(len(top_k) * len(thres) + 3)]
+
+ if t not in tab_ans:
+ tab_ans[t] = [SafeInt(0) for _ in range(len(thres) + 5)]
+
+ iou_hit = [False for _ in range(len(thres) + 1)]
+ iop_hit = False
+
+ if 'pred' in sample and 'conf' in sample and 'span' in sample:
+ for t in task:
+ tab_iou[t][0] += 1
+ tab_iop[t][0] += 1
+ tab_iou_all[0] += 1
+ tab_iop_all[0] += 1
+
+ iou = compute_iou(sample['pred'], sample['span'], sample['conf'], cgbench_mode, args.conf_thr)
+ top = iou[0].max().item()
+
+ for t in task:
+ tab_iou[t][-1] += top
+ tab_iou_all[-1] += top
+
+ for i, k in enumerate(top_k):
+ for j, h in enumerate(thres):
+ if iou[:k].max() >= h:
+ for t in task:
+ tab_iou[t][i * len(thres) + j + 2] += 1
+ tab_iou_all[i * len(thres) + j + 2] += 1
+ if k == 1:
+ iou_hit[j + 1] = True
+ if h == 0.5:
+ iou_hit[0] = True
+
+ if sample.get('pred_ori') is not None:
+ iou = compute_iou(sample['pred_ori'], sample['span'], sample['conf_ori'], cgbench_mode,
+ args.conf_thr)
+ iou = iou[0].max().item()
+
+ if iou < top:
+ iou_raise += 1
+ if iou > top:
+ iou_lower += 1
+
+ iop = temporal_iof(torch.Tensor(sample['pred']), torch.Tensor(sample['span']))
+ iop = torch.where(iop.isfinite(), iop, 0)
+ top = iop[0].max().item()
+
+ for t in task:
+ tab_iop[t][-1] += top
+ tab_iop_all[-1] += top
+
+ for i, k in enumerate(top_k):
+ for j, h in enumerate(thres):
+ if iop[:k].max() >= h:
+ for t in task:
+ tab_iop[t][i * len(thres) + j + 2] += 1
+ tab_iop_all[i * len(thres) + j + 2] += 1
+ if k == 1 and h == 0.5:
+ iop_hit = True
+
+ if sample.get('pred_ori') is not None:
+ iop = temporal_iof(torch.Tensor(sample['pred_ori']), torch.Tensor(sample['span']))
+ iop = torch.where(iop.isfinite(), iop, 0)
+ iop = iop[0].max().item()
+
+ if iop < top:
+ iop_raise += 1
+ if iop > top:
+ iop_lower += 1
+
+ if not sample.get('grounder_success', True):
+ for t in task:
+ tab_iou[t][1] += 1
+ tab_iop[t][1] += 1
+ tab_iou_all[1] += 1
+ tab_iop_all[1] += 1
+
+ if 'question' in sample and 'response' in sample:
+ for t in task:
+ tab_ans[t][0] += 1
+ tab_ans_all[0] += 1
+
+ correct = check_ans(sample['options'], sample['ans'], sample['response'])
+
+ if correct:
+ for t in task:
+ tab_ans[t][2] += 1
+ tab_ans_all[2] += 1
+ if iou_hit[0]:
+ for t in task:
+ tab_ans[t][3] += 1
+ tab_ans_all[3] += 1
+ if iop_hit:
+ for t in task:
+ tab_ans[t][4] += 1
+ tab_ans_all[4] += 1
+ for i in range(1, len(iou_hit)):
+ if iou_hit[i]:
+ for t in task:
+ tab_ans[t][i + 4] += 1
+ tab_ans_all[i + 4] += 1
+ elif correct is None:
+ for t in task:
+ tab_ans[t][1] += 1
+ tab_ans_all[1] += 1
+
+ tasks = sorted(list(set(list(tab_iou.keys()) + list(tab_iop.keys()) + list(tab_ans.keys()))))
+
+ if cgbench_mode:
+ nncore.log('\nGrounding (IoU):')
+ tab = tabulate(
+ [[task, tab_iou[task][0], tab_iou[task][1]] +
+ [f'{tab_iou[task][i] / tab_iou[task][0] * 100:.2f}' for i in range(2, len(tab_iou[task]))] +
+ [f'{sum(tab_iou[task][i] / tab_iou[task][0] for i in range(2, 2 + len(thres))) / len(thres) * 100:.2f}']
+ for task in tasks if task in tab_iou] +
+ [['all', tab_iou_all[0], tab_iou_all[1]] +
+ [f'{tab_iou_all[i] / tab_iou_all[0] * 100:.2f}' for i in range(2, len(tab_iou_all))] +
+ [f'{sum(tab_iou_all[i] / tab_iou_all[0] for i in range(2, 2 + len(thres))) / len(thres) * 100:.2f}']],
+ headers=['Task', '#Samples', 'Failed'] + [f'R{k}@{t}' for k in top_k for t in thres] + ['mIoU', 'rec.@IoU'],
+ tablefmt='pretty',
+ stralign='left')
+ nncore.log(tab)
+
+ nncore.log(f'\nIoU Raise ({tab_iou_all[0]} Samples): {iou_raise} ({iou_raise / tab_iou_all[0] * 100:.2f}%)')
+ nncore.log(f'IoU Lower ({tab_iou_all[0]} Samples): {iou_lower} ({iou_lower / tab_iou_all[0] * 100:.2f}%)')
+
+ nncore.log('\nQA:')
+ tab = tabulate(
+ [[task, tab_ans[task][0], tab_ans[task][1], f'{tab_ans[task][2] / tab_ans[task][0] * 100:.2f}'] +
+ [f'{sum(tab_ans[task][i] / tab_ans[task][0] for i in range(5, 5 + len(thres))) / len(thres) * 100:.2f}']
+ for task in tasks if task in tab_ans] +
+ [['all', tab_ans_all[0], tab_ans_all[1], f'{tab_ans_all[2] / tab_ans_all[0] * 100:.2f}'] +
+ [f'{sum(tab_ans_all[i] / tab_ans_all[0] for i in range(5, 5 + len(thres))) / len(thres) * 100:.2f}']],
+ headers=['Task', '#Samples', 'Failed', 'long-acc.', 'acc.@IoU'],
+ tablefmt='pretty',
+ stralign='left')
+ nncore.log(tab)
+ else:
+ nncore.log('\nGrounding (IoU):')
+ tab = tabulate(
+ [[task, tab_iou[task][0], tab_iou[task][1]] +
+ [f'{tab_iou[task][i] / tab_iou[task][0] * 100:.2f}' for i in range(2, len(tab_iou[task]))]
+ for task in tasks if task in tab_iou] +
+ [['all', tab_iou_all[0], tab_iou_all[1]] +
+ [f'{tab_iou_all[i] / tab_iou_all[0] * 100:.2f}' for i in range(2, len(tab_iou_all))]],
+ headers=['Task', '#Samples', 'Failed'] + [f'R{k}@{t}' for k in top_k for t in thres] + ['mIoU'],
+ tablefmt='pretty',
+ stralign='left')
+ nncore.log(tab)
+
+ nncore.log(f'\nIoU Raise ({tab_iou_all[0]} Samples): {iou_raise} ({iou_raise / tab_iou_all[0] * 100:.2f}%)')
+ nncore.log(f'IoU Lower ({tab_iou_all[0]} Samples): {iou_lower} ({iou_lower / tab_iou_all[0] * 100:.2f}%)')
+
+ nncore.log('\nGrounding (IoP):')
+ tab = tabulate(
+ [[task, tab_iop[task][0], tab_iop[task][1]] +
+ [f'{tab_iop[task][i] / tab_iop[task][0] * 100:.2f}' for i in range(2, len(tab_iop[task]))]
+ for task in tasks if task in tab_iop] +
+ [['all', tab_iop_all[0], tab_iop_all[1]] +
+ [f'{tab_iop_all[i] / tab_iop_all[0] * 100:.2f}' for i in range(2, len(tab_iop_all))]],
+ headers=['Task', '#Samples', 'Failed'] + [f'R{k}@{t}' for k in top_k for t in thres] + ['mIoP'],
+ tablefmt='pretty',
+ stralign='left')
+ nncore.log(tab)
+
+ nncore.log(f'\nIoP Raise ({tab_iop_all[0]} Samples): {iop_raise} ({iop_raise / tab_iop_all[0] * 100:.2f}%)')
+ nncore.log(f'IoP Lower ({tab_iop_all[0]} Samples): {iop_lower} ({iop_lower / tab_iop_all[0] * 100:.2f}%)')
+
+ nncore.log('\nQA:')
+ tab = tabulate(
+ [[task, tab_ans[task][0], tab_ans[task][1]] +
+ [f'{tab_ans[task][i] / tab_ans[task][0] * 100:.2f}' for i in range(2, 5)]
+ for task in tasks if task in tab_ans] +
+ [['all', tab_ans_all[0], tab_ans_all[1]] +
+ [f'{tab_ans_all[i] / tab_ans_all[0] * 100:.2f}' for i in range(2, 5)]],
+ headers=['Task', '#Samples', 'Failed', 'Acc', 'Acc (IoU >= 0.5)', 'Acc (IoP >= 0.5)'],
+ tablefmt='pretty',
+ stralign='left')
+ nncore.log(tab)
diff --git a/videomind/eval/eval_qvhighlights.py b/videomind/eval/eval_qvhighlights.py
new file mode 100644
index 0000000000000000000000000000000000000000..34ed2128660f2765e78e3c4e99c9beb43af2d8dc
--- /dev/null
+++ b/videomind/eval/eval_qvhighlights.py
@@ -0,0 +1,413 @@
+# Modified from https://github.com/showlab/UniVTG/blob/main/eval/eval.py
+
+import argparse
+import copy
+from collections import OrderedDict, defaultdict
+
+import nncore
+import numpy as np
+
+from sklearn.metrics import precision_recall_curve
+
+
+def compute_temporal_iou_batch_paired(a, b):
+ intersection = np.maximum(0, np.minimum(a[:, 1], b[:, 1]) - np.maximum(a[:, 0], b[:, 0]))
+ union = np.maximum(a[:, 1], b[:, 1]) - np.minimum(a[:, 0], b[:, 0])
+ return np.divide(intersection, union, out=np.zeros_like(intersection), where=union != 0)
+
+
+def compute_temporal_iou_batch_cross(spans1, spans2):
+ areas1 = spans1[:, 1] - spans1[:, 0]
+ areas2 = spans2[:, 1] - spans2[:, 0]
+ l = np.maximum(spans1[:, None, 0], spans2[None, :, 0])
+ r = np.minimum(spans1[:, None, 1], spans2[None, :, 1])
+ inter = np.clip(r - l, 0, None)
+ union = areas1[:, None] + areas2[None, :] - inter
+ iou = inter / union
+ return iou, union
+
+
+def interpolated_precision_recall(prc, rec):
+ mprc = np.hstack([[0], prc, [0]])
+ mrec = np.hstack([[0], rec, [1]])
+ for i in range(len(mprc) - 1)[::-1]:
+ mprc[i] = max(mprc[i], mprc[i + 1])
+ idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1
+ ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprc[idx])
+ return ap
+
+
+def compute_average_precision_detection(annos, prediction, tiou_thresholds=np.linspace(0.5, 0.95, 10)):
+ num_thresholds = len(tiou_thresholds)
+ num_gts = len(annos)
+ num_preds = len(prediction)
+ ap = np.zeros(num_thresholds)
+ if len(prediction) == 0:
+ return ap
+
+ num_positive = float(num_gts)
+ lock_gt = np.ones((num_thresholds, num_gts)) * -1
+ prediction.sort(key=lambda x: -x['score'])
+ tp = np.zeros((num_thresholds, num_preds))
+ fp = np.zeros((num_thresholds, num_preds))
+
+ ground_truth_by_videoid = dict()
+ for i, item in enumerate(annos):
+ item['index'] = i
+ ground_truth_by_videoid.setdefault(item['video-id'], []).append(item)
+
+ for idx, pred in enumerate(prediction):
+ if pred['video-id'] in ground_truth_by_videoid:
+ gts = ground_truth_by_videoid[pred['video-id']]
+ else:
+ fp[:, idx] = 1
+ continue
+
+ _pred = np.array([[pred['t-start'], pred['t-end']]])
+ _gt = np.array([[gt['t-start'], gt['t-end']] for gt in gts])
+ tiou_arr = compute_temporal_iou_batch_cross(_pred, _gt)[0]
+
+ tiou_arr = tiou_arr.reshape(-1)
+ tiou_sorted_idx = tiou_arr.argsort()[::-1]
+ for t_idx, tiou_threshold in enumerate(tiou_thresholds):
+ for j_idx in tiou_sorted_idx:
+ if tiou_arr[j_idx] < tiou_threshold:
+ fp[t_idx, idx] = 1
+ break
+ if lock_gt[t_idx, gts[j_idx]['index']] >= 0:
+ continue
+ tp[t_idx, idx] = 1
+ lock_gt[t_idx, gts[j_idx]['index']] = idx
+ break
+
+ if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0:
+ fp[t_idx, idx] = 1
+
+ tp_cumsum = np.cumsum(tp, axis=1).astype(float)
+ fp_cumsum = np.cumsum(fp, axis=1).astype(float)
+ recall_cumsum = tp_cumsum / num_positive
+
+ precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum)
+
+ for t_idx in range(len(tiou_thresholds)):
+ ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :], recall_cumsum[t_idx, :])
+
+ return ap
+
+
+def get_ap(y_true, y_pred, interpolate=True, point_11=False):
+ assert len(y_true) == len(y_pred), 'Prediction and ground truth need to be of the same length'
+ if len(set(y_true)) == 1:
+ if y_true[0] == 0:
+ return 0
+ else:
+ return 1
+ else:
+ assert sorted(set(y_true)) == [0, 1], 'Ground truth can only contain elements {0,1}'
+
+ precision, recall, _ = precision_recall_curve(y_true, y_pred)
+ recall = recall.astype(np.float32)
+
+ if interpolate:
+ for i in range(1, len(precision)):
+ precision[i] = max(precision[i - 1], precision[i])
+
+ if point_11:
+ precision_11 = [precision[np.where(recall >= t)[0][-1]] for t in np.arange(0, 1.01, 0.1)]
+ return np.mean(precision_11)
+ else:
+ indices = np.where(np.diff(recall))
+ return np.mean(precision[indices])
+
+
+def compute_average_precision_detection_wrapper(input_triple, tiou_thresholds=np.linspace(0.5, 0.95, 10)):
+ qid, annos, prediction = input_triple
+ scores = compute_average_precision_detection(annos, prediction, tiou_thresholds=tiou_thresholds)
+ return qid, scores
+
+
+def compute_mr_ap(preds, annos, iou_thds=np.linspace(0.5, 0.95, 10), max_gt_windows=None, max_pred_windows=10):
+ iou_thds = [float(f'{e:.2f}') for e in iou_thds]
+ pred_qid2data = defaultdict(list)
+ for d in preds:
+ pred_windows = d['pred_relevant_windows'][:max_pred_windows] \
+ if max_pred_windows is not None else d['pred_relevant_windows']
+ qid = d['qid']
+ for w in pred_windows:
+ pred_qid2data[qid].append({'video-id': d['qid'], 't-start': w[0], 't-end': w[1], 'score': w[2]})
+
+ gt_qid2data = defaultdict(list)
+ for d in annos:
+ gt_windows = d['relevant_windows'][:max_gt_windows] \
+ if max_gt_windows is not None else d['relevant_windows']
+ qid = d['qid']
+ for w in gt_windows:
+ gt_qid2data[qid].append({'video-id': d['qid'], 't-start': w[0], 't-end': w[1]})
+ qid2ap_list = dict()
+ data_triples = [[qid, gt_qid2data[qid], pred_qid2data[qid]] for qid in pred_qid2data]
+ from functools import partial
+ compute_ap_from_triple = partial(compute_average_precision_detection_wrapper, tiou_thresholds=iou_thds)
+
+ for data_triple in data_triples:
+ qid, scores = compute_ap_from_triple(data_triple)
+ qid2ap_list[qid] = scores
+
+ ap_array = np.array(list(qid2ap_list.values()))
+ ap_thds = ap_array.mean(0)
+ iou_thd2ap = dict(zip([str(e) for e in iou_thds], ap_thds))
+ iou_thd2ap['average'] = np.mean(ap_thds)
+
+ iou_thd2ap = {k: float(f'{100 * v:.2f}') for k, v in iou_thd2ap.items()}
+ return iou_thd2ap
+
+
+def compute_mr_r1(preds, annos, iou_thds=np.linspace(0.3, 0.95, 14)):
+ iou_thds = [float(f'{e:.2f}') for e in iou_thds]
+ pred_qid2window = {d['qid']: d['pred_relevant_windows'][0][:2] for d in preds}
+ gt_qid2window = dict()
+ for d in annos:
+ cur_gt_windows = d['relevant_windows']
+ cur_qid = d['qid']
+ cur_max_iou_idx = 0
+ if len(cur_gt_windows) > 0:
+ cur_ious = compute_temporal_iou_batch_cross(
+ np.array([pred_qid2window[cur_qid]]), np.array(d['relevant_windows']))[0]
+ cur_max_iou_idx = np.argmax(cur_ious)
+ gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_idx]
+
+ qids = list(pred_qid2window.keys())
+ pred_windows = np.array([pred_qid2window[k] for k in qids]).astype(float)
+ gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float)
+ pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows)
+ iou_thd2recall_at_one = dict()
+ miou_at_one = float(f'{np.mean(pred_gt_iou) * 100:.2f}')
+ for thd in iou_thds:
+ iou_thd2recall_at_one[str(thd)] = float(f'{np.mean(pred_gt_iou >= thd) * 100:.2f}')
+ return iou_thd2recall_at_one, miou_at_one
+
+
+def compute_mr_r5(preds, annos, iou_thds=np.linspace(0.3, 0.95, 14)):
+ iou_thds = [float(f'{e:.2f}') for e in iou_thds]
+ pred_qid2window = {d['qid']: [x[:2] for x in d['pred_relevant_windows'][:5]] for d in preds}
+ gt_qid2window = dict()
+ pred_optimal_qid2window = dict()
+ for d in annos:
+ cur_gt_windows = d['relevant_windows']
+ cur_qid = d['qid']
+ cur_max_iou_pred = 0
+ cur_max_iou_gt = 0
+ if len(cur_gt_windows) > 0:
+ cur_ious = compute_temporal_iou_batch_cross(
+ np.array(pred_qid2window[cur_qid]), np.array(d['relevant_windows']))[0]
+ cur_ious[np.isnan(cur_ious)] = 0
+ cur_max_iou_pred, cur_max_iou_gt = np.where(cur_ious == np.max(cur_ious))
+ cur_max_iou_pred, cur_max_iou_gt = cur_max_iou_pred[0], cur_max_iou_gt[0]
+ pred_optimal_qid2window[cur_qid] = pred_qid2window[cur_qid][cur_max_iou_pred]
+ gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_gt]
+
+ qids = list(pred_qid2window.keys())
+ pred_windows = np.array([pred_optimal_qid2window[k] for k in qids]).astype(float)
+ gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float)
+ pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows)
+ iou_thd2recall_at_one = dict()
+ for thd in iou_thds:
+ iou_thd2recall_at_one[str(thd)] = float(f'{np.mean(pred_gt_iou >= thd) * 100:.2f}')
+ return iou_thd2recall_at_one
+
+
+def get_data_by_range(preds, annos, len_range):
+ min_l, max_l = len_range
+ if min_l == 0 and max_l == float('inf'):
+ return preds, annos
+
+ ground_truth_in_range = []
+ gt_qids_in_range = set()
+ for d in annos:
+ rel_windows_in_range = [w for w in d['relevant_windows'] if min_l < (w[1] - w[0]) <= max_l]
+ if len(rel_windows_in_range) > 0:
+ d = copy.deepcopy(d)
+ d['relevant_windows'] = rel_windows_in_range
+ ground_truth_in_range.append(d)
+ gt_qids_in_range.add(d['qid'])
+
+ submission_in_range = []
+ for d in preds:
+ if d['qid'] in gt_qids_in_range:
+ submission_in_range.append(copy.deepcopy(d))
+
+ if submission_in_range == ground_truth_in_range == []:
+ return preds, annos
+
+ return submission_in_range, ground_truth_in_range
+
+
+def eval_moment_retrieval(preds, annos):
+ length_ranges = [[0, 10], [10, 30], [30, float('inf')], [0, float('inf')]]
+ range_names = ['short', 'middle', 'long', 'full']
+
+ ret_metrics = dict()
+ for l_range, name in zip(length_ranges, range_names):
+ _submission, _ground_truth = get_data_by_range(preds, annos, l_range)
+ print(f'{name}: {l_range}, {len(_ground_truth)}/{len(annos)}={100*len(_ground_truth)/len(annos):.2f} samples')
+ iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth)
+ iou_thd2recall_at_one, miou_at_one = compute_mr_r1(_submission, _ground_truth)
+ iou_thd2recall_at_five = compute_mr_r5(_submission, _ground_truth)
+ ret_metrics[name] = {
+ 'MR-mIoU': miou_at_one,
+ 'MR-mAP': iou_thd2average_precision,
+ 'MR-R1': iou_thd2recall_at_one,
+ 'MR-R5': iou_thd2recall_at_five
+ }
+
+ return ret_metrics
+
+
+def compute_hl_hit1(qid2preds, qid2gt_scores_binary):
+ qid2max_scored_clip_idx = {k: np.argmax(v['pred_saliency_scores']) for k, v in qid2preds.items()}
+ hit_scores = np.zeros((len(qid2preds), 3))
+ qids = list(qid2preds.keys())
+ for idx, qid in enumerate(qids):
+ pred_clip_idx = qid2max_scored_clip_idx[qid]
+ gt_scores_binary = qid2gt_scores_binary[qid]
+ if pred_clip_idx < len(gt_scores_binary):
+ hit_scores[idx] = gt_scores_binary[pred_clip_idx]
+ hit_at_one = float(f'{100 * np.mean(np.max(hit_scores, 1)):.2f}')
+ return hit_at_one
+
+
+def compute_hl_ap(qid2preds, qid2gt_scores_binary):
+ qid2pred_scores = {k: v['pred_saliency_scores'] for k, v in qid2preds.items()}
+ ap_scores = np.zeros((len(qid2preds), 3))
+ qids = list(qid2preds.keys())
+ input_tuples = []
+ for idx, qid in enumerate(qids):
+ for w_idx in range(3):
+ y_true = qid2gt_scores_binary[qid][:, w_idx]
+ y_pred = np.array(qid2pred_scores[qid])
+ input_tuples.append((idx, w_idx, y_true, y_pred))
+
+ for input_tuple in input_tuples:
+ idx, w_idx, score = compute_ap_from_tuple(input_tuple)
+ ap_scores[idx, w_idx] = score
+
+ mean_ap = float(f'{100 * np.mean(ap_scores):.2f}')
+ return mean_ap
+
+
+def compute_ap_from_tuple(input_tuple):
+ idx, w_idx, y_true, y_pred = input_tuple
+ if len(y_true) < len(y_pred):
+ y_pred = y_pred[:len(y_true)]
+ elif len(y_true) > len(y_pred):
+ _y_predict = np.zeros(len(y_true))
+ _y_predict[:len(y_pred)] = y_pred
+ y_pred = _y_predict
+
+ score = get_ap(y_true, y_pred)
+ return idx, w_idx, score
+
+
+def mk_gt_scores(gt_data, clip_length=2):
+ num_clips = int(gt_data['duration'] / clip_length)
+ saliency_scores_full_video = np.zeros((num_clips, 3))
+ relevant_clip_ids = np.array(gt_data['relevant_clip_ids'])
+ saliency_scores_relevant_clips = np.array(gt_data['saliency_scores'])
+ saliency_scores_full_video[relevant_clip_ids] = saliency_scores_relevant_clips
+ return saliency_scores_full_video
+
+
+def eval_highlight(preds, annos):
+ qid2preds = {d['qid']: d for d in preds}
+ qid2gt_scores_full_range = {d['qid']: mk_gt_scores(d) for d in annos}
+ gt_saliency_score_min_list = [2, 3, 4]
+ saliency_score_names = ['Fair', 'Good', 'VeryGood']
+ highlight_det_metrics = dict()
+ for gt_saliency_score_min, score_name in zip(gt_saliency_score_min_list, saliency_score_names):
+ qid2gt_scores_binary = {
+ k: (v >= gt_saliency_score_min).astype(float)
+ for k, v in qid2gt_scores_full_range.items()
+ }
+ hit_at_one = compute_hl_hit1(qid2preds, qid2gt_scores_binary)
+ mean_ap = compute_hl_ap(qid2preds, qid2gt_scores_binary)
+ highlight_det_metrics[f'HL-min-{score_name}'] = {'HL-mAP': mean_ap, 'HL-Hit1': hit_at_one}
+ return highlight_det_metrics
+
+
+def qvhighlights_eval(preds, annos):
+ pred_qids = set([e['qid'] for e in preds])
+ gt_qids = set([e['qid'] for e in annos])
+ assert pred_qids == gt_qids, 'qids in annos and preds must match'
+
+ eval_metrics = dict()
+ eval_metrics_brief = OrderedDict()
+ if 'pred_relevant_windows' in preds[0]:
+ moment_ret_scores = eval_moment_retrieval(preds, annos)
+ eval_metrics.update(moment_ret_scores)
+ moment_ret_scores_brief = {
+ 'MR-full-mAP': moment_ret_scores['full']['MR-mAP']['average'],
+ 'MR-full-mAP@0.5': moment_ret_scores['full']['MR-mAP']['0.5'],
+ 'MR-full-mAP@0.75': moment_ret_scores['full']['MR-mAP']['0.75'],
+ 'MR-short-mAP': moment_ret_scores['short']['MR-mAP']['average'],
+ 'MR-middle-mAP': moment_ret_scores['middle']['MR-mAP']['average'],
+ 'MR-long-mAP': moment_ret_scores['long']['MR-mAP']['average'],
+ 'MR-short-mIoU': moment_ret_scores['short']['MR-mIoU'],
+ 'MR-middle-mIoU': moment_ret_scores['middle']['MR-mIoU'],
+ 'MR-long-mIoU': moment_ret_scores['long']['MR-mIoU'],
+ 'MR-full-mIoU': moment_ret_scores['full']['MR-mIoU'],
+ 'MR-full-R1@0.3': moment_ret_scores['full']['MR-R1']['0.3'],
+ 'MR-full-R1@0.5': moment_ret_scores['full']['MR-R1']['0.5'],
+ 'MR-full-R1@0.7': moment_ret_scores['full']['MR-R1']['0.7'],
+ 'MR-full-R5@0.3': moment_ret_scores['full']['MR-R5']['0.3'],
+ 'MR-full-R5@0.5': moment_ret_scores['full']['MR-R5']['0.5'],
+ 'MR-full-R5@0.7': moment_ret_scores['full']['MR-R5']['0.7']
+ }
+ eval_metrics_brief.update(sorted([(k, v) for k, v in moment_ret_scores_brief.items()], key=lambda x: x[0]))
+
+ if ('pred_saliency_scores' in preds[0]) and ('saliency_scores' in annos[0]):
+ if isinstance(annos[0]['saliency_scores'], list):
+ highlight_det_scores = eval_highlight(preds, annos)
+ eval_metrics.update(highlight_det_scores)
+ highlight_det_scores_brief = dict([(f"{k}-{sub_k.split('-')[1]}", v[sub_k])
+ for k, v in highlight_det_scores.items() for sub_k in v])
+ eval_metrics_brief.update(highlight_det_scores_brief)
+ eval_metrics_brief['HL-min-VeryGood-mAP'] = eval_metrics_brief.pop('HL-min-VeryGood-mAP')
+ eval_metrics_brief['HL-min-VeryGood-Hit1'] = eval_metrics_brief.pop('HL-min-VeryGood-Hit1')
+
+ final_eval_metrics = OrderedDict()
+ final_eval_metrics['brief'] = eval_metrics_brief
+ final_eval_metrics.update(sorted([(k, v) for k, v in eval_metrics.items()], key=lambda x: x[0]))
+ return final_eval_metrics
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('pred_path')
+ parser.add_argument('--anno_path', default='data/qvhighlights/highlight_val_release.jsonl')
+ parser.add_argument('--out_name', default='metrics.log')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ if nncore.is_dir(args.pred_path):
+ log_file = nncore.join(args.pred_path, args.out_name)
+ else:
+ log_file = nncore.same_dir(args.pred_path, args.out_name)
+
+ nncore.set_default_logger(logger='eval', fmt=None, log_file=log_file)
+
+ if nncore.is_dir(args.pred_path):
+ pred_paths = nncore.ls(args.pred_path, ext=['json', 'jsonl'], join_path=True, sort=True)
+ nncore.log(f'Total number of files: {len(pred_paths)}\n')
+ preds = nncore.flatten([nncore.load(p) for p in pred_paths])
+ else:
+ nncore.log(f'Loading predictions from {args.pred_path}')
+ preds = nncore.load(args.pred_path)
+
+ annos = nncore.load(args.anno_path)
+
+ res = qvhighlights_eval(preds, annos)['brief']
+ for k, v in res.items():
+ nncore.log(f'{k}: {v}')
diff --git a/videomind/eval/infer_auto.py b/videomind/eval/infer_auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5b54526f78d74e3913b6334a16596b3ae61d8aa
--- /dev/null
+++ b/videomind/eval/infer_auto.py
@@ -0,0 +1,437 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import argparse
+import copy
+import json
+from contextlib import nullcontext
+
+import nncore
+import torch
+
+from videomind.constants import GROUNDER_PROMPT, PLANNER_PROMPT, VERIFIER_PROMPT
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.utils import process_vision_info
+from videomind.model.builder import build_model
+from videomind.utils.io import get_duration, load_subtitle
+from videomind.utils.parser import parse_query, parse_span
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset')
+ parser.add_argument('--pred_path')
+ parser.add_argument('--model_gnd_path')
+ parser.add_argument('--model_ver_path')
+ parser.add_argument('--model_pla_path')
+ parser.add_argument('--model_ans_path')
+ parser.add_argument('--split', default='test', choices=['train', 'valid', 'test'])
+ parser.add_argument('--style', default='mcq', choices=['mcq', 'options', 'direct'])
+ parser.add_argument('--use_subtitle', action='store_true')
+ parser.add_argument('--auto_rephrasing', action='store_true')
+ parser.add_argument('--auto_planning', action='store_true')
+ parser.add_argument('--num_threads', type=int, default=1)
+ parser.add_argument('--device', default='auto')
+ parser.add_argument('--chunk', type=int, default=1)
+ parser.add_argument('--index', type=int, default=0)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ if args.chunk > 1:
+ pred_path = nncore.join(args.pred_path, f'output_{args.index}.json')
+ else:
+ pred_path = nncore.join(args.pred_path, 'output.json')
+
+ print(f'Dataset: {args.dataset}({args.split}) Chunk: {args.chunk} Index: {args.index} Output Path: {pred_path}')
+
+ # NOTE:
+ # 1. grounder is always true so no need to store
+ # 2. answerer would always be used (when set to false, the base model would be used as the answerer)
+ adapter_state = dict(planner=False, verifier=False, answerer=False)
+
+ print('Initializing role *grounder*')
+ model, processor = build_model(args.model_gnd_path, device=args.device)
+ device = next(model.parameters()).device
+
+ if args.model_pla_path is not None:
+ adapter_path = nncore.join(args.model_pla_path, 'planner')
+ if nncore.is_dir(adapter_path):
+ print('Initializing role *planner*')
+ model.load_adapter(adapter_path, adapter_name='planner')
+ adapter_state['planner'] = True
+
+ if args.model_ver_path is not None:
+ adapter_path = nncore.join(args.model_ver_path, 'verifier')
+ if nncore.is_dir(adapter_path):
+ print('Initializing role *verifier*')
+ model.load_adapter(adapter_path, adapter_name='verifier')
+ adapter_state['verifier'] = True
+
+ if args.model_ans_path is not None:
+ adapter_path = nncore.join(args.model_ans_path, 'answerer')
+ if nncore.is_dir(adapter_path):
+ print('Initializing role *answerer*')
+ model.load_adapter(adapter_path, adapter_name='answerer')
+ adapter_state['answerer'] = True
+
+ annos = DATASETS.get(args.dataset).load_annos(split=args.split)
+ annos = [annos[i::args.chunk] for i in range(args.chunk)][args.index]
+
+ dumps = []
+ for i in nncore.ProgressBar(range(len(annos))):
+ anno = copy.deepcopy(annos[i])
+ dump = copy.deepcopy(annos[i])
+
+ video_path, duration, span = anno['video_path'], anno.get('duration'), anno.get('span')
+
+ if duration is None:
+ duration = get_duration(video_path, num_threads=args.num_threads)
+ dump['duration'] = duration
+
+ print()
+ print(video_path)
+ print(duration)
+
+ # sometimes the sample is for grounding only
+ do_answering = all(k in anno for k in ('question', 'options'))
+
+ if do_answering:
+ question, options, ans = anno['question'], anno['options'], anno['ans']
+
+ if args.style in ('mcq', 'options'):
+ prompt = question + '\nOptions:'
+ for idx, opt in enumerate(options):
+ prompt += f"\n({chr(ord('A') + idx)}) {opt.capitalize()}"
+ prompt += '\nPlease only give the best option.'
+ else:
+ prompt = question
+
+ print(prompt)
+ print(options)
+ print(ans)
+ else:
+ question = anno['query']
+ print(question)
+
+ # do grounding by default
+ do_grounding = True
+
+ # initialize grounding query as question
+ query = question
+
+ # initialize agent list
+ dump['agents'] = []
+
+ if adapter_state['planner'] and (args.auto_rephrasing or args.auto_planning):
+ print('=============== planner ===============')
+
+ dump['agents'].append('planner')
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'num_threads': args.num_threads,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 100,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': PLANNER_PROMPT.format(question)
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ print(text)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('planner')
+
+ output_ids = model.generate(
+ **data,
+ do_sample=False,
+ temperature=None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=256)
+
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
+ output_ids = output_ids[0, data.input_ids.size(1):]
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
+ output_ids = output_ids[:-1]
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
+ print(response)
+
+ dump['planner_response'] = response
+
+ try:
+ parsed = json.loads(response)
+ action = parsed[0] if isinstance(parsed, list) else parsed
+ if args.auto_rephrasing and action['type'].lower() == 'grounder' and action['value']:
+ query = action['value']
+ dump['planner_parsed_query'] = query
+ elif args.auto_planning and action['type'].lower() == 'answerer':
+ do_grounding = False
+ except Exception:
+ print('WARNING: Failed to parse planner response')
+
+ if do_grounding:
+ print('=============== grounder ===============')
+
+ dump['agents'].append('grounder')
+
+ query = parse_query(query)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'num_threads': args.num_threads,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 150,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': GROUNDER_PROMPT.format(query)
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ print(text)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('grounder')
+
+ output_ids = model.generate(
+ **data,
+ do_sample=False,
+ temperature=None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=256)
+
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
+ output_ids = output_ids[0, data.input_ids.size(1):]
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
+ output_ids = output_ids[:-1]
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
+ print(response)
+
+ dump['grounder_response'] = response
+ dump['grounder_success'] = len(model.reg) > 0
+
+ if dump['grounder_success']:
+ # 1. extract timestamps and confidences
+ blob = model.reg[0].cpu().float()
+ pred, conf = blob[:, :2] * duration, blob[:, -1].tolist()
+
+ # 2. clamp timestamps
+ pred = pred.clamp(min=0, max=duration)
+
+ # 3. round timestamps to units
+ unit = getattr(DATASETS.get(args.dataset), 'UNIT', 0.001)
+ pred = torch.round(pred / unit).long() * unit
+
+ # 4. sort timestamps
+ inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
+ pred[inds] = pred[inds].roll(1)
+
+ # 5. convert timestamps to list
+ pred = pred.tolist()
+ else:
+ print('WARNING: Failed to parse grounder response')
+
+ if adapter_state['verifier']:
+ pred = [[i * duration / 6, (i + 2) * duration / 6] for i in range(5)]
+ conf = [0] * 5
+ else:
+ pred = [[0, duration]]
+ conf = [0]
+
+ print(pred[0], span, duration)
+ dump['pred'] = pred
+ dump['conf'] = conf
+
+ if do_grounding and adapter_state['verifier'] and len(pred) > 1:
+ print('=============== verifier ===============')
+
+ dump['agents'].append('verifier')
+
+ # using top-5 predictions
+ probs = []
+ for cand in pred[:5]:
+ s0, e0 = parse_span(cand, duration, 2)
+ offset = (e0 - s0) / 2
+ s1, e1 = parse_span([s0 - offset, e0 + offset], duration)
+
+ # percentage of s0, e0 within s1, e1
+ s = (s0 - s1) / (e1 - s1)
+ e = (e0 - s1) / (e1 - s1)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'num_threads': args.num_threads,
+ 'video_start': s1,
+ 'video_end': e1,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 64,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': VERIFIER_PROMPT.format(question)
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ print(text)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+
+ # ===== insert segment start/end tokens =====
+ video_grid_thw = data['video_grid_thw'][0]
+ num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4)
+ assert num_frames * window * 4 == data['pixel_values_videos'].size(0)
+
+ pos_s, pos_e = round(s * num_frames), round(e * num_frames)
+ pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames)
+ assert pos_s <= pos_e, (num_frames, s, e)
+
+ base_idx = torch.nonzero(data['input_ids'][0] == model.config.vision_start_token_id).item()
+ pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2
+
+ input_ids = data['input_ids'][0].tolist()
+ input_ids.insert(pos_s, model.config.seg_s_token_id)
+ input_ids.insert(pos_e, model.config.seg_e_token_id)
+ data['input_ids'] = torch.LongTensor([input_ids])
+ data['attention_mask'] = torch.ones_like(data['input_ids'])
+ # ===========================================
+
+ data = data.to(device)
+
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('verifier')
+
+ with torch.inference_mode():
+ logits = model(**data).logits[0, -1].softmax(dim=-1)
+
+ # NOTE: magic numbers here
+ # In Qwen2-VL vocab: 9454 -> Yes, 2753 -> No
+ score = (logits[9454] - logits[2753]).sigmoid().item()
+ probs.append(score)
+
+ ranks = torch.Tensor(probs).argsort(descending=True).tolist()
+ print(probs)
+ print(ranks)
+
+ pred = [pred[idx] for idx in ranks]
+ conf = [conf[idx] for idx in ranks]
+ print(pred[0], span, duration)
+
+ dump['probs'] = probs
+ dump['ranks'] = ranks
+ dump['pred_ori'] = dump['pred']
+ dump['conf_ori'] = dump['conf']
+ dump['pred'] = pred
+ dump['conf'] = conf
+
+ if do_answering:
+ print('=============== answerer ===============')
+
+ dump['agents'].append('answerer')
+
+ # choose the potential best moment
+ selected = pred[0] if 'pred' in dump else [0, duration]
+
+ min_len = getattr(DATASETS.get(args.dataset), 'MIN_LEN', 32)
+ s, e = parse_span(selected, duration, min_len)
+ print([s, e], span, duration)
+
+ if args.use_subtitle and 'subtitle_path' in anno and nncore.is_file(anno['subtitle_path']):
+ # use only the first 100 subtitles to save memory
+ subs = load_subtitle(anno['subtitle_path'])[:100]
+ subs = [f'{round(a - s, 1)}s - {round(b - s, 1)}s, {t}\n' for a, b, t in subs if a >= s and b <= e]
+ subs = ''.join(subs)
+ prompt = f'You are given a video with {round(e - s, 1)} seconds long.\nSubtitles:\n{subs}' + prompt
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'num_threads': args.num_threads,
+ 'video_start': s,
+ 'video_end': e,
+ 'min_pixels': 128 * 28 * 28,
+ 'max_pixels': 256 * 28 * 28,
+ 'max_frames': 32,
+ 'fps': 2.0
+ }, {
+ 'type': 'text',
+ 'text': prompt
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ text += 'Best Option: (' if args.style == 'mcq' else ''
+ print(text)
+ images, videos = process_vision_info(messages)
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ if adapter_state['answerer']:
+ model.base_model.disable_adapter_layers()
+ model.base_model.enable_adapter_layers()
+ model.set_adapter('answerer')
+ context = nullcontext
+ else:
+ context = model.disable_adapter
+
+ with context():
+ output_ids = model.generate(
+ **data,
+ do_sample=False,
+ temperature=None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=256)
+
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
+ output_ids = output_ids[0, data.input_ids.size(1):]
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
+ output_ids = output_ids[:-1]
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
+ print(response)
+
+ dump['answerer_response'] = response
+ dump['response'] = response
+
+ dumps.append(dump)
+
+ nncore.dump(dumps, pred_path)
diff --git a/videomind/eval/infer_qvhighlights.py b/videomind/eval/infer_qvhighlights.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32c60da79f4307ca875eb50ff427800487eca30
--- /dev/null
+++ b/videomind/eval/infer_qvhighlights.py
@@ -0,0 +1,137 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import argparse
+import copy
+
+import nncore
+import torch
+
+from videomind.constants import GROUNDER_PROMPT
+from videomind.dataset.hybrid import DATASETS
+from videomind.dataset.utils import process_vision_info
+from videomind.model.builder import build_model
+from videomind.utils.io import get_duration
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dataset')
+ parser.add_argument('--pred_path')
+ parser.add_argument('--model_gnd_path')
+ parser.add_argument('--split', default='test', choices=['train', 'valid', 'test'])
+ parser.add_argument('--num_threads', type=int, default=1)
+ parser.add_argument('--device', default='auto')
+ parser.add_argument('--chunk', type=int, default=1)
+ parser.add_argument('--index', type=int, default=0)
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ if args.chunk > 1:
+ pred_path = nncore.join(args.pred_path, f'output_{args.index}.jsonl')
+ else:
+ pred_path = nncore.join(args.pred_path, 'output.jsonl')
+
+ print(f'Dataset: {args.dataset}({args.split}) Chunk: {args.chunk} Index: {args.index} Output Path: {pred_path}')
+
+ model, processor = build_model(args.model_gnd_path, device=args.device)
+ device = next(model.parameters()).device
+
+ annos = DATASETS.get(args.dataset).load_annos(split=args.split)
+ annos = [annos[i::args.chunk] for i in range(args.chunk)][args.index]
+
+ dumps = []
+ for i in nncore.ProgressBar(range(len(annos))):
+ anno = copy.deepcopy(annos[i])
+ dump = dict()
+
+ video_path, query, duration, span = anno['video_path'], anno['query'], anno.get('duration'), anno.get('span')
+
+ if duration is None:
+ duration = get_duration(video_path, num_threads=args.num_threads)
+
+ print()
+ print(video_path)
+ print(duration)
+ print(query)
+
+ messages = [{
+ 'role':
+ 'user',
+ 'content': [{
+ 'type': 'video',
+ 'video': video_path,
+ 'num_threads': args.num_threads,
+ 'min_pixels': 36 * 28 * 28,
+ 'max_pixels': 64 * 28 * 28,
+ 'max_frames': 150,
+ 'fps': 1.0
+ }, {
+ 'type': 'text',
+ 'text': GROUNDER_PROMPT.format(query)
+ }]
+ }]
+
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
+ print(text)
+
+ images, videos = process_vision_info(messages)
+
+ data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
+ data = data.to(device)
+
+ output_ids = model.generate(
+ **data,
+ do_sample=False,
+ temperature=None,
+ top_p=None,
+ top_k=None,
+ repetition_penalty=None,
+ max_new_tokens=256)
+
+ assert data.input_ids.size(0) == output_ids.size(0) == 1
+ output_ids = output_ids[0, data.input_ids.size(1):]
+
+ if output_ids[-1] == processor.tokenizer.eos_token_id:
+ output_ids = output_ids[:-1]
+
+ response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
+ print(response)
+
+ grounder_success = len(model.reg) > 0
+
+ if grounder_success:
+ # 1. extract timestamps and confidences
+ blob = model.reg[0].cpu().float()
+ pred, conf = blob[:, :2] * duration, blob[:, 2:]
+ print(pred[0], span, duration)
+
+ # 2. clamp timestamps
+ pred = pred.clamp(min=0, max=duration)
+
+ # 3. round timestamps to units
+ unit = getattr(DATASETS.get(args.dataset), 'UNIT', 0.001)
+ pred = torch.round(pred / unit).long() * unit
+
+ # 4. sort timestamps
+ inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
+ pred[inds] = pred[inds].roll(1)
+
+ # 5. merge timestamps back with confidences
+ pred = torch.cat((pred, conf), dim=1)
+ else:
+ print('WARNING: Failed to parse grounder response')
+ pred = torch.Tensor([[0, duration, 1]])
+
+ print(pred[0], span, duration)
+
+ dump['vid'] = anno['vid']
+ dump['qid'] = anno['qid']
+ dump['pred_relevant_windows'] = pred.tolist()
+
+ dumps.append(dump)
+
+ nncore.dump(dumps, pred_path)
diff --git a/videomind/model/__init__.py b/videomind/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbdbfae1e29d7db467f06b152a9efe4b23026691
--- /dev/null
+++ b/videomind/model/__init__.py
@@ -0,0 +1,3 @@
+from .model import AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration
+
+MODELS = {'qwen2_vl': (AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration)}
diff --git a/videomind/model/blocks.py b/videomind/model/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..c439be495330032effdf04108f65d378613d2877
--- /dev/null
+++ b/videomind/model/blocks.py
@@ -0,0 +1,93 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from nncore.nn import Parameter
+
+
+class Permute(nn.Module):
+
+ def forward(self, x):
+ return x.transpose(-1, -2)
+
+
+class LearnableEmbedding(nn.Module):
+
+ def __init__(self, dims):
+ super().__init__()
+ self.weights = Parameter(1, 1, dims)
+
+ def forward(self, x):
+ return x + self.weights.expand_as(x)
+
+
+class ConvPyramid(nn.Module):
+
+ def __init__(self, dims, strides, act_cls=nn.ReLU):
+ super().__init__()
+
+ self.blocks = nn.ModuleList()
+ for s in strides:
+ p = int(math.log2(s))
+ if p == 0:
+ layers = act_cls()
+ else:
+ conv_cls = nn.Conv1d if p > 0 else nn.ConvTranspose1d
+ layers = nn.Sequential()
+ for _ in range(abs(p)):
+ module = [Permute(), conv_cls(dims, dims, 2, stride=2), Permute(), nn.LayerNorm(dims), act_cls()]
+ layers.extend(module)
+ self.blocks.append(layers)
+
+ self.strides = strides
+
+ def forward(self, x, mask, return_mask=False):
+ pymid, pymid_msk = [], []
+
+ for s, blk in zip(self.strides, self.blocks):
+ if x.size(1) < s:
+ continue
+
+ pymid.append(blk(x))
+
+ if return_mask:
+ if s > 1:
+ msk = F.max_pool1d(mask.float(), s, stride=s).long()
+ elif s < 1:
+ msk = mask.repeat_interleave(int(1 / s), dim=1)
+ else:
+ msk = mask
+ pymid_msk.append(msk)
+
+ return (pymid, pymid_msk) if return_mask else pymid
+
+
+class Scale(nn.Module):
+
+ def __init__(self, strides):
+ super().__init__()
+ self.scale = nn.Parameter(torch.ones(len(strides)))
+
+ def forward(self, x, i):
+ return x * self.scale[i]
+
+
+class ConvHead(nn.Module):
+
+ def __init__(self, dims, out_dims, kernal_size=3, act_cls=nn.ReLU):
+ super().__init__()
+
+ # yapf:disable
+ self.module = nn.Sequential(
+ Permute(),
+ nn.Conv1d(dims, dims, kernal_size, padding=kernal_size // 2),
+ act_cls(),
+ nn.Conv1d(dims, out_dims, kernal_size, padding=kernal_size // 2),
+ Permute())
+ # yapf:enable
+
+ def forward(self, x):
+ return self.module(x)
diff --git a/videomind/model/builder.py b/videomind/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbc6a2442b0b156831f4cd586b6e9c367ac4fefd
--- /dev/null
+++ b/videomind/model/builder.py
@@ -0,0 +1,108 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import warnings
+
+import nncore
+import torch
+import torch.nn as nn
+from peft import PeftModel
+from safetensors.torch import load_model
+from transformers import AutoConfig, AutoModel, AutoProcessor, GenerationConfig, Qwen2VLForConditionalGeneration
+
+
+def get_auto_device(device):
+ try:
+ import torch_npu
+ has_npu = torch_npu.npu.is_available()
+ except ImportError:
+ has_npu = False
+
+ return 'cuda' if torch.cuda.is_available() else 'npu' if has_npu else 'cpu'
+
+
+def build_model(model_path, config=None, is_trainable=False, merge_adapter=False, device='auto', dtype=torch.float16):
+ # set do_resize to false to avoid duplicated resizing
+ # https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
+ processor = AutoProcessor.from_pretrained(model_path, do_resize=False)
+
+ # eager attention has known & unknown bugs
+ # [4.46.2] broken causality fp16: https://github.com/huggingface/transformers/issues/35151
+ # [4.48.1] broken sliding window: https://github.com/huggingface/transformers/issues/35924
+ attn_implementation = 'sdpa'
+
+ config = config or AutoConfig.from_pretrained(model_path)
+
+ adapter_path = nncore.join(model_path, getattr(config, 'role', 'unknown'))
+ partial_path = nncore.join(model_path, 'pytorch_model.safetensors')
+
+ if nncore.is_dir(adapter_path) or nncore.is_file(partial_path):
+ print(f'Loading base model from {config.base_model_path}...')
+ model = AutoModel.from_pretrained(
+ config.base_model_path,
+ config=config,
+ low_cpu_mem_usage=True,
+ ignore_mismatched_sizes=True,
+ attn_implementation=attn_implementation,
+ torch_dtype=dtype)
+
+ try:
+ model.generation_config = GenerationConfig.from_pretrained(model_path)
+ except OSError:
+ warnings.warn('generation_config.json not found')
+
+ meta_state_dict = {
+ n: torch.empty_like(p, device='cpu')
+ for n, p in model.named_parameters() if p.device == torch.device('meta')
+ }
+ model.load_state_dict(meta_state_dict, strict=False, assign=True)
+
+ size = (model.model.embed_tokens.num_embeddings, model.model.embed_tokens.embedding_dim)
+ if model.model.embed_tokens.weight.size() != size:
+ print(f'Resizing embed_tokens to {size}...')
+ model.model.embed_tokens.weight = nn.Parameter(model.model.embed_tokens.weight.new_empty(size))
+
+ size = (model.lm_head.out_features, model.lm_head.in_features)
+ if model.lm_head.weight.size() != size:
+ print(f'Resizing lm_head to {size}...')
+ model.lm_head.weight = nn.Parameter(model.lm_head.weight.new_empty(size))
+
+ if nncore.is_dir(adapter_path):
+ print(f'Loading adapter from {adapter_path}...')
+ # transformers integration does not support merge_and_unload, use peft instead
+ model = PeftModel.from_pretrained(
+ model,
+ adapter_path,
+ adapter_name=config.role,
+ is_trainable=is_trainable,
+ low_cpu_mem_usage=True,
+ torch_device=str(model.device))
+
+ if nncore.is_file(partial_path):
+ print(f'Loading state dict from {partial_path}...')
+ _, unexpected = load_model(model, partial_path, strict=False, device=str(model.device))
+ assert len(unexpected) == 0, f'unexpected parameters: {unexpected}'
+
+ if merge_adapter and nncore.is_dir(adapter_path):
+ print('Merging adapter and unloading...')
+ model = model.merge_and_unload()
+ model._hf_peft_config_loaded = False
+ else:
+ print(f'Loading full model from {model_path}...')
+
+ if len(config.architectures) == 1 and config.model_type == 'qwen2_vl':
+ model_cls = Qwen2VLForConditionalGeneration
+ else:
+ model_cls = AutoModel
+
+ model = model_cls.from_pretrained(
+ model_path,
+ config=config,
+ low_cpu_mem_usage=True,
+ attn_implementation=attn_implementation,
+ torch_dtype=dtype)
+
+ if not is_trainable:
+ device = get_auto_device(device) if device == 'auto' else device
+ model = model.to(device).eval()
+
+ return model, processor
diff --git a/videomind/model/generator.py b/videomind/model/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6666f7df404a8081f6a73e1b669809ee95e635f
--- /dev/null
+++ b/videomind/model/generator.py
@@ -0,0 +1,66 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import torch
+import torch.nn as nn
+
+
+class BufferList(nn.Module):
+
+ def __init__(self, buffers):
+ super(BufferList, self).__init__()
+ for i, buffer in enumerate(buffers):
+ self.register_buffer(str(i), buffer, persistent=False)
+
+ def __len__(self):
+ return len(self._buffers)
+
+ def __iter__(self):
+ return iter(self._buffers.values())
+
+
+class PointGenerator(nn.Module):
+
+ def __init__(self, strides, buffer_size, offset=False):
+ super(PointGenerator, self).__init__()
+
+ reg_range, last = [], 0
+ for stride in strides[1:]:
+ reg_range.append((last, stride))
+ last = stride
+ reg_range.append((last, float('inf')))
+
+ self.strides = strides
+ self.reg_range = reg_range
+ self.buffer_size = buffer_size
+ self.offset = offset
+
+ self.buffer = self._cache_points()
+
+ def _cache_points(self):
+ buffer_list = []
+ for stride, reg_range in zip(self.strides, self.reg_range):
+ reg_range = torch.Tensor([reg_range])
+ lv_stride = torch.Tensor([stride])
+ points = torch.arange(0, self.buffer_size, stride)[:, None]
+ if self.offset:
+ points += 0.5 * stride
+ reg_range = reg_range.repeat(points.size(0), 1)
+ lv_stride = lv_stride.repeat(points.size(0), 1)
+ buffer_list.append(torch.cat((points, reg_range, lv_stride), dim=1))
+ buffer = BufferList(buffer_list)
+ return buffer
+
+ def forward(self, pymid):
+ assert self.strides[0] == 1
+ # video_size = pymid[0].size(1)
+ points = []
+ sizes = [p.size(1) for p in pymid] + [0] * (len(self.buffer) - len(pymid))
+ for size, buffer in zip(sizes, self.buffer):
+ if size == 0:
+ continue
+ assert size <= buffer.size(0), 'reached max buffer size'
+ point = buffer[:size, :].clone()
+ # point[:, 0] /= video_size
+ points.append(point)
+ points = torch.cat(points)
+ return points
diff --git a/videomind/model/loss.py b/videomind/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54f46371a991919266b9b1f9f822b2f0294804e
--- /dev/null
+++ b/videomind/model/loss.py
@@ -0,0 +1,177 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from nncore.nn import LOSSES, Parameter, build_loss
+
+
+@LOSSES.register()
+class SampledNCELoss(nn.Module):
+
+ def __init__(self, temperature=0.07, max_scale=100, learnable=False, direction=('row', 'col'), loss_weight=1.0):
+ super().__init__()
+
+ scale = torch.Tensor([math.log(1 / temperature)])
+
+ if learnable:
+ self.scale = Parameter(scale)
+ else:
+ self.register_buffer('scale', scale)
+
+ self.temperature = temperature
+ self.max_scale = max_scale
+ self.learnable = learnable
+ self.direction = (direction, ) if isinstance(direction, str) else direction
+ self.loss_weight = loss_weight
+
+ def forward(self, video_emb, query_emb, video_msk, saliency, pos_clip):
+ batch_inds = torch.arange(video_emb.size(0), device=video_emb.device)
+
+ pos_scores = saliency[batch_inds, pos_clip].unsqueeze(-1)
+ loss_msk = (saliency <= pos_scores) * video_msk
+ if not loss_msk.any():
+ warnings.warn(f'loss_msk is all zeros: {loss_msk} {saliency} {video_msk} {pos_clip}')
+
+ scale = self.scale.exp().clamp(max=self.max_scale)
+ i_sim = F.cosine_similarity(video_emb, query_emb, dim=-1) * scale
+ i_sim = i_sim + torch.where(loss_msk > 0, .0, float('-inf'))
+
+ loss = 0
+
+ if 'row' in self.direction:
+ i_met = F.log_softmax(i_sim, dim=1)[batch_inds, pos_clip]
+ loss = loss - i_met.sum()
+
+ if 'col' in self.direction:
+ j_met = F.log_softmax(i_sim.t(), dim=1)[pos_clip, batch_inds]
+ loss = loss - j_met.sum() / j_met.size(0)
+
+ loss = loss * self.loss_weight
+ return loss
+
+
+@LOSSES.register()
+class BundleLoss(nn.Module):
+
+ def __init__(self, sample_radius=1.5, loss_cls=None, loss_reg=None, loss_sal=None):
+ super().__init__()
+
+ self._loss_cls = build_loss(loss_cls)
+ self._loss_reg = build_loss(loss_reg)
+ self._loss_sal = build_loss(loss_sal)
+
+ self.sample_radius = sample_radius
+
+ def get_target_single(self, point, gt_bnd, gt_cls):
+ num_pts, num_gts = point.size(0), gt_bnd.size(0)
+
+ lens = gt_bnd[:, 1] - gt_bnd[:, 0]
+ lens = lens[None, :].repeat(num_pts, 1)
+
+ gt_seg = gt_bnd[None].expand(num_pts, num_gts, 2)
+ s = point[:, 0, None] - gt_seg[:, :, 0]
+ e = gt_seg[:, :, 1] - point[:, 0, None]
+ r_tgt = torch.stack((s, e), dim=-1)
+
+ if self.sample_radius > 0:
+ center = (gt_seg[:, :, 0] + gt_seg[:, :, 1]) / 2
+ t_mins = center - point[:, 3, None] * self.sample_radius
+ t_maxs = center + point[:, 3, None] * self.sample_radius
+ dist_s = point[:, 0, None] - torch.maximum(t_mins, gt_seg[:, :, 0])
+ dist_e = torch.minimum(t_maxs, gt_seg[:, :, 1]) - point[:, 0, None]
+ center = torch.stack((dist_s, dist_e), dim=-1)
+ cls_msk = center.min(-1)[0] >= 0
+ else:
+ cls_msk = r_tgt.min(-1)[0] >= 0
+
+ reg_dist = r_tgt.max(-1)[0]
+ reg_msk = torch.logical_and((reg_dist >= point[:, 1, None]), (reg_dist <= point[:, 2, None]))
+
+ lens.masked_fill_(cls_msk == 0, float('inf'))
+ lens.masked_fill_(reg_msk == 0, float('inf'))
+ min_len, min_len_inds = lens.min(dim=1)
+
+ min_len_mask = torch.logical_and((lens <= (min_len[:, None] + 1e-3)), (lens < float('inf'))).to(r_tgt.dtype)
+
+ label = F.one_hot(gt_cls[:, 0], 2).to(r_tgt.dtype)
+ c_tgt = torch.matmul(min_len_mask, label).clamp(min=0.0, max=1.0)[:, 1]
+ r_tgt = r_tgt[range(num_pts), min_len_inds] / point[:, 3, None]
+
+ return c_tgt, r_tgt
+
+ def get_target(self, data):
+ cls_tgt, reg_tgt = [], []
+
+ for i in range(data['boundary'].size(0)):
+ gt_bnd = data['boundary'][i] * data['video_emb'].size(1)
+ # gt_bnd = data['boundary'][i]
+ gt_cls = gt_bnd.new_ones(gt_bnd.size(0), 1).long()
+
+ c_tgt, r_tgt = self.get_target_single(data['point'], gt_bnd, gt_cls)
+
+ cls_tgt.append(c_tgt)
+ reg_tgt.append(r_tgt)
+
+ cls_tgt = torch.stack(cls_tgt)
+ reg_tgt = torch.stack(reg_tgt)
+
+ return cls_tgt, reg_tgt
+
+ def loss_cls(self, data, output, cls_tgt):
+ src = data['out_class'].squeeze(-1)
+ msk = torch.cat(data['pymid_msk'], dim=1)
+
+ cls_tgt = cls_tgt.repeat(src.size(0) // cls_tgt.size(0), 1)
+
+ loss_cls = self._loss_cls(src, cls_tgt, weight=msk)
+ loss_cls = (loss_cls.sum(dim=1) / msk.sum(dim=1)).sum()
+
+ output['loss_cls'] = loss_cls
+ return output
+
+ def loss_reg(self, data, output, cls_tgt, reg_tgt):
+ src = data['out_coord']
+ msk = cls_tgt.unsqueeze(2).repeat(1, 1, 2).bool()
+ assert msk.any(), 'empty mask in reg loss'
+
+ reg_tgt = reg_tgt.repeat(src.size(0) // reg_tgt.size(0), 1, 1)
+ msk = msk.repeat(src.size(0) // msk.size(0), 1, 1)
+
+ loss_reg = self._loss_reg(src, reg_tgt, weight=msk)
+ loss_reg = (loss_reg.sum(dim=[1, 2]) / msk.sum(dim=[1, 2])).sum()
+
+ output['loss_reg'] = loss_reg
+ return output
+
+ def loss_sal(self, data, output):
+ video_emb = data['video_emb']
+ query_emb = data['query_emb']
+ video_msk = data['video_msk']
+
+ saliency = data['saliency']
+ pos_clip = data['pos_clip'][:, 0]
+
+ saliency = saliency.repeat(video_emb.size(0) // saliency.size(0), 1)
+ pos_clip = pos_clip.repeat(video_emb.size(0) // pos_clip.size(0))
+
+ output['loss_sal'] = self._loss_sal(video_emb, query_emb, video_msk, saliency, pos_clip)
+ return output
+
+ def forward(self, data, output):
+ if self._loss_reg is not None:
+ cls_tgt, reg_tgt = self.get_target(data)
+ output = self.loss_reg(data, output, cls_tgt, reg_tgt)
+ else:
+ cls_tgt = data['saliency']
+
+ if self._loss_cls is not None:
+ output = self.loss_cls(data, output, cls_tgt)
+
+ if self._loss_sal is not None:
+ output = self.loss_sal(data, output)
+
+ return output
diff --git a/videomind/model/model.py b/videomind/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aaedf22558bbc47739bdc76f7830b3639c94c2c
--- /dev/null
+++ b/videomind/model/model.py
@@ -0,0 +1,358 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import warnings
+
+import nncore
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from nncore.nn import ModuleList, PositionalEncoding, Sequential, TransformerEncoderLayer, xavier_init_
+from nncore.ops import temporal_iou
+from transformers import AutoConfig, AutoModel, Qwen2VLConfig, Qwen2VLForConditionalGeneration, Qwen2VLModel
+from transformers.activations import ACT2CLS, ACT2FN
+from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
+from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
+
+from .blocks import ConvHead, ConvPyramid, LearnableEmbedding, Scale
+from .generator import PointGenerator
+from .loss import BundleLoss
+
+
+class AgentQwen2VLConfig(Qwen2VLConfig):
+ model_type = 'agent_qwen2_vl'
+
+
+class AgentQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel):
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.gradient_checkpointing = False
+
+ # add support for gradient checkpointing
+ # https://github.com/huggingface/transformers/pull/34724
+ def forward(self, hidden_states, grid_thw):
+ hidden_states = self.patch_embed(hidden_states)
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ for blk in self.blocks:
+ if self.gradient_checkpointing and self.training:
+ hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens,
+ rotary_pos_emb)
+ else:
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
+
+ return self.merger(hidden_states)
+
+
+class AgentQwen2VLModel(Qwen2VLModel):
+ config_class = AgentQwen2VLConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.norm.register_forward_pre_hook(lambda module, args: setattr(module, 'state', args[0]))
+
+ def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
+ # ensure gradient tracking (in case that embed_tokens has been frozen)
+ assert input_ids is None and inputs_embeds is not None
+ if self.training and not inputs_embeds.requires_grad:
+ inputs_embeds.requires_grad = True
+ return super().forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
+
+
+class AgentQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
+ config_class = AgentQwen2VLConfig
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.visual = AgentQwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
+ self.model = AgentQwen2VLModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.rope_deltas = None
+
+ if self.config.role in ('all_in_one', 'grounder'):
+ hidden_size, hidden_act = self.config.hidden_size, self.config.hidden_act
+
+ self.dims = 256
+
+ self.vis_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
+ self.reg_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
+ self.vis_norm = nn.LayerNorm(self.dims)
+ self.vis_fuse = ModuleList(
+ TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
+ TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
+ TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]))
+
+ self.vis_pos = PositionalEncoding(self.dims, normalize=True, learnable=False)
+ self.vis_emb = LearnableEmbedding(self.dims)
+ self.reg_emb = LearnableEmbedding(self.dims)
+
+ self.strides = (1, 2, 4, 8)
+ self.vis_pad_length = self.strides[-1]
+
+ self.pyramid = ConvPyramid(self.dims, self.strides, act_cls=ACT2CLS[hidden_act])
+ self.class_head = ConvHead(self.dims, 1, act_cls=ACT2CLS[hidden_act])
+ self.coord_head = ConvHead(self.dims, 2, act_cls=ACT2CLS[hidden_act])
+
+ self.generator = PointGenerator(self.strides, 1024)
+ self.coef = Scale(self.strides)
+ self.bundle_loss = BundleLoss(
+ sample_radius=1.5,
+ loss_cls=dict(type='FocalLoss', reduction='none', loss_weight=5.0),
+ loss_reg=dict(type='L1Loss', reduction='none', loss_weight=1.0),
+ loss_sal=dict(type='SampledNCELoss', direction='row', loss_weight=0.05))
+
+ self.post_init()
+
+ def reset_conv_parameters(self):
+ for s in ('pyramid', 'class_head', 'coord_head'):
+ b = getattr(self, s, None)
+ if b is None:
+ continue
+ for n, m in b.named_modules():
+ if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
+ print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})')
+ xavier_init_(m, distribution='uniform')
+
+ def forward(self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ past_key_values=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ pixel_values=None,
+ pixel_values_videos=None,
+ image_grid_thw=None,
+ video_grid_thw=None,
+ rope_deltas=None,
+ timestamps=None,
+ saliency=None,
+ pos_clip=None):
+ mode = 'training' if self.training else 'caching' if (
+ past_key_values is None or len(past_key_values) == 0) else 'generating'
+
+ # https://github.com/huggingface/transformers/pull/33487
+ if position_ids is None and input_ids is not None:
+ position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
+
+ if mode in ('training', 'caching'):
+ vision_s_inds = torch.nonzero(input_ids == self.config.vision_start_token_id).tolist()
+ vision_e_inds = torch.nonzero(input_ids == self.config.vision_end_token_id).tolist()
+ assert len(vision_s_inds) == len(vision_e_inds)
+
+ self.cache_vision_inds = [[] for _ in range(input_ids.size(0))]
+ for i in range(len(vision_s_inds)):
+ assert vision_s_inds[i][0] == vision_e_inds[i][0]
+ self.cache_vision_inds[vision_s_inds[i][0]].append([vision_s_inds[i][1] + 1, vision_e_inds[i][1]])
+
+ outputs = super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=not self.training,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=True,
+ pixel_values=pixel_values,
+ pixel_values_videos=pixel_values_videos,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ rope_deltas=rope_deltas)
+
+ if mode == 'caching':
+ self.cache_norm_state = self.model.norm.state
+ self.reg = []
+ self.sal = []
+
+ if mode == 'training' and timestamps is not None:
+ loss_regs, avg_factors = [], []
+ shift_labels = labels[..., 1:].contiguous()
+ for batch_idx, (vision_inds, ts) in enumerate(zip(self.cache_vision_inds, timestamps)):
+ # only consider the first video
+ s, e = vision_inds[0]
+
+ # spatial merge size set to 2
+ window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
+ assert video_grid_thw[0][0] * window == e - s
+
+ inds = torch.where(shift_labels[batch_idx] == self.config.reg_token_id)[0]
+ reg_tokens = self.reg_proj(self.model.norm.state[batch_idx, inds, None])
+ # reg_tokens: num_reg_tokens * 1 * channel
+
+ vis_tokens = self.model.norm.state[batch_idx, None, s:e]
+ vis_tokens = vis_tokens.transpose(-1, -2)
+ vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
+ vis_tokens = vis_tokens.transpose(-1, -2)
+ vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
+ # vis_tokens: num_reg_tokens * num_frames * channel
+
+ vis_tokens = self.vis_emb(vis_tokens)
+ reg_tokens = self.reg_emb(reg_tokens)
+ pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)
+
+ joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
+ collected = [joint_tokens]
+ for blk in self.vis_fuse:
+ collected.append(blk(collected[-1]))
+ collected = collected[1:]
+ joint_tokens = torch.cat(collected)
+ joint_tokens = self.vis_norm(joint_tokens)
+
+ video_emb = joint_tokens[:, :-1]
+ # video_emb: num_reg_tokens * num_frames * channel
+
+ query_emb = joint_tokens[:, -1:]
+ # query_emb: num_reg_tokens * 1 * channel
+
+ b, t, c = video_emb.size()
+ video_msk = video_emb.new_ones(b, t)
+
+ if t < self.vis_pad_length:
+ emb_pad = video_emb.new_zeros(b, self.vis_pad_length - t, c)
+ msk_pad = video_msk.new_zeros(b, self.vis_pad_length - t)
+ pymid_emb = torch.cat((video_emb, emb_pad), dim=1)
+ pymid_msk = torch.cat((video_msk, msk_pad), dim=1)
+ else:
+ pymid_emb, pymid_msk = video_emb, video_msk
+
+ pymid, pymid_msk = self.pyramid(pymid_emb, pymid_msk, return_mask=True)
+ if not len(pymid) == len(pymid_msk) == len(self.strides):
+ warnings.warn(f'pyramid size mismatch: {len(pymid)} {len(pymid_msk)} {len(self.strides)}')
+
+ point = self.generator(pymid)
+
+ out_class = [self.class_head(e) for e in pymid]
+ out_class = torch.cat(out_class, dim=1)
+
+ out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
+ out_coord = torch.cat(out_coord, dim=1)
+
+ data = dict(
+ point=point,
+ video_emb=video_emb,
+ query_emb=query_emb,
+ video_msk=video_msk,
+ pymid_msk=pymid_msk,
+ out_class=out_class,
+ out_coord=out_coord,
+ boundary=point.new_tensor(ts),
+ saliency=saliency[batch_idx].unsqueeze(0),
+ pos_clip=pos_clip[batch_idx].unsqueeze(0))
+
+ losses = self.bundle_loss(data, dict())
+ # print({k: v.item() for k, v in losses.items()})
+
+ loss_regs.append(sum(v for v in losses.values()))
+ avg_factors.append(len(ts))
+
+ assert len(loss_regs) in (1, 2) and len(loss_regs) == len(avg_factors)
+
+ if len(loss_regs) == 2 and loss_regs[0] > loss_regs[1]:
+ loss_reg, avg_factor = loss_regs[1], avg_factors[1]
+ else:
+ loss_reg, avg_factor = loss_regs[0], avg_factors[0]
+
+ if avg_factor > 0:
+ outputs.loss = outputs.loss + loss_reg / avg_factor
+ elif mode == 'generating':
+ logits = outputs.logits[0, -1]
+ if logits.argmax() == self.config.reg_token_id:
+ assert self.model.norm.state.size() == (1, 1, self.config.hidden_size)
+
+ # only consider the first video
+ s, e = self.cache_vision_inds[0][0]
+
+ # spatial merge size set to 2
+ window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
+ assert video_grid_thw[0][0] * window == e - s
+
+ reg_tokens = self.reg_proj(self.model.norm.state)
+ # reg_tokens: num_reg_tokens * 1 * channel
+
+ vis_tokens = self.cache_norm_state[:, s:e]
+ vis_tokens = vis_tokens.transpose(-1, -2)
+ vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
+ vis_tokens = vis_tokens.transpose(-1, -2)
+ vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
+ # vis_tokens: num_reg_tokens * num_frames * channel
+
+ vis_tokens = self.vis_emb(vis_tokens)
+ reg_tokens = self.reg_emb(reg_tokens)
+ pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)
+
+ joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
+ for blk in self.vis_fuse:
+ joint_tokens = blk(joint_tokens)
+ joint_tokens = self.vis_norm(joint_tokens)
+
+ video_emb = joint_tokens[:, :-1]
+ # video_emb: num_reg_tokens * num_frames * channel
+
+ query_emb = joint_tokens[:, -1:]
+ # query_emb: num_reg_tokens * 1 * channel
+
+ b, t, _ = video_emb.size()
+ video_msk = video_emb.new_ones(b, t)
+
+ pymid = self.pyramid(video_emb, video_msk)
+ point = self.generator(pymid)
+
+ out_class = [self.class_head(e).sigmoid() for e in pymid]
+ out_class = torch.cat(out_class, dim=1)
+
+ out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
+ out_coord = torch.cat(out_coord, dim=1)
+
+ sal = out_class[0]
+ bnd = out_coord[0]
+
+ bnd[:, 0] *= -1
+ bnd *= point[:, 3, None].repeat(1, 2)
+ bnd += point[:, 0, None].repeat(1, 2)
+ bnd /= t
+ bnd = torch.cat((bnd, sal), dim=-1)
+
+ _, inds = bnd[:, -1].sort(descending=True)
+ bnd = bnd[inds]
+
+ # hard coding nms config here
+ nms_cfg = dict(type='normal', thres=0.75)
+ assert nms_cfg['type'] in ('normal', 'linear', 'gaussian')
+
+ for i in range(bnd.size(0)):
+ max_idx = bnd[i:, -1].argmax(dim=0)
+ bnd = nncore.swap_element(bnd, i, max_idx + i)
+ iou = temporal_iou(bnd[i, None, :-1], bnd[i + 1:, :-1])[0]
+
+ if nms_cfg['type'] == 'normal':
+ bnd[i + 1:, -1][iou >= nms_cfg['thres']] = 0
+ elif nms_cfg['type'] == 'linear':
+ bnd[i + 1:, -1] *= 1 - iou
+ else:
+ bnd[i + 1:, -1] *= (-iou.pow(2) / nms_cfg['sigma']).exp()
+
+ # save top-100 predictions
+ self.reg.append(bnd[:100])
+
+ # save all saliency scores
+ self.sal.append(sal)
+
+ return outputs
+
+
+# set the patched model to a vision model
+MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES[AgentQwen2VLConfig.model_type] = 'AgentQwen2VLForConditionalGeneration'
+
+AutoConfig.register(AgentQwen2VLConfig.model_type, AgentQwen2VLConfig)
+AutoModel.register(AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration)
diff --git a/videomind/train/custom_trainer.py b/videomind/train/custom_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..abe65d5b9a58552394f08ddc716e023f8930d2fc
--- /dev/null
+++ b/videomind/train/custom_trainer.py
@@ -0,0 +1,251 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import warnings
+
+import nncore
+import torch
+from deepspeed import zero
+from safetensors.torch import load_model, save_file
+from torch.utils.data import Sampler
+from transformers import Trainer, TrainerCallback
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.trainer_pt_utils import get_parameter_names
+from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+from transformers.utils import CHAT_TEMPLATE_NAME
+
+
+def gather(param):
+ if hasattr(param, 'ds_id'):
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def gather_lora_params(model, bias):
+ assert bias in ('lora_only', 'all', 'none')
+
+ if bias == 'lora_only':
+ state_dict, maybe_lora_bias, lora_bias_names = dict(), dict(), set()
+ for n, p in model.named_parameters():
+ if 'modules_to_save' in n:
+ state_dict[n] = p
+ elif 'lora_' in n:
+ state_dict[n] = p
+ bias_name = n.split('lora_')[0] + 'bias'
+ lora_bias_names.add(bias_name)
+ elif 'bias' in n:
+ maybe_lora_bias[n] = p
+ for n, p in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ state_dict[bias_name] = p
+ else:
+ keys = ['lora_', 'modules_to_save', 'bias'] if bias == 'all' else ['lora_', 'modules_to_save']
+ state_dict = {n: p for n, p in model.named_parameters() if any(k in n for k in keys)}
+
+ state_dict = {n: gather(p) for n, p in state_dict.items()}
+ return state_dict
+
+
+def gather_key_params(model, keys):
+ state_dict = {n: p for n, p in model.named_parameters() if p.requires_grad and any(k in n for k in keys)}
+ state_dict = {n: gather(p) for n, p in state_dict.items()}
+ return state_dict
+
+
+class GroupSampler(Sampler):
+
+ def __init__(self, group_size, data_types, seed):
+ self.group_size = group_size
+ self.data_types = data_types
+ self.seed = seed
+
+ def __len__(self):
+ return len(self.data_types)
+
+ def __iter__(self):
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+
+ # avoid using dict or set here as they are not deterministic
+ unique_types, groups = [], []
+ for i, t in enumerate(self.data_types):
+ if t not in unique_types:
+ unique_types.append(t)
+ groups.append([])
+ groups[unique_types.index(t)].append(i)
+
+ group_batches = []
+ for group in groups:
+ inds = [group[i] for i in torch.randperm(len(group), generator=g)]
+ batches = [inds[i:i + self.group_size] for i in range(0, len(inds), self.group_size)]
+
+ if len(batches[-1]) < self.group_size:
+ batches = batches[:-1]
+
+ group_batches += batches
+
+ perm_group_batches = [group_batches[i] for i in torch.randperm(len(group_batches), generator=g)]
+ inds = [i for batch in perm_group_batches for i in batch]
+
+ return iter(inds)
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class SetEpochCallback(TrainerCallback):
+
+ # partially fixed in https://github.com/huggingface/accelerate/pull/3246
+ # but not for the case of batch_sampler.batch_sampler.sampler
+ def on_epoch_begin(self, args, state, control, **kwargs):
+ shard_sampler = kwargs['train_dataloader'].batch_sampler
+ batch_sampler = getattr(shard_sampler, 'batch_sampler', shard_sampler)
+ batch_sampler.sampler.set_epoch(int(state.epoch))
+
+
+class CustomTrainer(Trainer):
+
+ def __init__(self, *args, processor=None, head_keys=None, **kwargs):
+ super().__init__(*args, tokenizer=processor, **kwargs)
+ self.add_callback(SetEpochCallback())
+ self.processor = processor
+ self.head_keys = head_keys
+
+ def _get_train_sampler(self):
+ if self.args.group_by_data_type:
+ return GroupSampler(self.args.train_batch_size * self.args.world_size, self.train_dataset.data_types,
+ self.args.seed)
+ else:
+ return super()._get_train_sampler()
+
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+ if model is None:
+ model = self.model
+
+ super()._load_from_checkpoint(resume_from_checkpoint, model=model)
+
+ partial_path = nncore.join(resume_from_checkpoint, 'pytorch_model.safetensors')
+ if nncore.is_file(partial_path):
+ load_model(model, partial_path, strict=False, device=model.device)
+
+ def create_optimizer(self):
+ if self.optimizer is None:
+ grad_ps = [(n, p) for n, p in self.model.named_parameters() if p.requires_grad]
+
+ decay_ps = get_parameter_names(self.model, ALL_LAYERNORM_LAYERS)
+ decay_ps = [n for n in decay_ps if 'bias' not in n]
+
+ if self.args.lora_lr is None:
+ self.args.lora_lr = self.args.learning_rate
+
+ if self.args.head_lr is None:
+ self.args.head_lr = self.args.learning_rate
+
+ lora_ps = [n for n, _ in grad_ps if 'lora' in n]
+ head_ps = [n for n, _ in grad_ps if any(k in n for k in self.head_keys)]
+ assert all(n not in lora_ps for n in head_ps) and all(n not in head_ps for n in lora_ps)
+
+ groups = [{
+ 'params': [p for n, p in grad_ps if (n in decay_ps and n not in lora_ps and n not in head_ps)],
+ 'weight_decay': self.args.weight_decay
+ }, {
+ 'params': [p for n, p in grad_ps if (n not in decay_ps and n not in lora_ps and n not in head_ps)],
+ 'weight_decay': 0.0
+ }, {
+ 'params': [p for n, p in grad_ps if (n in decay_ps and n in lora_ps)],
+ 'weight_decay': self.args.weight_decay,
+ 'lr': self.args.lora_lr
+ }, {
+ 'params': [p for n, p in grad_ps if (n not in decay_ps and n in lora_ps)],
+ 'weight_decay': 0.0,
+ 'lr': self.args.lora_lr
+ }, {
+ 'params': [p for n, p in grad_ps if (n in decay_ps and n in head_ps)],
+ 'weight_decay': self.args.weight_decay,
+ 'lr': self.args.head_lr
+ }, {
+ 'params': [p for n, p in grad_ps if (n not in decay_ps and n in head_ps)],
+ 'weight_decay': 0.0,
+ 'lr': self.args.head_lr
+ }]
+
+ optim_cls, kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+ self.optimizer = optim_cls(groups, **kwargs)
+
+ return self.optimizer
+
+ def gather_and_save_model(self):
+ deepspeed_zero3 = self.accelerator.deepspeed_config['zero_optimization']['stage'] == 3
+ output_dir = self.args.output_dir
+
+ if self.args.should_save:
+ print(f'Saving final model to {nncore.abs_path(output_dir)}...')
+
+ if self.processor is not None and self.args.should_save:
+ self.processor.save_pretrained(output_dir)
+
+ # https://github.com/huggingface/transformers/pull/33462
+ if self.processor.chat_template is not None:
+ chat_template = {'chat_template': self.processor.chat_template}
+ nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2)
+
+ if self.args.save_full_model and self.args.lora_enable and deepspeed_zero3:
+ warnings.warn('LoRA models cannot be saved in full mode under zero3, saving adapters instead')
+ self.args.save_full_model = False
+
+ if self.args.save_full_model:
+ if self.args.lora_enable:
+ self.model = self.model.merge_and_unload()
+
+ if deepspeed_zero3 and not self.model_wrapped.zero_gather_16bit_weights_on_model_save():
+ warnings.warn('Saving zero checkpoint, use zero_to_fp32.py to recover weights')
+ self.model_wrapped.save_checkpoint(output_dir)
+ return
+
+ if deepspeed_zero3:
+ state_dict = self.model_wrapped._zero3_consolidated_16bit_state_dict()
+ else:
+ state_dict = self.model.state_dict()
+
+ if self.args.should_save:
+ state_dict = {k[17:] if k.startswith('base_model.model.') else k: v for k, v in state_dict.items()}
+ self._save(output_dir, state_dict=state_dict)
+ else:
+ if self.args.lora_enable:
+ state_dict = gather_lora_params(self.model, self.args.lora_bias)
+ if self.args.should_save:
+ self.model.save_pretrained(output_dir, state_dict=state_dict)
+
+ if self.args.should_save:
+ self.model.config.save_pretrained(output_dir)
+ self.model.generation_config.save_pretrained(output_dir)
+ self.tokenizer.save_pretrained(output_dir)
+
+ state_dict = gather_key_params(self.model, self.head_keys)
+ if self.args.should_save and state_dict:
+ save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors'))
+
+ def _save_checkpoint(self, model, trial, **kwargs):
+ output_dir = self._get_output_dir(trial)
+ output_dir = nncore.join(output_dir, f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}')
+
+ if self.args.should_save:
+ print(f'Saving checkpoint to {nncore.abs_path(output_dir)}...')
+
+ super()._save_checkpoint(model, trial, **kwargs)
+
+ if self.processor is not None and self.args.should_save:
+ self.processor.save_pretrained(output_dir)
+
+ # https://github.com/huggingface/transformers/pull/33462
+ if self.processor.chat_template is not None:
+ chat_template = {'chat_template': self.processor.chat_template}
+ nncore.dump(chat_template, nncore.join(output_dir, CHAT_TEMPLATE_NAME), indent=2)
+
+ if self.args.lora_enable:
+ state_dict = gather_key_params(self.model, self.head_keys)
+ if self.args.should_save:
+ self.model.config.save_pretrained(output_dir)
+ save_file(state_dict, nncore.join(output_dir, 'pytorch_model.safetensors'))
diff --git a/videomind/train/train.py b/videomind/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..3217a36fc1540e3dfff77b38eed05422e2a9e274
--- /dev/null
+++ b/videomind/train/train.py
@@ -0,0 +1,200 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+import nncore
+import torch
+import torch.nn as nn
+from peft import LoraConfig, PeftModel, get_peft_model
+from transformers import AutoProcessor, HfArgumentParser, TrainingArguments
+
+from videomind.constants import REG_TOKEN, SEG_E_TOKEN, SEG_S_TOKEN
+from videomind.dataset import HybridDataCollator, HybridDataset
+from videomind.model import MODELS
+from videomind.model.builder import build_model
+from videomind.train.custom_trainer import CustomTrainer
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default=None)
+ base_model: Optional[str] = field(default=None)
+ conv_type: Optional[str] = field(default=None)
+ role: Optional[str] = field(default=None)
+
+
+@dataclass
+class DataArguments:
+ datasets: Optional[str] = field(default=None)
+ min_video_len: Optional[int] = field(default=-1)
+ max_video_len: Optional[int] = field(default=-1)
+ min_num_words: Optional[int] = field(default=-1)
+ max_num_words: Optional[int] = field(default=-1)
+ max_retries: Optional[int] = field(default=10)
+
+
+@dataclass
+class CustomArguments:
+ optim: Optional[str] = field(default='adamw_torch')
+ group_by_data_type: Optional[bool] = field(default=True)
+ merge_adapter: Optional[bool] = field(default=False)
+ lora_enable: Optional[bool] = field(default=False)
+ lora_type: Optional[str] = field(default='qkvo')
+ lora_r: Optional[int] = field(default=64)
+ lora_alpha: Optional[int] = field(default=64)
+ lora_dropout: Optional[float] = field(default=0.1)
+ lora_bias: Optional[str] = field(default='none')
+ lora_lr: Optional[float] = field(default=None)
+ head_lr: Optional[float] = field(default=None)
+ tuning_modules: Optional[str] = field(default=None)
+ save_full_model: Optional[bool] = field(default=False)
+ remove_unused_columns: Optional[bool] = field(default=False)
+
+
+@dataclass
+class TrainingArguments(CustomArguments, TrainingArguments):
+ pass
+
+
+def get_target_modules(model, lora_type, base_model):
+ lora_type = lora_type.split('_')
+ assert all(t in ('qkvo', 'linear', 'all') for t in lora_type)
+
+ if base_model == 'qwen2_vl':
+ # all qkvo layers in the visual encoder and the llm
+ qkvo_keys = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'attn.qkv', 'attn.proj']
+
+ target_modules = set()
+ for n, m in model.named_modules():
+ if not isinstance(m, nn.Linear):
+ continue
+ if 'all' not in lora_type and 'visual' in n:
+ continue
+ if 'qkvo' in lora_type and not any(n.endswith(k) for k in qkvo_keys):
+ continue
+ target_modules.add(n)
+ else:
+ raise ValueError(f'unknown base model: {base_model}')
+
+ return target_modules
+
+
+def train(TrainingArguments, Trainer):
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ assert model_args.role in ('all_in_one', 'planner', 'grounder', 'verifier', 'answerer')
+
+ config_cls, model_cls = MODELS[model_args.base_model]
+
+ dtype = torch.bfloat16 if training_args.bf16 else torch.float32
+
+ config = config_cls.from_pretrained(model_args.model_name_or_path, torch_dtype=dtype)
+ config.update(model_args.__dict__)
+
+ if config.model_type == 'agent_qwen2_vl':
+ model, processor = build_model(
+ model_args.model_name_or_path,
+ config=config,
+ is_trainable=True,
+ merge_adapter=training_args.merge_adapter,
+ dtype=dtype)
+ else:
+ # set do_resize to false to avoid duplicated resizing
+ # https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
+ processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, do_resize=False)
+
+ # eager attention has known & unknown bugs
+ # [4.46.2] broken causality fp16: https://github.com/huggingface/transformers/issues/35151
+ # [4.48.1] broken sliding window: https://github.com/huggingface/transformers/issues/35924
+ model = model_cls.from_pretrained(model_args.model_name_or_path, config=config, attn_implementation='sdpa')
+
+ # save base model path for inference
+ model.config.base_model_path = model_args.model_name_or_path
+
+ # conv parameters may become inf after casting to fp16
+ model.reset_conv_parameters()
+
+ model.requires_grad_(False)
+
+ if training_args.lora_enable and not isinstance(model, PeftModel):
+ target_modules = get_target_modules(model, training_args.lora_type, model.config.base_model)
+ tune_lm_head = model.config.role in ('all_in_one', 'grounder', 'verifier')
+ print(f'LoRA target modules: {target_modules}')
+ lora_config = LoraConfig(
+ task_type='CAUSAL_LM',
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ target_modules=target_modules,
+ modules_to_save=['embed_tokens', 'lm_head'] if tune_lm_head else None)
+ # transformers integration does not support merge_and_unload, use peft instead
+ model = get_peft_model(model, lora_config, adapter_name=model_args.role)
+
+ new_tokens = processor.tokenizer.add_special_tokens(
+ dict(additional_special_tokens=[REG_TOKEN, SEG_S_TOKEN, SEG_E_TOKEN]))
+ print(f'Added {new_tokens} new token(s)')
+
+ model.config.reg_token_id = processor.tokenizer.convert_tokens_to_ids(REG_TOKEN)
+ model.config.seg_s_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_S_TOKEN)
+ model.config.seg_e_token_id = processor.tokenizer.convert_tokens_to_ids(SEG_E_TOKEN)
+
+ if new_tokens > 0 and len(processor.tokenizer) > model.config.vocab_size:
+ print(f'Expanding vocab size: {model.config.vocab_size} -> {len(processor.tokenizer)}')
+ model.resize_token_embeddings(len(processor.tokenizer))
+ i_emb = model.get_input_embeddings().weight.data
+ o_emb = model.get_output_embeddings().weight.data
+ i_emb[-new_tokens:] = i_emb[:-new_tokens].mean(0, keepdim=True)
+ o_emb[-new_tokens:] = o_emb[:-new_tokens].mean(0, keepdim=True)
+
+ tuning_modules = [] if training_args.tuning_modules is None else training_args.tuning_modules.split(',')
+
+ head_keys = [
+ 'vis_proj', 'reg_proj', 'vis_fuse', 'vis_norm', 'vis_pos', 'vis_emb', 'reg_emb', 'pyramid', 'class_head',
+ 'coord_head', 'coef', 'bundle_loss'
+ ]
+
+ for n, p in model.named_parameters():
+ # embed_tokens and lm_head might be handled by lora
+ if not training_args.lora_enable and new_tokens > 0 and any(k in n for k in ('embed_tokens', 'lm_head')):
+ p.requires_grad = True
+
+ if 'projector' in tuning_modules and 'visual.merger' in n:
+ p.requires_grad = True
+
+ if model_args.role in ('all_in_one', 'grounder') and any(k in n for k in head_keys):
+ p.requires_grad = True
+
+ if training_args.local_rank in (0, -1):
+ for n, p in model.named_parameters():
+ print(p.requires_grad, p.dtype, p.shape, n)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ ratio = round(learnable_params / total_params * 100, 2) if total_params > 0 else 0
+ print(f'Total params: {total_params} Learnable params: {learnable_params} ({ratio}%)')
+
+ i_size = model.get_input_embeddings().num_embeddings
+ o_size = model.get_output_embeddings().out_features
+ assert i_size == o_size, (i_size, o_size)
+ print(f'Tokenizer size: {len(processor.tokenizer)} Vocab size: {model.config.vocab_size} Embed size: {i_size}')
+
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ data_collator=HybridDataCollator(processor.tokenizer),
+ train_dataset=HybridDataset(processor, model.config, model_args, data_args, training_args),
+ processor=processor,
+ head_keys=head_keys)
+
+ has_ckpt = bool(nncore.find(training_args.output_dir, 'checkpoint-*'))
+ trainer.train(resume_from_checkpoint=has_ckpt)
+
+ trainer.save_state()
+ trainer.gather_and_save_model()
+
+
+if __name__ == '__main__':
+ train(TrainingArguments, CustomTrainer)
diff --git a/videomind/utils/io.py b/videomind/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..4040700eeba7c01c17b877f2e15dc80c90d6d21e
--- /dev/null
+++ b/videomind/utils/io.py
@@ -0,0 +1,30 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import pysrt
+from decord import VideoReader
+
+
+def time_to_seconds(t):
+ return (t.hour * 60 + t.minute) * 60 + t.second + t.microsecond / 1000000
+
+
+def load_subtitle(path):
+ subs = pysrt.open(path)
+
+ parsed = []
+ for sub in subs:
+ s = time_to_seconds(sub.start.to_time())
+ e = time_to_seconds(sub.end.to_time())
+ parsed.append((s, e, sub.text))
+
+ return parsed
+
+
+def get_duration(path, num_threads=1):
+ # sometimes the video is loaded as a list of frames
+ if isinstance(path, list):
+ return len(path)
+
+ vr = VideoReader(path, num_threads=num_threads)
+ duration = len(vr) / vr.get_avg_fps()
+ return duration
diff --git a/videomind/utils/parser.py b/videomind/utils/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ca1a643c8e050680686ecb7ce65d1ac073de4d8
--- /dev/null
+++ b/videomind/utils/parser.py
@@ -0,0 +1,25 @@
+# Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
+
+import re
+
+
+def parse_span(span, duration, min_len=-1):
+ s, e = span
+ s, e = min(duration, max(0, s)), min(duration, max(0, e))
+ s, e = min(s, e), max(s, e)
+
+ if min_len != -1 and e - s < min_len:
+ h = min_len / 2
+ c = min(duration - h, max(h, (s + e) / 2))
+ s, e = c - h, c + h
+
+ s, e = min(duration, max(0, s)), min(duration, max(0, e))
+ return s, e
+
+
+def parse_query(query):
+ return re.sub(r'\s+', ' ', query).strip().strip('.').strip()
+
+
+def parse_question(question):
+ return re.sub(r'\s+', ' ', question).strip()