Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +16 -0
- .gitignore +9 -0
- app.py +640 -0
- assets/bot.png +0 -0
- assets/user.png +0 -0
- data/10309844035.mp4 +3 -0
- data/13887487955.mp4 +3 -0
- data/4167294363.mp4 +3 -0
- data/4742652230.mp4 +3 -0
- data/4766274786.mp4 +3 -0
- data/5012237466.mp4 +3 -0
- data/5188348585.mp4 +3 -0
- data/9383140374.mp4 +3 -0
- data/DTInxNfWXVc_210.0_360.0.mp4 +3 -0
- data/RoripwjYFp8_210.0_360.0.mp4 +3 -0
- data/UFWQKrcbhjI_360.0_510.0.mp4 +3 -0
- data/Z3-IZ3HAmIA_60.0_210.0.mp4 +3 -0
- data/h6QKDqomIPk_210.0_360.0.mp4 +3 -0
- data/pA6Z-qYhSNg_60.0_210.0.mp4 +3 -0
- data/rrTIeJRVGjg_60.0_210.0.mp4 +3 -0
- data/yId2wIocTys_210.0_360.0.mp4 +3 -0
- requirements.txt +26 -0
- setup.cfg +16 -0
- videomind/constants.py +42 -0
- videomind/conversation.py +49 -0
- videomind/dataset/__init__.py +61 -0
- videomind/dataset/collator.py +40 -0
- videomind/dataset/hybrid.py +180 -0
- videomind/dataset/sub_classes/__init__.py +69 -0
- videomind/dataset/sub_classes/activitynet_captions.py +96 -0
- videomind/dataset/sub_classes/activitynet_rtl.py +68 -0
- videomind/dataset/sub_classes/cgbench.py +47 -0
- videomind/dataset/sub_classes/charades_sta.py +45 -0
- videomind/dataset/sub_classes/cosmo_cap.py +37 -0
- videomind/dataset/sub_classes/didemo.py +59 -0
- videomind/dataset/sub_classes/ego4d_naq.py +81 -0
- videomind/dataset/sub_classes/ego4d_nlq.py +41 -0
- videomind/dataset/sub_classes/ego_timeqa.py +93 -0
- videomind/dataset/sub_classes/hirest.py +150 -0
- videomind/dataset/sub_classes/internvit_vtime.py +45 -0
- videomind/dataset/sub_classes/longvideobench.py +53 -0
- videomind/dataset/sub_classes/lvbench.py +52 -0
- videomind/dataset/sub_classes/mlvu.py +55 -0
- videomind/dataset/sub_classes/mvbench.py +74 -0
- videomind/dataset/sub_classes/nextgqa.py +87 -0
- videomind/dataset/sub_classes/nextqa.py +63 -0
- videomind/dataset/sub_classes/qa_ego4d.py +98 -0
- videomind/dataset/sub_classes/queryd.py +49 -0
- videomind/dataset/sub_classes/qvhighlights.py +78 -0
- videomind/dataset/sub_classes/rextime.py +81 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,19 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            data/10309844035.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            data/13887487955.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            data/4167294363.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            data/4742652230.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            data/4766274786.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            data/5012237466.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            data/5188348585.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            data/9383140374.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            data/DTInxNfWXVc_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            data/RoripwjYFp8_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            data/UFWQKrcbhjI_360.0_510.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            data/Z3-IZ3HAmIA_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 48 | 
            +
            data/h6QKDqomIPk_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 49 | 
            +
            data/pA6Z-qYhSNg_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 50 | 
            +
            data/rrTIeJRVGjg_60.0_210.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
| 51 | 
            +
            data/yId2wIocTys_210.0_360.0.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__
         | 
| 3 | 
            +
            *.egg-info
         | 
| 4 | 
            +
            *.py[cod]
         | 
| 5 | 
            +
            *$py.class
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Temporary data
         | 
| 8 | 
            +
            .DS_Store
         | 
| 9 | 
            +
            ._*
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,640 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024 Ye Liu. Licensed under the BSD-3-Clause license.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import html
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from functools import partial
         | 
| 9 | 
            +
            from threading import Thread
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import gradio as gr
         | 
| 12 | 
            +
            import nncore
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 15 | 
            +
            from transformers import TextIteratorStreamer
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from videomind.constants import GROUNDER_PROMPT, PLANNER_PROMPT, VERIFIER_PROMPT
         | 
| 18 | 
            +
            from videomind.dataset.utils import process_vision_info
         | 
| 19 | 
            +
            from videomind.model.builder import build_model
         | 
| 20 | 
            +
            from videomind.utils.io import get_duration
         | 
| 21 | 
            +
            from videomind.utils.parser import parse_query, parse_span
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            BASE_MODEL = 'model_zoo/Qwen2-VL-2B-Instruct'
         | 
| 24 | 
            +
            BASE_MODEL_HF = 'Qwen/Qwen2-VL-2B-Instruct'
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            MODEL = 'model_zoo/VideoMind-2B'
         | 
| 27 | 
            +
            MODEL_HF = 'yeliudev/VideoMind-2B'
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            TITLE = 'VideoMind: A Chain-of-LoRA Agent for Long Video Reasoning'
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            TITLE_MD = f'<h1 align="center">💡 {TITLE}</h1>'
         | 
| 32 | 
            +
            DESCRIPTION_MD = """VideoMind is a multi-modal agent framework that enhances video reasoning by emulating *human-like* processes, such as *breaking down tasks*, *localizing and verifying moments*, and *synthesizing answers*. This approach addresses the unique challenges of temporal-grounded reasoning in a progressive strategy. Please find more details at our <a href="https://videomind.github.io/" target="_blank">Project Page</a>, <a href="https://arxiv.org/abs/2503.13444" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/VideoMind" target="_blank">GitHub Repo</a>."""  # noqa
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # yapf:disable
         | 
| 35 | 
            +
            EXAMPLES = [
         | 
| 36 | 
            +
                ('data/4167294363.mp4', 'Why did the old man stand up?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 37 | 
            +
                ('data/5012237466.mp4', 'How does the child in stripes react about the fountain?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 38 | 
            +
                ('data/13887487955.mp4', 'What did the excavator do after it pushed the cement forward?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 39 | 
            +
                ('data/5188348585.mp4', 'What did the person do before pouring the liquor?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 40 | 
            +
                ('data/4766274786.mp4', 'What did the girl do after the baby lost the balloon?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 41 | 
            +
                ('data/4742652230.mp4', 'Why is the girl pushing the boy only around the toy but not to other places?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 42 | 
            +
                ('data/9383140374.mp4', 'How does the girl in pink control the movement of the claw?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 43 | 
            +
                ('data/10309844035.mp4', 'Why are they holding up the phones?', ['pla', 'gnd', 'ver', 'ans']),
         | 
| 44 | 
            +
                ('data/pA6Z-qYhSNg_60.0_210.0.mp4', 'Different types of meat products are being cut, shaped and prepared', ['gnd', 'ver']),
         | 
| 45 | 
            +
                ('data/UFWQKrcbhjI_360.0_510.0.mp4', 'A man talks to the camera whilst walking along a roadside in a rural area', ['gnd', 'ver']),
         | 
| 46 | 
            +
                ('data/RoripwjYFp8_210.0_360.0.mp4', 'A woman wearing glasses eating something at a street market', ['gnd', 'ver']),
         | 
| 47 | 
            +
                ('data/h6QKDqomIPk_210.0_360.0.mp4', 'A toddler sits in his car seat, holding his yellow tablet', ['gnd', 'ver']),
         | 
| 48 | 
            +
                ('data/Z3-IZ3HAmIA_60.0_210.0.mp4', 'A view from the window as the plane accelerates and takes off from the runway', ['gnd', 'ver']),
         | 
| 49 | 
            +
                ('data/yId2wIocTys_210.0_360.0.mp4', "Temporally locate the visual content mentioned in the text query 'kids exercise in front of parked cars' within the video.", ['pla', 'gnd', 'ver']),
         | 
| 50 | 
            +
                ('data/rrTIeJRVGjg_60.0_210.0.mp4', "Localize the moment that provides relevant context about 'man stands in front of a white building monologuing'.", ['pla', 'gnd', 'ver']),
         | 
| 51 | 
            +
                ('data/DTInxNfWXVc_210.0_360.0.mp4', "Find the video segment that corresponds to the given textual query 'man with headphones talking'.", ['pla', 'gnd', 'ver']),
         | 
| 52 | 
            +
            ]
         | 
| 53 | 
            +
            # yapf:enable
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            CSS = """button .box { text-align: left }"""
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            JS = """
         | 
| 58 | 
            +
            function init() {
         | 
| 59 | 
            +
                var info = document.getElementById('role').querySelectorAll('[class^="svelte"]')[1]
         | 
| 60 | 
            +
                info.innerHTML = info.innerHTML.replace(/</g, '<').replace(/>/g, '>')
         | 
| 61 | 
            +
            }
         | 
| 62 | 
            +
            """
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class CustomStreamer(TextIteratorStreamer):
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def put(self, value):
         | 
| 68 | 
            +
                    if len(value.shape) > 1 and value.shape[0] > 1:
         | 
| 69 | 
            +
                        raise ValueError('TextStreamer only supports batch size 1')
         | 
| 70 | 
            +
                    elif len(value.shape) > 1:
         | 
| 71 | 
            +
                        value = value[0]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.skip_prompt and self.next_tokens_are_prompt:
         | 
| 74 | 
            +
                        self.next_tokens_are_prompt = False
         | 
| 75 | 
            +
                        return
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.token_cache.extend(value.tolist())
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # force skipping eos token
         | 
| 80 | 
            +
                    if self.token_cache[-1] == self.tokenizer.eos_token_id:
         | 
| 81 | 
            +
                        self.token_cache = self.token_cache[:-1]
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    # cache decoded text for future use
         | 
| 86 | 
            +
                    self.text_cache = text
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    if text.endswith('\n'):
         | 
| 89 | 
            +
                        printable_text = text[self.print_len:]
         | 
| 90 | 
            +
                        self.token_cache = []
         | 
| 91 | 
            +
                        self.print_len = 0
         | 
| 92 | 
            +
                    elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
         | 
| 93 | 
            +
                        printable_text = text[self.print_len:]
         | 
| 94 | 
            +
                        self.print_len += len(printable_text)
         | 
| 95 | 
            +
                    else:
         | 
| 96 | 
            +
                        printable_text = text[self.print_len:text.rfind(' ') + 1]
         | 
| 97 | 
            +
                        self.print_len += len(printable_text)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.on_finalized_text(printable_text)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def seconds_to_hms(seconds):
         | 
| 103 | 
            +
                hours, remainder = divmod(round(seconds), 3600)
         | 
| 104 | 
            +
                minutes, seconds = divmod(remainder, 60)
         | 
| 105 | 
            +
                return f'{hours:02}:{minutes:02}:{seconds:02}'
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            def enable_btns():
         | 
| 109 | 
            +
                return (gr.Button(interactive=True), ) * 3
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def disable_btns():
         | 
| 113 | 
            +
                return (gr.Button(interactive=False), ) * 3
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            def update_placeholder(role):
         | 
| 117 | 
            +
                placeholder = 'Ask a question about the video...' if 'ans' in role else 'Write a query to search for a moment...'
         | 
| 118 | 
            +
                return gr.Textbox(placeholder=placeholder)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def main(video, prompt, role, temperature, max_new_tokens, model, processor, streamer, device):
         | 
| 122 | 
            +
                history = []
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                if not video:
         | 
| 125 | 
            +
                    gr.Warning('Please upload a video or click [Random] to sample one.')
         | 
| 126 | 
            +
                    return history
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                if not prompt:
         | 
| 129 | 
            +
                    gr.Warning('Please provide a prompt or click [Random] to sample one.')
         | 
| 130 | 
            +
                    return history
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                if 'gnd' not in role and 'ans' not in role:
         | 
| 133 | 
            +
                    gr.Warning('Please at least select Grounder or Answerer.')
         | 
| 134 | 
            +
                    return history
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                if 'ver' in role and 'gnd' not in role:
         | 
| 137 | 
            +
                    gr.Warning('Verifier cannot be used without Grounder.')
         | 
| 138 | 
            +
                    return history
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                if 'pla' in role and any(k not in role for k in ('gnd', 'ver', 'ans')):
         | 
| 141 | 
            +
                    gr.Warning('Planner can only be used when all other roles are selected.')
         | 
| 142 | 
            +
                    return history
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                history.append({'role': 'user', 'content': prompt})
         | 
| 145 | 
            +
                yield history
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                duration = get_duration(video)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                # do grounding and answering by default
         | 
| 150 | 
            +
                do_grounding = True
         | 
| 151 | 
            +
                do_answering = True
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                # initialize grounding query as prompt
         | 
| 154 | 
            +
                query = prompt
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                if 'pla' in role:
         | 
| 157 | 
            +
                    text = PLANNER_PROMPT.format(prompt)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    history.append({
         | 
| 160 | 
            +
                        'metadata': {
         | 
| 161 | 
            +
                            'title': '🗺️ Working as Planner...'
         | 
| 162 | 
            +
                        },
         | 
| 163 | 
            +
                        'role': 'assistant',
         | 
| 164 | 
            +
                        'content': f'##### Planner Prompt:\n\n{html.escape(text)}\n\n##### Planner Response:\n\n...'
         | 
| 165 | 
            +
                    })
         | 
| 166 | 
            +
                    yield history
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    start_time = time.perf_counter()
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    messages = [{
         | 
| 171 | 
            +
                        'role':
         | 
| 172 | 
            +
                        'user',
         | 
| 173 | 
            +
                        'content': [{
         | 
| 174 | 
            +
                            'type': 'video',
         | 
| 175 | 
            +
                            'video': video,
         | 
| 176 | 
            +
                            'num_threads': 1,
         | 
| 177 | 
            +
                            'min_pixels': 36 * 28 * 28,
         | 
| 178 | 
            +
                            'max_pixels': 64 * 28 * 28,
         | 
| 179 | 
            +
                            'max_frames': 100,
         | 
| 180 | 
            +
                            'fps': 1.0
         | 
| 181 | 
            +
                        }, {
         | 
| 182 | 
            +
                            'type': 'text',
         | 
| 183 | 
            +
                            'text': text
         | 
| 184 | 
            +
                        }]
         | 
| 185 | 
            +
                    }]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    text = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    images, videos = process_vision_info(messages)
         | 
| 190 | 
            +
                    data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
         | 
| 191 | 
            +
                    data = data.to(device)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    model.base_model.disable_adapter_layers()
         | 
| 194 | 
            +
                    model.base_model.enable_adapter_layers()
         | 
| 195 | 
            +
                    model.set_adapter('planner')
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    generation_kwargs = dict(
         | 
| 198 | 
            +
                        **data,
         | 
| 199 | 
            +
                        streamer=streamer,
         | 
| 200 | 
            +
                        do_sample=temperature > 0,
         | 
| 201 | 
            +
                        temperature=temperature if temperature > 0 else None,
         | 
| 202 | 
            +
                        top_p=None,
         | 
| 203 | 
            +
                        top_k=None,
         | 
| 204 | 
            +
                        repetition_penalty=None,
         | 
| 205 | 
            +
                        max_new_tokens=max_new_tokens)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    t = Thread(target=model.generate, kwargs=generation_kwargs)
         | 
| 208 | 
            +
                    t.start()
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    skipped = False
         | 
| 211 | 
            +
                    for i, text in enumerate(streamer):
         | 
| 212 | 
            +
                        if text and not skipped:
         | 
| 213 | 
            +
                            history[-1]['content'] = history[-1]['content'].rstrip('.')
         | 
| 214 | 
            +
                            skipped = True
         | 
| 215 | 
            +
                        history[-1]['content'] += text
         | 
| 216 | 
            +
                        yield history
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    elapsed_time = round(time.perf_counter() - start_time, 1)
         | 
| 219 | 
            +
                    history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
         | 
| 220 | 
            +
                    yield history
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    try:
         | 
| 223 | 
            +
                        parsed = json.loads(streamer.text_cache)
         | 
| 224 | 
            +
                        action = parsed[0] if isinstance(parsed, list) else parsed
         | 
| 225 | 
            +
                        if action['type'].lower() == 'grounder' and action['value']:
         | 
| 226 | 
            +
                            query = action['value']
         | 
| 227 | 
            +
                        elif action['type'].lower() == 'answerer':
         | 
| 228 | 
            +
                            do_grounding = False
         | 
| 229 | 
            +
                            do_answering = True
         | 
| 230 | 
            +
                    except Exception:
         | 
| 231 | 
            +
                        pass
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    response = 'After browsing the video and the question. My plan to figure out the answer is as follows:\n'
         | 
| 234 | 
            +
                    step_idx = 1
         | 
| 235 | 
            +
                    if 'gnd' in role and do_grounding:
         | 
| 236 | 
            +
                        response += f'\n{step_idx}. Localize the relevant moment in this video using the query "<span style="color:red">{query}</span>".'
         | 
| 237 | 
            +
                        step_idx += 1
         | 
| 238 | 
            +
                    if 'ver' in role and do_grounding:
         | 
| 239 | 
            +
                        response += f'\n{step_idx}. Verify the grounded moments one-by-one and select the best cancdidate.'
         | 
| 240 | 
            +
                        step_idx += 1
         | 
| 241 | 
            +
                    if 'ans' in role and do_answering:
         | 
| 242 | 
            +
                        if step_idx > 1:
         | 
| 243 | 
            +
                            response += f'\n{step_idx}. Crop the video segment and zoom-in to higher resolution.'
         | 
| 244 | 
            +
                        else:
         | 
| 245 | 
            +
                            response += f'\n{step_idx}. Analyze the whole video directly without cropping.'
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    history.append({'role': 'assistant', 'content': ''})
         | 
| 248 | 
            +
                    for i, text in enumerate(response.split(' ')):
         | 
| 249 | 
            +
                        history[-1]['content'] += ' ' + text if i > 0 else text
         | 
| 250 | 
            +
                        yield history
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                if 'gnd' in role and do_grounding:
         | 
| 253 | 
            +
                    query = parse_query(query)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    text = GROUNDER_PROMPT.format(query)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    history.append({
         | 
| 258 | 
            +
                        'metadata': {
         | 
| 259 | 
            +
                            'title': '🔍 Working as Grounder...'
         | 
| 260 | 
            +
                        },
         | 
| 261 | 
            +
                        'role': 'assistant',
         | 
| 262 | 
            +
                        'content': f'##### Grounder Prompt:\n\n{html.escape(text)}\n\n##### Grounder Response:\n\n...'
         | 
| 263 | 
            +
                    })
         | 
| 264 | 
            +
                    yield history
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    start_time = time.perf_counter()
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    messages = [{
         | 
| 269 | 
            +
                        'role':
         | 
| 270 | 
            +
                        'user',
         | 
| 271 | 
            +
                        'content': [{
         | 
| 272 | 
            +
                            'type': 'video',
         | 
| 273 | 
            +
                            'video': video,
         | 
| 274 | 
            +
                            'num_threads': 1,
         | 
| 275 | 
            +
                            'min_pixels': 36 * 28 * 28,
         | 
| 276 | 
            +
                            'max_pixels': 64 * 28 * 28,
         | 
| 277 | 
            +
                            'max_frames': 150,
         | 
| 278 | 
            +
                            'fps': 1.0
         | 
| 279 | 
            +
                        }, {
         | 
| 280 | 
            +
                            'type': 'text',
         | 
| 281 | 
            +
                            'text': text
         | 
| 282 | 
            +
                        }]
         | 
| 283 | 
            +
                    }]
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                    text = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 286 | 
            +
                    images, videos = process_vision_info(messages)
         | 
| 287 | 
            +
                    data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
         | 
| 288 | 
            +
                    data = data.to(device)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    model.base_model.disable_adapter_layers()
         | 
| 291 | 
            +
                    model.base_model.enable_adapter_layers()
         | 
| 292 | 
            +
                    model.set_adapter('grounder')
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    generation_kwargs = dict(
         | 
| 295 | 
            +
                        **data,
         | 
| 296 | 
            +
                        streamer=streamer,
         | 
| 297 | 
            +
                        do_sample=temperature > 0,
         | 
| 298 | 
            +
                        temperature=temperature if temperature > 0 else None,
         | 
| 299 | 
            +
                        top_p=None,
         | 
| 300 | 
            +
                        top_k=None,
         | 
| 301 | 
            +
                        repetition_penalty=None,
         | 
| 302 | 
            +
                        max_new_tokens=max_new_tokens)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    t = Thread(target=model.generate, kwargs=generation_kwargs)
         | 
| 305 | 
            +
                    t.start()
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    skipped = False
         | 
| 308 | 
            +
                    for i, text in enumerate(streamer):
         | 
| 309 | 
            +
                        if text and not skipped:
         | 
| 310 | 
            +
                            history[-1]['content'] = history[-1]['content'].rstrip('.')
         | 
| 311 | 
            +
                            skipped = True
         | 
| 312 | 
            +
                        history[-1]['content'] += text
         | 
| 313 | 
            +
                        yield history
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    elapsed_time = round(time.perf_counter() - start_time, 1)
         | 
| 316 | 
            +
                    history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
         | 
| 317 | 
            +
                    yield history
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    if len(model.reg) > 0:
         | 
| 320 | 
            +
                        # 1. extract timestamps and confidences
         | 
| 321 | 
            +
                        blob = model.reg[0].cpu().float()
         | 
| 322 | 
            +
                        pred, conf = blob[:, :2] * duration, blob[:, -1].tolist()
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                        # 2. clamp timestamps
         | 
| 325 | 
            +
                        pred = pred.clamp(min=0, max=duration)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                        # 3. sort timestamps
         | 
| 328 | 
            +
                        inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
         | 
| 329 | 
            +
                        pred[inds] = pred[inds].roll(1)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                        # 4. convert timestamps to list
         | 
| 332 | 
            +
                        pred = pred.tolist()
         | 
| 333 | 
            +
                    else:
         | 
| 334 | 
            +
                        if 'ver' in role:
         | 
| 335 | 
            +
                            pred = [[i * duration / 6, (i + 2) * duration / 6] for i in range(5)]
         | 
| 336 | 
            +
                            conf = [0] * 5
         | 
| 337 | 
            +
                        else:
         | 
| 338 | 
            +
                            pred = [[0, duration]]
         | 
| 339 | 
            +
                            conf = [0]
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    response = 'The candidate moments and confidence scores are as follows:\n'
         | 
| 342 | 
            +
                    response += '\n| ID | Start Time | End Time | Confidence |'
         | 
| 343 | 
            +
                    response += '\n| :-: | :-: | :-: | :-: |'
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    # using top-5 predictions
         | 
| 346 | 
            +
                    for i, (p, c) in enumerate(zip(pred[:5], conf[:5])):
         | 
| 347 | 
            +
                        response += f'\n| {i} | {seconds_to_hms(p[0])} | {seconds_to_hms(p[1])} | {c:.2f} |'
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    response += f'\n\nTherefore, the target moment might happens from <span style="color:red">{seconds_to_hms(pred[0][0])}</span> to <span style="color:red">{seconds_to_hms(pred[0][1])}</span>.'
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                    history.append({'role': 'assistant', 'content': ''})
         | 
| 352 | 
            +
                    for i, text in enumerate(response.split(' ')):
         | 
| 353 | 
            +
                        history[-1]['content'] += ' ' + text if i > 0 else text
         | 
| 354 | 
            +
                        yield history
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                if 'ver' in role and do_grounding:
         | 
| 357 | 
            +
                    text = VERIFIER_PROMPT.format(query)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    history.append({
         | 
| 360 | 
            +
                        'metadata': {
         | 
| 361 | 
            +
                            'title': '📊 Working as Verifier...'
         | 
| 362 | 
            +
                        },
         | 
| 363 | 
            +
                        'role': 'assistant',
         | 
| 364 | 
            +
                        'content': f'##### Verifier Prompt:\n\n{html.escape(text)}\n\n##### Verifier Response:\n\n...'
         | 
| 365 | 
            +
                    })
         | 
| 366 | 
            +
                    yield history
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    start_time = time.perf_counter()
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # using top-5 predictions
         | 
| 371 | 
            +
                    prob = []
         | 
| 372 | 
            +
                    for i, cand in enumerate(pred[:5]):
         | 
| 373 | 
            +
                        s0, e0 = parse_span(cand, duration, 2)
         | 
| 374 | 
            +
                        offset = (e0 - s0) / 2
         | 
| 375 | 
            +
                        s1, e1 = parse_span([s0 - offset, e0 + offset], duration)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                        # percentage of s0, e0 within s1, e1
         | 
| 378 | 
            +
                        s = (s0 - s1) / (e1 - s1)
         | 
| 379 | 
            +
                        e = (e0 - s1) / (e1 - s1)
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                        messages = [{
         | 
| 382 | 
            +
                            'role':
         | 
| 383 | 
            +
                            'user',
         | 
| 384 | 
            +
                            'content': [{
         | 
| 385 | 
            +
                                'type': 'video',
         | 
| 386 | 
            +
                                'video': video,
         | 
| 387 | 
            +
                                'num_threads': 1,
         | 
| 388 | 
            +
                                'video_start': s1,
         | 
| 389 | 
            +
                                'video_end': e1,
         | 
| 390 | 
            +
                                'min_pixels': 36 * 28 * 28,
         | 
| 391 | 
            +
                                'max_pixels': 64 * 28 * 28,
         | 
| 392 | 
            +
                                'max_frames': 64,
         | 
| 393 | 
            +
                                'fps': 2.0
         | 
| 394 | 
            +
                            }, {
         | 
| 395 | 
            +
                                'type': 'text',
         | 
| 396 | 
            +
                                'text': text
         | 
| 397 | 
            +
                            }]
         | 
| 398 | 
            +
                        }]
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                        text = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 401 | 
            +
                        images, videos = process_vision_info(messages)
         | 
| 402 | 
            +
                        data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                        # ===== insert segment start/end tokens =====
         | 
| 405 | 
            +
                        video_grid_thw = data['video_grid_thw'][0]
         | 
| 406 | 
            +
                        num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4)
         | 
| 407 | 
            +
                        assert num_frames * window * 4 == data['pixel_values_videos'].size(0)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                        pos_s, pos_e = round(s * num_frames), round(e * num_frames)
         | 
| 410 | 
            +
                        pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames)
         | 
| 411 | 
            +
                        assert pos_s <= pos_e, (num_frames, s, e)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                        base_idx = torch.nonzero(data['input_ids'][0] == model.config.vision_start_token_id).item()
         | 
| 414 | 
            +
                        pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                        input_ids = data['input_ids'][0].tolist()
         | 
| 417 | 
            +
                        input_ids.insert(pos_s, model.config.seg_s_token_id)
         | 
| 418 | 
            +
                        input_ids.insert(pos_e, model.config.seg_e_token_id)
         | 
| 419 | 
            +
                        data['input_ids'] = torch.LongTensor([input_ids])
         | 
| 420 | 
            +
                        data['attention_mask'] = torch.ones_like(data['input_ids'])
         | 
| 421 | 
            +
                        # ===========================================
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                        data = data.to(device)
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        model.base_model.disable_adapter_layers()
         | 
| 426 | 
            +
                        model.base_model.enable_adapter_layers()
         | 
| 427 | 
            +
                        model.set_adapter('verifier')
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                        with torch.inference_mode():
         | 
| 430 | 
            +
                            logits = model(**data).logits[0, -1].softmax(dim=-1)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                        # NOTE: magic numbers here
         | 
| 433 | 
            +
                        # In Qwen2-VL vocab: 9454 -> Yes, 2753 -> No
         | 
| 434 | 
            +
                        score = (logits[9454] - logits[2753]).sigmoid().item()
         | 
| 435 | 
            +
                        prob.append(score)
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                        if i == 0:
         | 
| 438 | 
            +
                            history[-1]['content'] = history[-1]['content'].rstrip('.')[:-1]
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                        response = f'\nCandidate ID {i}: P(Yes) = {score:.2f}'
         | 
| 441 | 
            +
                        for j, text in enumerate(response.split(' ')):
         | 
| 442 | 
            +
                            history[-1]['content'] += ' ' + text if j > 0 else text
         | 
| 443 | 
            +
                            yield history
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    elapsed_time = round(time.perf_counter() - start_time, 1)
         | 
| 446 | 
            +
                    history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
         | 
| 447 | 
            +
                    yield history
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    ranks = torch.Tensor(prob).argsort(descending=True).tolist()
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    prob = [prob[idx] for idx in ranks]
         | 
| 452 | 
            +
                    pred = [pred[idx] for idx in ranks]
         | 
| 453 | 
            +
                    conf = [conf[idx] for idx in ranks]
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    response = 'After verification, the candidate moments are re-ranked as follows:\n'
         | 
| 456 | 
            +
                    response += '\n| ID | Start Time | End Time | Score |'
         | 
| 457 | 
            +
                    response += '\n| :-: | :-: | :-: | :-: |'
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    ids = list(range(len(ranks)))
         | 
| 460 | 
            +
                    for r, p, c in zip(ranks, pred, prob):
         | 
| 461 | 
            +
                        response += f'\n| {ids[r]} | {seconds_to_hms(p[0])} | {seconds_to_hms(p[1])} | {c:.2f} |'
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    response += f'\n\nTherefore, the target moment should be from <span style="color:red">{seconds_to_hms(pred[0][0])}</span> to <span style="color:red">{seconds_to_hms(pred[0][1])}</span>.'
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    history.append({'role': 'assistant', 'content': ''})
         | 
| 466 | 
            +
                    for i, text in enumerate(response.split(' ')):
         | 
| 467 | 
            +
                        history[-1]['content'] += ' ' + text if i > 0 else text
         | 
| 468 | 
            +
                        yield history
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                if 'ans' in role and do_answering:
         | 
| 471 | 
            +
                    text = f'{prompt} Please think step by step and provide your response.'
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    history.append({
         | 
| 474 | 
            +
                        'metadata': {
         | 
| 475 | 
            +
                            'title': '📝 Working as Answerer...'
         | 
| 476 | 
            +
                        },
         | 
| 477 | 
            +
                        'role': 'assistant',
         | 
| 478 | 
            +
                        'content': f'##### Answerer Prompt:\n\n{html.escape(text)}\n\n##### Answerer Response:\n\n...'
         | 
| 479 | 
            +
                    })
         | 
| 480 | 
            +
                    yield history
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    start_time = time.perf_counter()
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    # choose the potential best moment
         | 
| 485 | 
            +
                    selected = pred[0] if 'gnd' in role and do_grounding else [0, duration]
         | 
| 486 | 
            +
                    s, e = parse_span(selected, duration, 32)
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    messages = [{
         | 
| 489 | 
            +
                        'role':
         | 
| 490 | 
            +
                        'user',
         | 
| 491 | 
            +
                        'content': [{
         | 
| 492 | 
            +
                            'type': 'video',
         | 
| 493 | 
            +
                            'video': video,
         | 
| 494 | 
            +
                            'num_threads': 1,
         | 
| 495 | 
            +
                            'video_start': s,
         | 
| 496 | 
            +
                            'video_end': e,
         | 
| 497 | 
            +
                            'min_pixels': 128 * 28 * 28,
         | 
| 498 | 
            +
                            'max_pixels': 256 * 28 * 28,
         | 
| 499 | 
            +
                            'max_frames': 32,
         | 
| 500 | 
            +
                            'fps': 2.0
         | 
| 501 | 
            +
                        }, {
         | 
| 502 | 
            +
                            'type': 'text',
         | 
| 503 | 
            +
                            'text': text
         | 
| 504 | 
            +
                        }]
         | 
| 505 | 
            +
                    }]
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    text = processor.apply_chat_template(messages, add_generation_prompt=True)
         | 
| 508 | 
            +
                    images, videos = process_vision_info(messages)
         | 
| 509 | 
            +
                    data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
         | 
| 510 | 
            +
                    data = data.to(device)
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    with model.disable_adapter():
         | 
| 513 | 
            +
                        generation_kwargs = dict(
         | 
| 514 | 
            +
                            **data,
         | 
| 515 | 
            +
                            streamer=streamer,
         | 
| 516 | 
            +
                            do_sample=temperature > 0,
         | 
| 517 | 
            +
                            temperature=temperature if temperature > 0 else None,
         | 
| 518 | 
            +
                            top_p=None,
         | 
| 519 | 
            +
                            top_k=None,
         | 
| 520 | 
            +
                            repetition_penalty=None,
         | 
| 521 | 
            +
                            max_new_tokens=max_new_tokens)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                        t = Thread(target=model.generate, kwargs=generation_kwargs)
         | 
| 524 | 
            +
                        t.start()
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                        skipped = False
         | 
| 527 | 
            +
                        for i, text in enumerate(streamer):
         | 
| 528 | 
            +
                            if text and not skipped:
         | 
| 529 | 
            +
                                history[-1]['content'] = history[-1]['content'].rstrip('.')
         | 
| 530 | 
            +
                                skipped = True
         | 
| 531 | 
            +
                            history[-1]['content'] += text
         | 
| 532 | 
            +
                            yield history
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    elapsed_time = round(time.perf_counter() - start_time, 1)
         | 
| 535 | 
            +
                    history[-1]['metadata']['title'] += f' ({elapsed_time} seconds)'
         | 
| 536 | 
            +
                    yield history
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    if 'gnd' in role and do_grounding:
         | 
| 539 | 
            +
                        response = f'After zooming in and analyzing the target moment, I finalize my answer: <span style="color:green">{streamer.text_cache}</span>'
         | 
| 540 | 
            +
                    else:
         | 
| 541 | 
            +
                        response = f'After watching the whole video, my answer is: <span style="color:green">{streamer.text_cache}</span>'
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    history.append({'role': 'assistant', 'content': ''})
         | 
| 544 | 
            +
                    for i, text in enumerate(response.split(' ')):
         | 
| 545 | 
            +
                        history[-1]['content'] += ' ' + text if i > 0 else text
         | 
| 546 | 
            +
                        yield history
         | 
| 547 | 
            +
             | 
| 548 | 
            +
             | 
| 549 | 
            +
            if __name__ == '__main__':
         | 
| 550 | 
            +
                if not nncore.is_dir(BASE_MODEL):
         | 
| 551 | 
            +
                    snapshot_download(BASE_MODEL_HF, local_dir=BASE_MODEL)
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                if not nncore.is_dir(MODEL):
         | 
| 554 | 
            +
                    snapshot_download(MODEL_HF, local_dir=MODEL)
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                print('Initializing role *grounder*')
         | 
| 557 | 
            +
                model, processor = build_model(MODEL)
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                print('Initializing role *planner*')
         | 
| 560 | 
            +
                model.load_adapter(nncore.join(MODEL, 'planner'), adapter_name='planner')
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                print('Initializing role *verifier*')
         | 
| 563 | 
            +
                model.load_adapter(nncore.join(MODEL, 'verifier'), adapter_name='verifier')
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                streamer = CustomStreamer(processor.tokenizer, skip_prompt=True)
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                device = next(model.parameters()).device
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                main = partial(main, model=model, processor=processor, streamer=streamer, device=device)
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                path = os.path.dirname(os.path.realpath(__file__))
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                chat = gr.Chatbot(
         | 
| 574 | 
            +
                    type='messages',
         | 
| 575 | 
            +
                    height='70vh',
         | 
| 576 | 
            +
                    avatar_images=[f'{path}/assets/user.png', f'{path}/assets/bot.png'],
         | 
| 577 | 
            +
                    placeholder='A conversation with VideoMind',
         | 
| 578 | 
            +
                    label='VideoMind')
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                prompt = gr.Textbox(label='Text Prompt', placeholder='Ask a question about the video...')
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                with gr.Blocks(title=TITLE, css=CSS, js=JS) as demo:
         | 
| 583 | 
            +
                    gr.Markdown(TITLE_MD)
         | 
| 584 | 
            +
                    gr.Markdown(DESCRIPTION_MD)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    with gr.Row():
         | 
| 587 | 
            +
                        with gr.Column(scale=3):
         | 
| 588 | 
            +
                            video = gr.Video()
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                            with gr.Group():
         | 
| 591 | 
            +
                                role = gr.CheckboxGroup(
         | 
| 592 | 
            +
                                    choices=[('🗺️ Planner', 'pla'), ('🔍 Grounder', 'gnd'), ('📊 Verifier', 'ver'),
         | 
| 593 | 
            +
                                             ('📝 Answerer', 'ans')],
         | 
| 594 | 
            +
                                    value=['pla', 'gnd', 'ver', 'ans'],
         | 
| 595 | 
            +
                                    interactive=True,
         | 
| 596 | 
            +
                                    elem_id='role',
         | 
| 597 | 
            +
                                    label='Role(s) To Use',
         | 
| 598 | 
            +
                                    info='[Auto Planning]: Planner + Grounder + Verifier + Answerer<br>'
         | 
| 599 | 
            +
                                    '[Grounded Video Question-Answering]: Grounder + Verifier + Answerer<br>'
         | 
| 600 | 
            +
                                    '[Video Temporal Grounding]: Grounder + Verifier<br>'
         | 
| 601 | 
            +
                                    '[Direct Video Question-Answering]: Answerer<br>')
         | 
| 602 | 
            +
                                role.change(update_placeholder, role, prompt)
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                                with gr.Accordion(label='Hyperparameters', open=False):
         | 
| 605 | 
            +
                                    temperature = gr.Slider(
         | 
| 606 | 
            +
                                        0,
         | 
| 607 | 
            +
                                        1,
         | 
| 608 | 
            +
                                        value=0,
         | 
| 609 | 
            +
                                        step=0.1,
         | 
| 610 | 
            +
                                        interactive=True,
         | 
| 611 | 
            +
                                        label='Temperature',
         | 
| 612 | 
            +
                                        info='Higher value leads to more creativity and randomness (Default: 0)')
         | 
| 613 | 
            +
                                    max_new_tokens = gr.Slider(
         | 
| 614 | 
            +
                                        1,
         | 
| 615 | 
            +
                                        1024,
         | 
| 616 | 
            +
                                        value=256,
         | 
| 617 | 
            +
                                        interactive=True,
         | 
| 618 | 
            +
                                        label='Max Output Tokens',
         | 
| 619 | 
            +
                                        info='The maximum number of output tokens for each role (Default: 256)')
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                            prompt.render()
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                            with gr.Row():
         | 
| 624 | 
            +
                                random_btn = gr.Button(value='🔮 Random')
         | 
| 625 | 
            +
                                random_btn.click(lambda: random.choice(EXAMPLES), None, [video, prompt, role])
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                                reset_btn = gr.ClearButton([video, prompt, chat], value='🗑️ Reset')
         | 
| 628 | 
            +
                                reset_btn.click(lambda: (['pla', 'gnd', 'ver', 'ans'], 0, 256), None,
         | 
| 629 | 
            +
                                                [role, temperature, max_new_tokens])
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                                submit_btn = gr.Button(value='🚀 Submit', variant='primary')
         | 
| 632 | 
            +
                                submit_ctx = submit_btn.click(disable_btns, None, [random_btn, reset_btn, submit_btn])
         | 
| 633 | 
            +
                                submit_ctx = submit_ctx.then(main, [video, prompt, role, temperature, max_new_tokens], chat)
         | 
| 634 | 
            +
                                submit_ctx.then(enable_btns, None, [random_btn, reset_btn, submit_btn])
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                        with gr.Column(scale=5):
         | 
| 637 | 
            +
                            chat.render()
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    demo.queue()
         | 
| 640 | 
            +
                    demo.launch(server_name='0.0.0.0')
         | 
    	
        assets/bot.png
    ADDED
    
    |   | 
    	
        assets/user.png
    ADDED
    
    |   | 
    	
        data/10309844035.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8996ff134787d6b769c2491b9079a02c05953465ad770f07a8d9138e2668d24f
         | 
| 3 | 
            +
            size 4041678
         | 
    	
        data/13887487955.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:e5fecab1076ee42b3804718f9f64bef06cbfafd6995ad5f5ee42ba6354721429
         | 
| 3 | 
            +
            size 5544739
         | 
    	
        data/4167294363.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3d0e0a4a381836f68e16a816d87f241fed3e31ea321f544b921743d6c1c50666
         | 
| 3 | 
            +
            size 6611151
         | 
    	
        data/4742652230.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8733ab4b0716d13ea7a79fc4ddacaf9eede567db364f0ecddfa4582c2f237f82
         | 
| 3 | 
            +
            size 2200304
         | 
    	
        data/4766274786.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:afa38a9ce9e89f934293214d79755c89159664223b3ca366813fd5fe524ed013
         | 
| 3 | 
            +
            size 3395545
         | 
    	
        data/5012237466.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:cd1929aa93d037f809f402e9801047125dc9fe8060301e69ded9ba1f2d785cc8
         | 
| 3 | 
            +
            size 4822293
         | 
    	
        data/5188348585.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b225f448a546ba2f65958f18c6731a6dde9b1f437014e90036b22eb40e9ad0a5
         | 
| 3 | 
            +
            size 5051675
         | 
    	
        data/9383140374.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:30b6b3eb43f711bef194150d473a59850ff5d7fec0f5cc30e7526aa9e382303f
         | 
| 3 | 
            +
            size 2518081
         | 
    	
        data/DTInxNfWXVc_210.0_360.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a09eee0dc404688731fb768c120d3519605f2343376b9bd727a71b91379fd9a9
         | 
| 3 | 
            +
            size 4999970
         | 
    	
        data/RoripwjYFp8_210.0_360.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4b39b15158dc20c0bc6f1758a9239c8f3eed20ba4a90953338eec2246fa8f1f0
         | 
| 3 | 
            +
            size 9287252
         | 
    	
        data/UFWQKrcbhjI_360.0_510.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8669153d9ffac4b5534c20fab8d795347f5babe588da9b8330e049d623ebb443
         | 
| 3 | 
            +
            size 14510618
         | 
    	
        data/Z3-IZ3HAmIA_60.0_210.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4b3a342993ee61efc5f3b859cd9c1e0d360b3331eed9deb8466891e4bcacc554
         | 
| 3 | 
            +
            size 14397799
         | 
    	
        data/h6QKDqomIPk_210.0_360.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:103820de2b8a1a3935b39ed80d91cd08e546e5617310b3d1bb3dadb06b2ffb95
         | 
| 3 | 
            +
            size 13485144
         | 
    	
        data/pA6Z-qYhSNg_60.0_210.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c84660fd4ebd8c23a2a7364174b1e819fec8b0e1cb8b9d9cd86f9e429cbdf66c
         | 
| 3 | 
            +
            size 8658509
         | 
    	
        data/rrTIeJRVGjg_60.0_210.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:efe6f48a49963bd4880ef5065840e05dd25e2aa975870140bcdaf4220bbd2827
         | 
| 3 | 
            +
            size 11410412
         | 
    	
        data/yId2wIocTys_210.0_360.0.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:447fcb1fd1f94ed6a88d56dd0f6f859646cb8c58ed8e3b7a82f374e2cfee1646
         | 
| 3 | 
            +
            size 14769130
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            accelerate==1.2.1
         | 
| 2 | 
            +
            decord==0.6.0
         | 
| 3 | 
            +
            gradio==4.44.1
         | 
| 4 | 
            +
            pandas==2.2.3
         | 
| 5 | 
            +
            peft==0.14.0
         | 
| 6 | 
            +
            pysrt==1.1.2
         | 
| 7 | 
            +
            scikit-image==0.25.0
         | 
| 8 | 
            +
            scikit-learn==1.6.1
         | 
| 9 | 
            +
            sentencepiece==0.2.0
         | 
| 10 | 
            +
            termplotlib==0.3.9
         | 
| 11 | 
            +
            triton==3.0.0
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # our codebase contains necessary patches for 4.45.2
         | 
| 14 | 
            +
            transformers==4.45.2
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # https://github.com/microsoft/DeepSpeed/issues/6793
         | 
| 17 | 
            +
            deepspeed==0.15.4
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # https://github.com/pytorch/pytorch/issues/138386
         | 
| 20 | 
            +
            torch==2.4.1
         | 
| 21 | 
            +
            torchvision==0.19.1
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # torch-npu only supports torch 2.4.0
         | 
| 24 | 
            +
            # torch==2.4.0+cpu
         | 
| 25 | 
            +
            # torch-npu==2.4.0.post2
         | 
| 26 | 
            +
            # torchvision==0.19.0+cpu
         | 
    	
        setup.cfg
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [yapf]
         | 
| 2 | 
            +
            column_limit = 120
         | 
| 3 | 
            +
            based_on_style = pep8
         | 
| 4 | 
            +
            blank_line_before_nested_class_or_def = true
         | 
| 5 | 
            +
            split_before_expression_after_opening_paren = true
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            [isort]
         | 
| 8 | 
            +
            line_length = 120
         | 
| 9 | 
            +
            multi_line_output = 0
         | 
| 10 | 
            +
            known_third_party = decord,deepspeed,gradio,huggingface_hub,nncore,numpy,pandas,peft,PIL,pysrt,safetensors,tabulate,termplotlib,torch,torchvision,transformers
         | 
| 11 | 
            +
            no_lines_before = STDLIB,LOCALFOLDER
         | 
| 12 | 
            +
            default_section = FIRSTPARTY
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            [flake8]
         | 
| 15 | 
            +
            max-line-length = 500
         | 
| 16 | 
            +
            extend-ignore = E741
         | 
    	
        videomind/constants.py
    ADDED
    
    | @@ -0,0 +1,42 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            IGNORE_INDEX = -100
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            REG_TOKEN = '<|reg|>'
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            SEG_S_TOKEN = '<|seg_start|>'
         | 
| 8 | 
            +
            SEG_E_TOKEN = '<|seg_end|>'
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            PLANNER_PROMPT = (
         | 
| 11 | 
            +
                'You are acting as the planner now. '
         | 
| 12 | 
            +
                'Given a question about the video, your task is to analyze the question and identify the best way to answer this question. '
         | 
| 13 | 
            +
                'You have access to the following tools:\n\n'
         | 
| 14 | 
            +
                'Grounder: Accepts a text query and localize the relevant video segment according to the query.\n'
         | 
| 15 | 
            +
                'Verifier: A tool supporting grounder by verifying the reliability of its outputs.\n'
         | 
| 16 | 
            +
                'Answerer: Answer a given question directly based on the whole video or a cropped video segment.\n\n'
         | 
| 17 | 
            +
                'Your response must be a list in JSON format. '
         | 
| 18 | 
            +
                'A valid plan for reasoning could be "grounder, verifier, answer", "grounder, verifier", or "answerer", depending on the given question. '
         | 
| 19 | 
            +
                'Please see an example for the format below.\n\n'
         | 
| 20 | 
            +
                '[{{"type": "grounder", "value": "<text query>"}}, {{"type": "verifier"}}, {{"type": "answerer"}}]\n\n'
         | 
| 21 | 
            +
                'Note that only the grounder can accept an argument called "value", which is the text query used for grounding. '
         | 
| 22 | 
            +
                "Now I give you the question: '{}'. "
         | 
| 23 | 
            +
                'Please think carefully and respond with your plan in JSON directly.')
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            GROUNDER_PROMPT = (
         | 
| 26 | 
            +
                'You are acting as the grounder now. '
         | 
| 27 | 
            +
                'Given a video and a text query, your goal is to temporally localize the video moment described by the query. '
         | 
| 28 | 
            +
                'If the query is directly describing a moment, simply localize it according to its content. '
         | 
| 29 | 
            +
                "Otherwise, if the moment is described as 'before/after a pivotal event', you need to determine the actual event it refers to. "
         | 
| 30 | 
            +
                'The localized moment should only cover the target event. '
         | 
| 31 | 
            +
                "Now I give you the query: '{}'. "
         | 
| 32 | 
            +
                'Please think carefully and provide your response.')
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            VERIFIER_PROMPT = (
         | 
| 35 | 
            +
                'You are acting as the verifier now. '
         | 
| 36 | 
            +
                'You will be presented a text query describing a moment that potentialy happens in the given video. '
         | 
| 37 | 
            +
                f'Your task is to identify whether the video segment between {SEG_S_TOKEN} and {SEG_E_TOKEN} perfectly covers the moment. '
         | 
| 38 | 
            +
                f'If the described moment can be seen in the video, please focus on verifying whether the moment starts at {SEG_S_TOKEN} and ends at {SEG_E_TOKEN}. '
         | 
| 39 | 
            +
                "Respond with 'Yes' if you think the moment boundaries are correct, otherwise 'No'. "
         | 
| 40 | 
            +
                "If the described moment cannot be seen in the video, respond with 'No' directly. "
         | 
| 41 | 
            +
                "Now I give you the query: '{}'. "
         | 
| 42 | 
            +
                "Please think carefully and respond with 'Yes' or 'No' directly.")
         | 
    	
        videomind/conversation.py
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            @dataclass
         | 
| 8 | 
            +
            class Conversation:
         | 
| 9 | 
            +
                style: str
         | 
| 10 | 
            +
                system: str
         | 
| 11 | 
            +
                roles: List[str]
         | 
| 12 | 
            +
                seps: List[str]
         | 
| 13 | 
            +
                messages: List[str]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                def append_message(self, role, msg):
         | 
| 16 | 
            +
                    self.messages.append([role, msg])
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def clear(self):
         | 
| 19 | 
            +
                    self.messages = []
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def get_prompt(self):
         | 
| 22 | 
            +
                    assert self.style in ('chatml', )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    prompt = self.system + self.seps[0] if self.system is not None else ''
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    for i, (role, msg) in enumerate(self.messages):
         | 
| 27 | 
            +
                        prompt += role
         | 
| 28 | 
            +
                        sep = self.seps[i % 2]
         | 
| 29 | 
            +
                        if msg is not None:
         | 
| 30 | 
            +
                            prompt += msg
         | 
| 31 | 
            +
                            if not prompt.endswith(sep):
         | 
| 32 | 
            +
                                prompt += sep
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    prompt = prompt.lstrip('\n')
         | 
| 35 | 
            +
                    return prompt
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            def get_conv(conv_type):
         | 
| 39 | 
            +
                if conv_type == 'chatml':
         | 
| 40 | 
            +
                    conv = Conversation(
         | 
| 41 | 
            +
                        style='chatml',
         | 
| 42 | 
            +
                        system='<|im_start|>system\nYou are a helpful assistant.',
         | 
| 43 | 
            +
                        roles=('\n<|im_start|>user\n', '\n<|im_start|>assistant\n'),
         | 
| 44 | 
            +
                        seps=('<|im_end|>', '<|im_end|>'),
         | 
| 45 | 
            +
                        messages=[])
         | 
| 46 | 
            +
                else:
         | 
| 47 | 
            +
                    raise ValueError(f'unknown conversation type: {conv_type}')
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                return conv
         | 
    	
        videomind/dataset/__init__.py
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .collator import HybridDataCollator
         | 
| 2 | 
            +
            from .hybrid import HybridDataset
         | 
| 3 | 
            +
            from .sub_classes import (ActivitynetCaptionsBiasDataset, ActivitynetCaptionsDataset, ActivitynetRTLDataset,
         | 
| 4 | 
            +
                                      CGBenchDataset, CharadesSTADataset, CosMoCapDataset, DiDeMoDataset, Ego4DNaQDataset,
         | 
| 5 | 
            +
                                      Ego4DNLQDataset, EgoTimeQACropDataset, EgoTimeQADataset, EgoTimeQAGroundingDataset,
         | 
| 6 | 
            +
                                      HiRESTGroundingDataset, HiRESTStepBiasDataset, HiRESTStepDataset, InternVidVTimeDataset,
         | 
| 7 | 
            +
                                      LongVideoBenchDataset, LVBenchDataset, MLVUDataset, MVBenchDataset, NExTGQACropDataset,
         | 
| 8 | 
            +
                                      NExTGQADataset, NExTGQAGroundingDataset, NExTQADataset, QAEgo4DCropDataset, QAEgo4DDataset,
         | 
| 9 | 
            +
                                      QAEgo4DGroundingDataset, QuerYDDataset, QVHighlightsDataset, ReXTimeCropDataset,
         | 
| 10 | 
            +
                                      ReXTimeDataset, ReXTimeGroundingDataset, STARDataset, TACoSDataset, VideoMMEDataset,
         | 
| 11 | 
            +
                                      VideoXumDataset, VidMorpDataset, YouCook2BiasDataset, YouCook2Dataset)
         | 
| 12 | 
            +
            from .wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset, PlanningDataset, VerifyingDataset
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            __all__ = [
         | 
| 15 | 
            +
                'HybridDataCollator',
         | 
| 16 | 
            +
                'HybridDataset',
         | 
| 17 | 
            +
                'ActivitynetCaptionsBiasDataset',
         | 
| 18 | 
            +
                'ActivitynetCaptionsDataset',
         | 
| 19 | 
            +
                'ActivitynetRTLDataset',
         | 
| 20 | 
            +
                'CGBenchDataset',
         | 
| 21 | 
            +
                'CharadesSTADataset',
         | 
| 22 | 
            +
                'CosMoCapDataset',
         | 
| 23 | 
            +
                'DiDeMoDataset',
         | 
| 24 | 
            +
                'Ego4DNaQDataset',
         | 
| 25 | 
            +
                'Ego4DNLQDataset',
         | 
| 26 | 
            +
                'EgoTimeQACropDataset',
         | 
| 27 | 
            +
                'EgoTimeQADataset',
         | 
| 28 | 
            +
                'EgoTimeQAGroundingDataset',
         | 
| 29 | 
            +
                'HiRESTGroundingDataset',
         | 
| 30 | 
            +
                'HiRESTStepBiasDataset',
         | 
| 31 | 
            +
                'HiRESTStepDataset',
         | 
| 32 | 
            +
                'InternVidVTimeDataset',
         | 
| 33 | 
            +
                'LongVideoBenchDataset',
         | 
| 34 | 
            +
                'LVBenchDataset',
         | 
| 35 | 
            +
                'MLVUDataset',
         | 
| 36 | 
            +
                'MVBenchDataset',
         | 
| 37 | 
            +
                'NExTGQACropDataset',
         | 
| 38 | 
            +
                'NExTGQADataset',
         | 
| 39 | 
            +
                'NExTGQAGroundingDataset',
         | 
| 40 | 
            +
                'NExTQADataset',
         | 
| 41 | 
            +
                'QAEgo4DCropDataset',
         | 
| 42 | 
            +
                'QAEgo4DDataset',
         | 
| 43 | 
            +
                'QAEgo4DGroundingDataset',
         | 
| 44 | 
            +
                'QuerYDDataset',
         | 
| 45 | 
            +
                'QVHighlightsDataset',
         | 
| 46 | 
            +
                'ReXTimeCropDataset',
         | 
| 47 | 
            +
                'ReXTimeDataset',
         | 
| 48 | 
            +
                'ReXTimeGroundingDataset',
         | 
| 49 | 
            +
                'STARDataset',
         | 
| 50 | 
            +
                'TACoSDataset',
         | 
| 51 | 
            +
                'VideoMMEDataset',
         | 
| 52 | 
            +
                'VideoXumDataset',
         | 
| 53 | 
            +
                'VidMorpDataset',
         | 
| 54 | 
            +
                'YouCook2BiasDataset',
         | 
| 55 | 
            +
                'YouCook2Dataset',
         | 
| 56 | 
            +
                'AnsweringCropDataset',
         | 
| 57 | 
            +
                'AnsweringDataset',
         | 
| 58 | 
            +
                'GroundingDataset',
         | 
| 59 | 
            +
                'PlanningDataset',
         | 
| 60 | 
            +
                'VerifyingDataset',
         | 
| 61 | 
            +
            ]
         | 
    	
        videomind/dataset/collator.py
    ADDED
    
    | @@ -0,0 +1,40 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import warnings
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.nn.utils.rnn import pad_sequence
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from videomind.constants import IGNORE_INDEX
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class HybridDataCollator(object):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def __init__(self, tokenizer):
         | 
| 14 | 
            +
                    self.tokenizer = tokenizer
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __call__(self, batch):
         | 
| 17 | 
            +
                    input_ids = [d['input_ids'] for d in batch]
         | 
| 18 | 
            +
                    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    labels = [d['labels'] for d in batch]
         | 
| 21 | 
            +
                    labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    assert input_ids.size() == labels.size()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    seq_len, max_len = input_ids.size(1), self.tokenizer.model_max_length
         | 
| 26 | 
            +
                    if seq_len > max_len:
         | 
| 27 | 
            +
                        warnings.warn(f'The length of input sequence is exceeding model max length: {seq_len} > {max_len}')
         | 
| 28 | 
            +
                        input_ids, labels = input_ids[:, :max_len], labels[:, :max_len]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    data = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids != self.tokenizer.pad_token_id)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    for key in ('pixel_values', 'pixel_values_videos', 'image_grid_thw', 'video_grid_thw'):
         | 
| 33 | 
            +
                        if key in batch[0]:
         | 
| 34 | 
            +
                            data[key] = torch.cat([d[key] for d in batch])
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    for key in ('timestamps', 'saliency', 'pos_clip'):
         | 
| 37 | 
            +
                        if key in batch[0]:
         | 
| 38 | 
            +
                            data[key] = [d[key] for d in batch]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    return data
         | 
    	
        videomind/dataset/hybrid.py
    ADDED
    
    | @@ -0,0 +1,180 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            from collections import defaultdict
         | 
| 6 | 
            +
            from itertools import accumulate
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import nncore
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import termplotlib as tpl
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            from tabulate import tabulate
         | 
| 13 | 
            +
            from torch.utils.data import Dataset
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from videomind.constants import IGNORE_INDEX
         | 
| 16 | 
            +
            from videomind.dataset.utils import preprocess, process_vision_info
         | 
| 17 | 
            +
            from videomind.utils.parser import parse_span
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            DATASETS = nncore.Registry('datasets')
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class HybridDataset(Dataset):
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self, processor, model_config, model_args, data_args, training_args):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    datasets = []
         | 
| 28 | 
            +
                    for key in data_args.datasets.split(','):
         | 
| 29 | 
            +
                        datasets.append(DATASETS.get(key)(processor, model_args, data_args, training_args))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    data_types = [a['data_type'] for d in datasets for a in d.annos]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    cum_length = [0] + list(accumulate([len(d) for d in datasets]))
         | 
| 34 | 
            +
                    idx_ranges = [[cum_length[i], cum_length[i + 1]] for i in range(len(cum_length) - 1)]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    if training_args.local_rank in (0, -1):
         | 
| 37 | 
            +
                        raw_length = sum(d.raw_length for d in datasets)
         | 
| 38 | 
            +
                        cur_length = idx_ranges[-1][-1]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        ratio = round(cur_length / raw_length * 100, 2)
         | 
| 41 | 
            +
                        print(f'Number of samples: {raw_length} (original) -> {cur_length} (filtered) {ratio}%')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        data_type_cnt = ' '.join([f'{data_types.count(t)} ({t})' for t in list(set(data_types))])
         | 
| 44 | 
            +
                        print(f'Data types: {data_type_cnt}')
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        tab = defaultdict(int)
         | 
| 47 | 
            +
                        for dataset in datasets:
         | 
| 48 | 
            +
                            for anno in dataset.annos:
         | 
| 49 | 
            +
                                tab[anno.get('source', 'unknown')] += 1
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        tab = [[k, v, round(v / cur_length, 3)] for k, v in tab.items()]
         | 
| 52 | 
            +
                        print(tabulate(tab, headers=['Source', '#Samples', 'Ratio'], tablefmt='pretty', stralign='left'))
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                        d, _ = torch.Tensor([a['duration'] for d in datasets for a in d.annos if 'duration' in a]).sort()
         | 
| 55 | 
            +
                        if d.size(0) > 0:
         | 
| 56 | 
            +
                            n, r = min(d.size(0), 10), d.flip(0)
         | 
| 57 | 
            +
                            print(f'Top-{n} max video durations: {[round(r[i].item(), 1) for i in range(n)]}')
         | 
| 58 | 
            +
                            print(f'Top-{n} min video durations: {[round(d[i].item(), 1) for i in range(n)]}')
         | 
| 59 | 
            +
                            print(f'Average video duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s')
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                            print('Video duration histogram:')
         | 
| 62 | 
            +
                            counts, edges = np.histogram(d)
         | 
| 63 | 
            +
                            labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)]
         | 
| 64 | 
            +
                            fig = tpl.figure()
         | 
| 65 | 
            +
                            fig.barh(counts, labels)
         | 
| 66 | 
            +
                            fig.show()
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                        d, _ = torch.Tensor([abs(b[0] - b[1]) for d in datasets for a in d.annos if 'span' in a
         | 
| 69 | 
            +
                                             for b in a['span']]).sort()
         | 
| 70 | 
            +
                        if d.size(0) > 0:
         | 
| 71 | 
            +
                            n, r = min(d.size(0), 10), d.flip(0)
         | 
| 72 | 
            +
                            print(f'Top-{n} max span durations: {[round(r[i].item(), 1) for i in range(n)]}')
         | 
| 73 | 
            +
                            print(f'Top-{n} min span durations: {[round(d[i].item(), 1) for i in range(n)]}')
         | 
| 74 | 
            +
                            print(f'Average span duration ({d.size(0)} samples): {round(d.mean().item(), 1)}s')
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                            print('Span duration histogram:')
         | 
| 77 | 
            +
                            counts, edges = np.histogram(d)
         | 
| 78 | 
            +
                            labels = [f'{edges[i]:.2f}s - {edges[i + 1]:.2f}s' for i in range(len(edges) - 1)]
         | 
| 79 | 
            +
                            fig = tpl.figure()
         | 
| 80 | 
            +
                            fig.barh(counts, labels)
         | 
| 81 | 
            +
                            fig.show()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    self.datasets = datasets
         | 
| 84 | 
            +
                    self.data_types = data_types
         | 
| 85 | 
            +
                    self.idx_ranges = idx_ranges
         | 
| 86 | 
            +
                    self.processor = processor
         | 
| 87 | 
            +
                    self.model_config = model_config
         | 
| 88 | 
            +
                    self.model_args = model_args
         | 
| 89 | 
            +
                    self.data_args = data_args
         | 
| 90 | 
            +
                    self.training_args = training_args
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def __len__(self):
         | 
| 93 | 
            +
                    return self.idx_ranges[-1][-1]
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def __getitem__(self, idx):
         | 
| 96 | 
            +
                    for retry in range(self.data_args.max_retries + 1):
         | 
| 97 | 
            +
                        try:
         | 
| 98 | 
            +
                            return self.fetch_data(idx)
         | 
| 99 | 
            +
                        except Exception as e:
         | 
| 100 | 
            +
                            print(f'Error in loading {idx}: {type(e).__name__}({e})')
         | 
| 101 | 
            +
                            idx = random.choice([i for i, t in enumerate(self.data_types) if t == self.data_types[idx]])
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    raise RuntimeError(f'Data loading failed after {retry} retries')
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def map(self, *args, **kwargs):
         | 
| 106 | 
            +
                    return self
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def fetch_data(self, idx):
         | 
| 109 | 
            +
                    for (s, e), dataset in zip(self.idx_ranges, self.datasets):
         | 
| 110 | 
            +
                        if s <= idx < e:
         | 
| 111 | 
            +
                            meta = dataset[idx - s]
         | 
| 112 | 
            +
                            break
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    text = self.processor.apply_chat_template(meta['messages'])
         | 
| 115 | 
            +
                    text = [text.strip()]
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                    images, videos = process_vision_info(meta['messages'], sanity_check=True)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    data = self.processor(text=text, images=images, videos=videos, return_tensors='pt')
         | 
| 120 | 
            +
                    assert data['input_ids'].size(0) == 1
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    data['input_ids'] = data['input_ids'][0]
         | 
| 123 | 
            +
                    data['labels'] = preprocess(data['input_ids'], text[0], self.processor.tokenizer, self.model_args.conv_type)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # insert segment start/end tokens
         | 
| 126 | 
            +
                    if 'ss' in meta and 'se' in meta:
         | 
| 127 | 
            +
                        video_grid_thw = data['video_grid_thw'][0]
         | 
| 128 | 
            +
                        num_frames, window = int(video_grid_thw[0]), int(video_grid_thw[1] * video_grid_thw[2] / 4)
         | 
| 129 | 
            +
                        assert num_frames * window * 4 == data['pixel_values_videos'].size(0)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                        pos_s, pos_e = round(meta['ss'] * num_frames), round(meta['se'] * num_frames)
         | 
| 132 | 
            +
                        pos_s, pos_e = min(max(0, pos_s), num_frames), min(max(0, pos_e), num_frames)
         | 
| 133 | 
            +
                        assert pos_s <= pos_e, (num_frames, meta['ss'], meta['se'])
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                        base_idx = torch.nonzero(data['input_ids'] == self.model_config.vision_start_token_id).item()
         | 
| 136 | 
            +
                        pos_s, pos_e = pos_s * window + base_idx + 1, pos_e * window + base_idx + 2
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        input_ids = data['input_ids'].tolist()
         | 
| 139 | 
            +
                        input_ids.insert(pos_s, self.model_config.seg_s_token_id)
         | 
| 140 | 
            +
                        input_ids.insert(pos_e, self.model_config.seg_e_token_id)
         | 
| 141 | 
            +
                        data['input_ids'] = torch.LongTensor(input_ids)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                        labels = data['labels'].tolist()
         | 
| 144 | 
            +
                        labels.insert(pos_s, IGNORE_INDEX)
         | 
| 145 | 
            +
                        labels.insert(pos_e, IGNORE_INDEX)
         | 
| 146 | 
            +
                        data['labels'] = torch.LongTensor(labels)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if 'span' in meta:
         | 
| 149 | 
            +
                        span, duration = meta['span'], meta['duration']
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                        pixel_values_videos, video_grid_thw = data['pixel_values_videos'], data['video_grid_thw']
         | 
| 152 | 
            +
                        num_frames = int(video_grid_thw[0][0])
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        assert video_grid_thw.size(0) == 1
         | 
| 155 | 
            +
                        assert video_grid_thw.prod() == pixel_values_videos.size(0)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                        # actual fps would be 1/2 of config (temporal patch size = 2)
         | 
| 158 | 
            +
                        fps = num_frames / duration
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                        safe_span = [parse_span(b, duration, 1 / fps) for b in span]
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        # num_reg_tokens -> num_bnds -> s & e
         | 
| 163 | 
            +
                        timestamps = [[[s / duration, e / duration] for s, e in safe_span]]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                        saliency, pos_inds = torch.zeros(num_frames), []
         | 
| 166 | 
            +
                        for s, e in safe_span:
         | 
| 167 | 
            +
                            span_ind = max(0, s * fps), min(e * fps, num_frames)
         | 
| 168 | 
            +
                            pos_inds = list(range(math.ceil(span_ind[0]), math.ceil(span_ind[1])))
         | 
| 169 | 
            +
                            assert len(pos_inds) > 0, f'empty pos_inds ({idx}): {fps} {num_frames} {duration} {span}'
         | 
| 170 | 
            +
                            saliency[pos_inds] = 1
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                        assert saliency.any(), f'empty saliency ({idx}): {pos_inds} {fps} {num_frames} {duration} {span}'
         | 
| 173 | 
            +
                        pos_clip = random.sample(saliency.nonzero()[:, 0].tolist(), 1)
         | 
| 174 | 
            +
                        pos_clip = torch.LongTensor(pos_clip)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                        data['timestamps'] = timestamps
         | 
| 177 | 
            +
                        data['saliency'] = saliency
         | 
| 178 | 
            +
                        data['pos_clip'] = pos_clip
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    return data
         | 
    	
        videomind/dataset/sub_classes/__init__.py
    ADDED
    
    | @@ -0,0 +1,69 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .activitynet_captions import ActivitynetCaptionsBiasDataset, ActivitynetCaptionsDataset
         | 
| 2 | 
            +
            from .activitynet_rtl import ActivitynetRTLDataset
         | 
| 3 | 
            +
            from .cgbench import CGBenchDataset
         | 
| 4 | 
            +
            from .charades_sta import CharadesSTADataset
         | 
| 5 | 
            +
            from .cosmo_cap import CosMoCapDataset
         | 
| 6 | 
            +
            from .didemo import DiDeMoDataset
         | 
| 7 | 
            +
            from .ego4d_naq import Ego4DNaQDataset
         | 
| 8 | 
            +
            from .ego4d_nlq import Ego4DNLQDataset
         | 
| 9 | 
            +
            from .ego_timeqa import EgoTimeQACropDataset, EgoTimeQADataset, EgoTimeQAGroundingDataset
         | 
| 10 | 
            +
            from .hirest import HiRESTGroundingDataset, HiRESTStepBiasDataset, HiRESTStepDataset
         | 
| 11 | 
            +
            from .internvit_vtime import InternVidVTimeDataset
         | 
| 12 | 
            +
            from .longvideobench import LongVideoBenchDataset
         | 
| 13 | 
            +
            from .lvbench import LVBenchDataset
         | 
| 14 | 
            +
            from .mlvu import MLVUDataset
         | 
| 15 | 
            +
            from .mvbench import MVBenchDataset
         | 
| 16 | 
            +
            from .nextgqa import NExTGQACropDataset, NExTGQADataset, NExTGQAGroundingDataset
         | 
| 17 | 
            +
            from .nextqa import NExTQADataset
         | 
| 18 | 
            +
            from .qa_ego4d import QAEgo4DCropDataset, QAEgo4DDataset, QAEgo4DGroundingDataset
         | 
| 19 | 
            +
            from .queryd import QuerYDDataset
         | 
| 20 | 
            +
            from .qvhighlights import QVHighlightsDataset
         | 
| 21 | 
            +
            from .rextime import ReXTimeCropDataset, ReXTimeDataset, ReXTimeGroundingDataset
         | 
| 22 | 
            +
            from .star import STARDataset
         | 
| 23 | 
            +
            from .tacos import TACoSDataset
         | 
| 24 | 
            +
            from .vid_morp import VidMorpDataset
         | 
| 25 | 
            +
            from .videomme import VideoMMEDataset
         | 
| 26 | 
            +
            from .videoxum import VideoXumDataset
         | 
| 27 | 
            +
            from .youcook2 import YouCook2BiasDataset, YouCook2Dataset
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            __all__ = [
         | 
| 30 | 
            +
                'ActivitynetCaptionsBiasDataset',
         | 
| 31 | 
            +
                'ActivitynetCaptionsDataset',
         | 
| 32 | 
            +
                'ActivitynetRTLDataset',
         | 
| 33 | 
            +
                'CGBenchDataset',
         | 
| 34 | 
            +
                'CharadesSTADataset',
         | 
| 35 | 
            +
                'CosMoCapDataset',
         | 
| 36 | 
            +
                'DiDeMoDataset',
         | 
| 37 | 
            +
                'Ego4DNaQDataset',
         | 
| 38 | 
            +
                'Ego4DNLQDataset',
         | 
| 39 | 
            +
                'EgoTimeQACropDataset',
         | 
| 40 | 
            +
                'EgoTimeQADataset',
         | 
| 41 | 
            +
                'EgoTimeQAGroundingDataset',
         | 
| 42 | 
            +
                'HiRESTGroundingDataset',
         | 
| 43 | 
            +
                'HiRESTStepBiasDataset',
         | 
| 44 | 
            +
                'HiRESTStepDataset',
         | 
| 45 | 
            +
                'InternVidVTimeDataset',
         | 
| 46 | 
            +
                'LongVideoBenchDataset',
         | 
| 47 | 
            +
                'LVBenchDataset',
         | 
| 48 | 
            +
                'MLVUDataset',
         | 
| 49 | 
            +
                'MVBenchDataset',
         | 
| 50 | 
            +
                'NExTGQACropDataset',
         | 
| 51 | 
            +
                'NExTGQADataset',
         | 
| 52 | 
            +
                'NExTGQAGroundingDataset',
         | 
| 53 | 
            +
                'NExTQADataset',
         | 
| 54 | 
            +
                'QAEgo4DCropDataset',
         | 
| 55 | 
            +
                'QAEgo4DDataset',
         | 
| 56 | 
            +
                'QAEgo4DGroundingDataset',
         | 
| 57 | 
            +
                'QuerYDDataset',
         | 
| 58 | 
            +
                'QVHighlightsDataset',
         | 
| 59 | 
            +
                'ReXTimeCropDataset',
         | 
| 60 | 
            +
                'ReXTimeDataset',
         | 
| 61 | 
            +
                'ReXTimeGroundingDataset',
         | 
| 62 | 
            +
                'STARDataset',
         | 
| 63 | 
            +
                'TACoSDataset',
         | 
| 64 | 
            +
                'VidMorpDataset',
         | 
| 65 | 
            +
                'VideoMMEDataset',
         | 
| 66 | 
            +
                'VideoXumDataset',
         | 
| 67 | 
            +
                'YouCook2BiasDataset',
         | 
| 68 | 
            +
                'YouCook2Dataset',
         | 
| 69 | 
            +
            ]
         | 
    	
        videomind/dataset/sub_classes/activitynet_captions.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from collections import OrderedDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='activitynet_captions')
         | 
| 13 | 
            +
            class ActivitynetCaptionsDataset(GroundingDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/activitynet_captions/train.json'
         | 
| 16 | 
            +
                ANNO_PATH_VALID = 'data/activitynet_captions/val_1.json'
         | 
| 17 | 
            +
                ANNO_PATH_TEST = 'data/activitynet_captions/val_2.json'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VIDEO_ROOT = 'data/activitynet/videos_3fps_480_noaudio'
         | 
| 20 | 
            +
                DURATIONS = 'data/activitynet/durations.json'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                UNIT = 0.01
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                @classmethod
         | 
| 25 | 
            +
                def load_annos(self, split='train'):
         | 
| 26 | 
            +
                    if split == 'train':
         | 
| 27 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 28 | 
            +
                    elif split == 'valid':
         | 
| 29 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 30 | 
            +
                    else:
         | 
| 31 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    annos = []
         | 
| 36 | 
            +
                    for vid, raw_anno in raw_annos.items():
         | 
| 37 | 
            +
                        for query, span in zip(raw_anno['sentences'], raw_anno['timestamps']):
         | 
| 38 | 
            +
                            anno = dict(
         | 
| 39 | 
            +
                                source='activitynet_captions',
         | 
| 40 | 
            +
                                data_type='grounding',
         | 
| 41 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 42 | 
            +
                                duration=durations[vid],
         | 
| 43 | 
            +
                                query=parse_query(query),
         | 
| 44 | 
            +
                                span=[span])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                            annos.append(anno)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    return annos
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @DATASETS.register(name='activitynet_captions_bias')
         | 
| 52 | 
            +
            class ActivitynetCaptionsBiasDataset(ActivitynetCaptionsDataset):
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @classmethod
         | 
| 55 | 
            +
                def load_annos(self, split='train'):
         | 
| 56 | 
            +
                    if split == 'train':
         | 
| 57 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 58 | 
            +
                    elif split == 'valid':
         | 
| 59 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 60 | 
            +
                    else:
         | 
| 61 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    annos = []
         | 
| 66 | 
            +
                    for vid, raw_anno in raw_annos.items():
         | 
| 67 | 
            +
                        assert len(raw_anno['sentences']) == len(raw_anno['timestamps'])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        for i in range(len(raw_anno['sentences']) - 1):
         | 
| 70 | 
            +
                            span_a = raw_anno['timestamps'][i]
         | 
| 71 | 
            +
                            span_b = raw_anno['timestamps'][i + 1]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                            if span_b[0] - span_a[1] < 3:
         | 
| 74 | 
            +
                                query_a = parse_query(f"The moment before {raw_anno['sentences'][i + 1]}")
         | 
| 75 | 
            +
                                query_b = parse_query(f"The moment after {raw_anno['sentences'][i]}")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                                anno_a = dict(
         | 
| 78 | 
            +
                                    source='activitynet_captions_bias',
         | 
| 79 | 
            +
                                    data_type='grounding',
         | 
| 80 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 81 | 
            +
                                    duration=durations[vid],
         | 
| 82 | 
            +
                                    query=query_a,
         | 
| 83 | 
            +
                                    span=[span_a])
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                                anno_b = dict(
         | 
| 86 | 
            +
                                    source='activitynet_captions_bias',
         | 
| 87 | 
            +
                                    data_type='grounding',
         | 
| 88 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 89 | 
            +
                                    duration=durations[vid],
         | 
| 90 | 
            +
                                    query=query_b,
         | 
| 91 | 
            +
                                    span=[span_b])
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                                annos.append(anno_a)
         | 
| 94 | 
            +
                                annos.append(anno_b)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/activitynet_rtl.py
    ADDED
    
    | @@ -0,0 +1,68 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import re
         | 
| 4 | 
            +
            from collections import OrderedDict
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import nncore
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 9 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 10 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            @DATASETS.register(name='activitynet_rtl')
         | 
| 14 | 
            +
            class ActivitynetRTLDataset(GroundingDataset):
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                ANNO_PATH_TRAIN = 'data/activitynet_rtl/activitynet_train_gpt-4-0613_temp_6_f10009.json'
         | 
| 17 | 
            +
                ANNO_PATH_TEST = 'data/activitynet_rtl/annot_val_1_q229.json'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VIDEO_ROOT = 'data/activitynet/videos_3fps_480_noaudio'
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                UNIT = 0.01
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                @classmethod
         | 
| 24 | 
            +
                def load_annos(self, split='train'):
         | 
| 25 | 
            +
                    if split == 'train':
         | 
| 26 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                        annos = []
         | 
| 29 | 
            +
                        for vid, raw_anno in raw_annos.items():
         | 
| 30 | 
            +
                            for meta in raw_anno['QA']:
         | 
| 31 | 
            +
                                match = re.findall(r'<(\d+(\.\d+)?)>', meta['a'])
         | 
| 32 | 
            +
                                span = [float(m[0]) for m in match[:2]]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                                # some samples do not have timestamps
         | 
| 35 | 
            +
                                if len(span) != 2:
         | 
| 36 | 
            +
                                    continue
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                                anno = dict(
         | 
| 39 | 
            +
                                    source='activitynet_rtl',
         | 
| 40 | 
            +
                                    data_type='grounding',
         | 
| 41 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 42 | 
            +
                                    duration=raw_anno['duration'],
         | 
| 43 | 
            +
                                    query=parse_query(meta['q']),
         | 
| 44 | 
            +
                                    span=[span])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                                annos.append(anno)
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        annos = []
         | 
| 51 | 
            +
                        for raw_anno in raw_annos:
         | 
| 52 | 
            +
                            vid = f"v_{raw_anno['vid']}"
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                            match = re.findall(r'<(\d+(\.\d+)?)>', raw_anno['answer'])
         | 
| 55 | 
            +
                            span = [float(m[0]) for m in match[:2]]
         | 
| 56 | 
            +
                            assert len(span) == 2
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                            anno = dict(
         | 
| 59 | 
            +
                                source='activitynet_rtl',
         | 
| 60 | 
            +
                                data_type='grounding',
         | 
| 61 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 62 | 
            +
                                duration=raw_anno['duration'],
         | 
| 63 | 
            +
                                query=parse_query(raw_anno['question']),
         | 
| 64 | 
            +
                                span=[span])
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                            annos.append(anno)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/cgbench.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='cgbench')
         | 
| 11 | 
            +
            class CGBenchDataset(Dataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_TEST = 'data/cgbench/cgbench_mini.json'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                VIDEO_ROOT = 'data/cgbench/videos_3fps_480_noaudio'
         | 
| 16 | 
            +
                SUBTITLE_ROOT = 'data/cgbench/subtitles'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                UNIT = 0.001
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                @classmethod
         | 
| 21 | 
            +
                def load_annos(self, split='test'):
         | 
| 22 | 
            +
                    assert split == 'test'
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    annos = []
         | 
| 27 | 
            +
                    for raw_anno in raw_annos:
         | 
| 28 | 
            +
                        vid = raw_anno['video_uid']
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                        anno = dict(
         | 
| 31 | 
            +
                            source='cgbench',
         | 
| 32 | 
            +
                            data_type='multimodal',
         | 
| 33 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 34 | 
            +
                            subtitle_path=nncore.join(self.SUBTITLE_ROOT, vid + '.srt'),
         | 
| 35 | 
            +
                            duration=raw_anno['duration'],
         | 
| 36 | 
            +
                            query=parse_query(raw_anno['question']),
         | 
| 37 | 
            +
                            question=parse_question(raw_anno['question']),
         | 
| 38 | 
            +
                            options=[o.capitalize() for o in raw_anno['choices']],
         | 
| 39 | 
            +
                            answer=raw_anno['answer'].capitalize(),
         | 
| 40 | 
            +
                            ans=raw_anno['right_answer'],
         | 
| 41 | 
            +
                            span=raw_anno['clue_intervals'],
         | 
| 42 | 
            +
                            task=raw_anno['sub_category'],
         | 
| 43 | 
            +
                            domain=raw_anno['domain'])
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        annos.append(anno)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/charades_sta.py
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='charades_sta')
         | 
| 11 | 
            +
            class CharadesSTADataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_TRAIN = 'data/charades_sta/charades_sta_train.txt'
         | 
| 14 | 
            +
                ANNO_PATH_TEST = 'data/charades_sta/charades_sta_test.txt'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                VIDEO_ROOT = 'data/charades_sta/videos_3fps_480_noaudio'
         | 
| 17 | 
            +
                DURATIONS = 'data/charades_sta/durations.json'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                UNIT = 0.1
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @classmethod
         | 
| 22 | 
            +
                def load_annos(self, split='train'):
         | 
| 23 | 
            +
                    if split == 'train':
         | 
| 24 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 25 | 
            +
                    else:
         | 
| 26 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    annos = []
         | 
| 31 | 
            +
                    for raw_anno in raw_annos:
         | 
| 32 | 
            +
                        info, query = raw_anno.split('##')
         | 
| 33 | 
            +
                        vid, s, e = info.split()
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        anno = dict(
         | 
| 36 | 
            +
                            source='charades_sta',
         | 
| 37 | 
            +
                            data_type='grounding',
         | 
| 38 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 39 | 
            +
                            duration=durations[vid],
         | 
| 40 | 
            +
                            query=parse_query(query),
         | 
| 41 | 
            +
                            span=[[float(s), float(e)]])
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        annos.append(anno)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/cosmo_cap.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='cosmo_cap')
         | 
| 11 | 
            +
            class CosMoCapDataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH = 'data/cosmo_cap/anno_cosmo_cap.jsonl'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                VIDEO_ROOT = 'data/cosmo_cap/videos_3fps_480_noaudio'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                UNIT = 1.0
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                @classmethod
         | 
| 20 | 
            +
                def load_annos(self, split='train'):
         | 
| 21 | 
            +
                    assert split == 'train'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    annos = []
         | 
| 26 | 
            +
                    for raw_anno in raw_annos:
         | 
| 27 | 
            +
                        anno = dict(
         | 
| 28 | 
            +
                            source='cosmo_cap',
         | 
| 29 | 
            +
                            data_type='grounding',
         | 
| 30 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, raw_anno['vid'] + '.mp4'),
         | 
| 31 | 
            +
                            duration=raw_anno['duration'],
         | 
| 32 | 
            +
                            query=parse_query(raw_anno['query']),
         | 
| 33 | 
            +
                            span=[raw_anno['span']])
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        annos.append(anno)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/didemo.py
    ADDED
    
    | @@ -0,0 +1,59 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 9 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 10 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            @DATASETS.register(name='didemo')
         | 
| 14 | 
            +
            class DiDeMoDataset(GroundingDataset):
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                ANNO_PATH_TRAIN = 'data/didemo/train_data.json'
         | 
| 17 | 
            +
                ANNO_PATH_VALID = 'data/didemo/val_data.json'
         | 
| 18 | 
            +
                ANNO_PATH_TEST = 'data/didemo/test_data.json'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                VIDEO_ROOT = 'data/didemo/videos_3fps_480_noaudio'
         | 
| 21 | 
            +
                DURATIONS = 'data/didemo/durations.json'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                UNIT = 1.0
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                @classmethod
         | 
| 26 | 
            +
                def load_annos(self, split='train'):
         | 
| 27 | 
            +
                    if split == 'train':
         | 
| 28 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 29 | 
            +
                    elif split == 'valid':
         | 
| 30 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 31 | 
            +
                    else:
         | 
| 32 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    annos = []
         | 
| 37 | 
            +
                    for raw_anno in raw_annos:
         | 
| 38 | 
            +
                        vid = raw_anno['video'].split('.')[0]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                        # apply mean on multiple spans
         | 
| 41 | 
            +
                        span = np.array(raw_anno['times']).mean(axis=0).tolist()
         | 
| 42 | 
            +
                        span = [round(span[0] * 5), round((span[1] + 1) * 5)]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                        # augment spans during training
         | 
| 45 | 
            +
                        if split == 'train':
         | 
| 46 | 
            +
                            offset = random.randint(-2, 2)
         | 
| 47 | 
            +
                            span = [span[0] + offset, span[1] + offset]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        anno = dict(
         | 
| 50 | 
            +
                            source='didemo',
         | 
| 51 | 
            +
                            data_type='grounding',
         | 
| 52 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 53 | 
            +
                            duration=durations[vid],
         | 
| 54 | 
            +
                            query=parse_query(raw_anno['description']),
         | 
| 55 | 
            +
                            span=[span])
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        annos.append(anno)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/ego4d_naq.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from collections import OrderedDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='ego4d_naq')
         | 
| 13 | 
            +
            class Ego4DNaQDataset(GroundingDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/ego4d_naq/train.json'
         | 
| 16 | 
            +
                ANNO_PATH_VALID = 'data/ego4d_naq/val.json'
         | 
| 17 | 
            +
                ANNO_PATH_TEST = 'data/ego4d_naq/test.json'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                UNIT = 0.001
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                @classmethod
         | 
| 24 | 
            +
                def load_annos(self, split='train'):
         | 
| 25 | 
            +
                    if split == 'train':
         | 
| 26 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 27 | 
            +
                    elif split == 'valid':
         | 
| 28 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST, object_pairs_hook=OrderedDict)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    annos = []
         | 
| 33 | 
            +
                    for vid, raw_anno in raw_annos.items():
         | 
| 34 | 
            +
                        duration = raw_anno['num_frames'] / raw_anno['fps']
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                        # 300s: 254k samples (dropped 121k samples merged 156k samples)
         | 
| 37 | 
            +
                        # 480s: 567k samples (dropped 249k samples merged 328k samples)
         | 
| 38 | 
            +
                        if split == 'train' and (duration < 10 or duration > 600):
         | 
| 39 | 
            +
                            continue
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                        meta = dict()
         | 
| 42 | 
            +
                        for span, query in zip(raw_anno['exact_times'], raw_anno['sentences']):
         | 
| 43 | 
            +
                            span = [round(span[0], 3), round(span[1], 3)]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                            query = parse_query(query)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                            # these annotations might be from nlq
         | 
| 48 | 
            +
                            nlq_keys = ('who', 'what', 'when', 'in what', 'did', 'where', 'how', 'i what')
         | 
| 49 | 
            +
                            if split == 'train' and any(query.startswith(k) for k in nlq_keys):
         | 
| 50 | 
            +
                                continue
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                            # bad samples
         | 
| 53 | 
            +
                            if split == 'train' and '#unsure' in query:
         | 
| 54 | 
            +
                                continue
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                            # too short or too long samples
         | 
| 57 | 
            +
                            num_words = len(query.split(' '))
         | 
| 58 | 
            +
                            if split == 'train' and (num_words < 3 or num_words > 30):
         | 
| 59 | 
            +
                                continue
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                            if query not in meta:
         | 
| 62 | 
            +
                                meta[query] = []
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                            meta[query].append(span)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        for query, span in meta.items():
         | 
| 67 | 
            +
                            # skip samples with multiple moments
         | 
| 68 | 
            +
                            if len(span) > 1:
         | 
| 69 | 
            +
                                continue
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                            anno = dict(
         | 
| 72 | 
            +
                                source='ego4d_naq',
         | 
| 73 | 
            +
                                data_type='grounding',
         | 
| 74 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 75 | 
            +
                                duration=duration,
         | 
| 76 | 
            +
                                query=query,
         | 
| 77 | 
            +
                                span=span)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                            annos.append(anno)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/ego4d_nlq.py
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='ego4d_nlq')
         | 
| 11 | 
            +
            class Ego4DNLQDataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_TRAIN = 'data/ego4d_nlq/nlq_train.jsonl'
         | 
| 14 | 
            +
                ANNO_PATH_VALID = 'data/ego4d_nlq/nlq_val.jsonl'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                UNIT = 0.001
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                @classmethod
         | 
| 21 | 
            +
                def load_annos(self, split='train'):
         | 
| 22 | 
            +
                    if split == 'train':
         | 
| 23 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 24 | 
            +
                    else:
         | 
| 25 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    annos = []
         | 
| 28 | 
            +
                    for raw_anno in raw_annos:
         | 
| 29 | 
            +
                        assert len(raw_anno['relevant_windows']) == 1
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                        anno = dict(
         | 
| 32 | 
            +
                            source='ego4d_nlq',
         | 
| 33 | 
            +
                            data_type='grounding',
         | 
| 34 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, raw_anno['vid'] + '.mp4'),
         | 
| 35 | 
            +
                            duration=raw_anno['duration'],
         | 
| 36 | 
            +
                            query=parse_query(raw_anno['query']),
         | 
| 37 | 
            +
                            span=raw_anno['relevant_windows'])
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                        annos.append(anno)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/ego_timeqa.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='ego_timeqa')
         | 
| 13 | 
            +
            class EgoTimeQADataset(AnsweringDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/ego_timeqa/annotations.EgoTimeQA.json'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                VIDEO_ROOT = 'data/ego4d/v2/videos_3fps_480_noaudio'
         | 
| 18 | 
            +
                DURATIONS = 'data/ego4d/v2/durations.json'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                SOURCE = 'ego_timeqa'
         | 
| 21 | 
            +
                DATA_TYPE = 'multimodal'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                UNIT = 0.001
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                @classmethod
         | 
| 26 | 
            +
                def load_annos(self, split='train'):
         | 
| 27 | 
            +
                    assert split == 'train'
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 30 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    annos = []
         | 
| 33 | 
            +
                    for raw_anno in raw_annos:
         | 
| 34 | 
            +
                        vid = raw_anno['video_id']
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                        duration = durations[vid]
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                        # 303k -> 284k (to be verified)
         | 
| 39 | 
            +
                        if duration < 10 or duration > 600:
         | 
| 40 | 
            +
                            continue
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        span = [raw_anno['moment_start_frame'] / 30, raw_anno['moment_end_frame'] / 30]
         | 
| 43 | 
            +
                        span = [round(span[0], 3), round(span[1], 3)]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        # this would remove many samples (284k -> 37k)
         | 
| 46 | 
            +
                        # if span[1] - span[0] < 2:
         | 
| 47 | 
            +
                        #     continue
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        question = raw_anno['question'].replace(' l ', ' I ').capitalize()
         | 
| 50 | 
            +
                        question = parse_question(question)
         | 
| 51 | 
            +
                        query = parse_query(question)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        # too short or too long samples
         | 
| 54 | 
            +
                        num_words = len(query.split(' '))
         | 
| 55 | 
            +
                        if split == 'train' and (num_words < 3 or num_words > 30):
         | 
| 56 | 
            +
                            continue
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                        answer = raw_anno['answer'].capitalize()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                        assert len(raw_anno['wrong_answers']) == 3
         | 
| 61 | 
            +
                        idx = random.randint(0, 3)
         | 
| 62 | 
            +
                        ans = chr(ord('A') + idx)
         | 
| 63 | 
            +
                        options = [o.capitalize() for o in raw_anno['wrong_answers']]
         | 
| 64 | 
            +
                        options.insert(idx, answer)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        anno = dict(
         | 
| 67 | 
            +
                            source=self.SOURCE,
         | 
| 68 | 
            +
                            data_type=self.DATA_TYPE,
         | 
| 69 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 70 | 
            +
                            duration=duration,
         | 
| 71 | 
            +
                            query=query,
         | 
| 72 | 
            +
                            question=question,
         | 
| 73 | 
            +
                            options=options,
         | 
| 74 | 
            +
                            answer=answer,
         | 
| 75 | 
            +
                            ans=ans,
         | 
| 76 | 
            +
                            span=[span])
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                        annos.append(anno)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    return annos
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            @DATASETS.register(name='ego_timeqa_crop')
         | 
| 84 | 
            +
            class EgoTimeQACropDataset(AnsweringCropDataset, EgoTimeQADataset):
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                SOURCE = 'ego_timeqa_crop'
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            @DATASETS.register(name='ego_timeqa_grounding')
         | 
| 90 | 
            +
            class EgoTimeQAGroundingDataset(GroundingDataset, EgoTimeQADataset):
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                SOURCE = 'ego_timeqa_grounding'
         | 
| 93 | 
            +
                DATA_TYPE = 'grounding'
         | 
    	
        videomind/dataset/sub_classes/hirest.py
    ADDED
    
    | @@ -0,0 +1,150 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from collections import OrderedDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='hirest_grounding')
         | 
| 13 | 
            +
            class HiRESTGroundingDataset(GroundingDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/hirest/all_data_train.json'
         | 
| 16 | 
            +
                ANNO_PATH_VALID = 'data/hirest/all_data_val.json'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                VIDEO_ROOT = 'data/hirest/videos_3fps_480_noaudio'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                UNIT = 1.0
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @classmethod
         | 
| 23 | 
            +
                def load_annos(self, split='train'):
         | 
| 24 | 
            +
                    if split == 'train':
         | 
| 25 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 26 | 
            +
                    else:
         | 
| 27 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
         | 
| 30 | 
            +
                    all_videos = set(v[:11] for v in all_videos)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    annos = []
         | 
| 33 | 
            +
                    for query, videos in raw_annos.items():
         | 
| 34 | 
            +
                        for video_name, raw_anno in videos.items():
         | 
| 35 | 
            +
                            if not raw_anno['relevant'] or not raw_anno['clip']:
         | 
| 36 | 
            +
                                continue
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                            assert len(raw_anno['bounds']) == 2
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                            vid = video_name.split('.')[0]
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                            if vid not in all_videos:
         | 
| 43 | 
            +
                                continue
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                            anno = dict(
         | 
| 46 | 
            +
                                source='hirest_grounding',
         | 
| 47 | 
            +
                                data_type='grounding',
         | 
| 48 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, video_name),
         | 
| 49 | 
            +
                                duration=raw_anno['v_duration'],
         | 
| 50 | 
            +
                                query=parse_query(query),
         | 
| 51 | 
            +
                                span=[raw_anno['bounds']])
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                            annos.append(anno)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    return annos
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            @DATASETS.register(name='hirest_step')
         | 
| 59 | 
            +
            class HiRESTStepDataset(HiRESTGroundingDataset):
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                @classmethod
         | 
| 62 | 
            +
                def load_annos(self, split='train'):
         | 
| 63 | 
            +
                    if split == 'train':
         | 
| 64 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 65 | 
            +
                    else:
         | 
| 66 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
         | 
| 69 | 
            +
                    all_videos = set(v[:11] for v in all_videos)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    annos = []
         | 
| 72 | 
            +
                    for query, videos in raw_annos.items():
         | 
| 73 | 
            +
                        for video_name, raw_anno in videos.items():
         | 
| 74 | 
            +
                            if not raw_anno['relevant'] or not raw_anno['clip'] or len(raw_anno['steps']) == 0:
         | 
| 75 | 
            +
                                continue
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                            vid = video_name.split('.')[0]
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                            if vid not in all_videos:
         | 
| 80 | 
            +
                                continue
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                            for step in raw_anno['steps']:
         | 
| 83 | 
            +
                                assert len(step['absolute_bounds']) == 2
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                                anno = dict(
         | 
| 86 | 
            +
                                    source='hirest_step',
         | 
| 87 | 
            +
                                    data_type='grounding',
         | 
| 88 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, video_name),
         | 
| 89 | 
            +
                                    duration=raw_anno['v_duration'],
         | 
| 90 | 
            +
                                    query=parse_query(step['heading']),
         | 
| 91 | 
            +
                                    span=[step['absolute_bounds']])
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                                annos.append(anno)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    return annos
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            @DATASETS.register(name='hirest_step_bias')
         | 
| 99 | 
            +
            class HiRESTStepBiasDataset(HiRESTStepDataset):
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                @classmethod
         | 
| 102 | 
            +
                def load_annos(self, split='train'):
         | 
| 103 | 
            +
                    if split == 'train':
         | 
| 104 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN, object_pairs_hook=OrderedDict)
         | 
| 105 | 
            +
                    else:
         | 
| 106 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID, object_pairs_hook=OrderedDict)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
         | 
| 109 | 
            +
                    all_videos = set(v[:11] for v in all_videos)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    annos = []
         | 
| 112 | 
            +
                    for query, videos in raw_annos.items():
         | 
| 113 | 
            +
                        for video_name, raw_anno in videos.items():
         | 
| 114 | 
            +
                            if not raw_anno['relevant'] or not raw_anno['clip'] or len(raw_anno['steps']) == 0:
         | 
| 115 | 
            +
                                continue
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                            vid = video_name.split('.')[0]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                            if vid not in all_videos:
         | 
| 120 | 
            +
                                continue
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                            for i in range(len(raw_anno['steps']) - 1):
         | 
| 123 | 
            +
                                span_a = raw_anno['steps'][i]['absolute_bounds']
         | 
| 124 | 
            +
                                span_b = raw_anno['steps'][i + 1]['absolute_bounds']
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                                assert len(span_a) == 2 and len(span_b) == 2 and span_a[1] == span_b[0]
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                                query_a = parse_query(f"The moment before {raw_anno['steps'][i + 1]['heading']}")
         | 
| 129 | 
            +
                                query_b = parse_query(f"The moment after {raw_anno['steps'][i]['heading']}")
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                                anno_a = dict(
         | 
| 132 | 
            +
                                    source='hirest_step_bias',
         | 
| 133 | 
            +
                                    data_type='grounding',
         | 
| 134 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, video_name),
         | 
| 135 | 
            +
                                    duration=raw_anno['v_duration'],
         | 
| 136 | 
            +
                                    query=query_a,
         | 
| 137 | 
            +
                                    span=[span_a])
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                                anno_b = dict(
         | 
| 140 | 
            +
                                    source='hirest_step_bias',
         | 
| 141 | 
            +
                                    data_type='grounding',
         | 
| 142 | 
            +
                                    video_path=nncore.join(self.VIDEO_ROOT, video_name),
         | 
| 143 | 
            +
                                    duration=raw_anno['v_duration'],
         | 
| 144 | 
            +
                                    query=query_b,
         | 
| 145 | 
            +
                                    span=[span_b])
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                                annos.append(anno_a)
         | 
| 148 | 
            +
                                annos.append(anno_b)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/internvit_vtime.py
    ADDED
    
    | @@ -0,0 +1,45 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='internvid_vtime')
         | 
| 11 | 
            +
            class InternVidVTimeDataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH = 'data/internvid_vtime/anno_internvid_vtime_query_gpt4o_mini.jsonl'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                VIDEO_ROOT = 'data/internvid_vtime/videos_crop_3fps_480_noaudio'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                UNIT = 0.1
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                @classmethod
         | 
| 20 | 
            +
                def load_annos(self, split='train'):
         | 
| 21 | 
            +
                    assert split == 'train'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    all_videos = nncore.ls(self.VIDEO_ROOT, ext='.mp4')
         | 
| 26 | 
            +
                    all_videos = set(v[:11] for v in all_videos)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    annos = []
         | 
| 29 | 
            +
                    for raw_anno in raw_annos:
         | 
| 30 | 
            +
                        vid = raw_anno['vid']
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                        if vid not in all_videos:
         | 
| 33 | 
            +
                            continue
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                        anno = dict(
         | 
| 36 | 
            +
                            source='internvid_vtime',
         | 
| 37 | 
            +
                            data_type='grounding',
         | 
| 38 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 39 | 
            +
                            duration=raw_anno['duration'],
         | 
| 40 | 
            +
                            query=parse_query(raw_anno['query']),
         | 
| 41 | 
            +
                            span=[raw_anno['span']])
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        annos.append(anno)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/longvideobench.py
    ADDED
    
    | @@ -0,0 +1,53 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='longvideobench')
         | 
| 11 | 
            +
            class LongVideoBenchDataset(Dataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_VALID = 'data/longvideobench/lvb_val.json'
         | 
| 14 | 
            +
                ANNO_PATH_TEST = 'data/longvideobench/lvb_test_wo_gt.json'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                VIDEO_ROOT = 'data/longvideobench/videos_3fps_480_noaudio'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                @classmethod
         | 
| 19 | 
            +
                def load_annos(self, split='valid'):
         | 
| 20 | 
            +
                    if split == 'valid':
         | 
| 21 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 22 | 
            +
                    else:
         | 
| 23 | 
            +
                        print('WARNING: Test split does not have ground truth annotations')
         | 
| 24 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    annos = []
         | 
| 27 | 
            +
                    for raw_anno in raw_annos:
         | 
| 28 | 
            +
                        vid = raw_anno['video_id']
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                        if vid.startswith('@'):
         | 
| 31 | 
            +
                            vid = vid[-19:]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                        # videos might come from youtube or other sources
         | 
| 34 | 
            +
                        assert len(vid) in (11, 19)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                        anno = dict(
         | 
| 37 | 
            +
                            source='longvideobench',
         | 
| 38 | 
            +
                            data_type='multimodal',
         | 
| 39 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 40 | 
            +
                            query=parse_query(raw_anno['question']),
         | 
| 41 | 
            +
                            question=parse_question(raw_anno['question']),
         | 
| 42 | 
            +
                            options=raw_anno['candidates'],
         | 
| 43 | 
            +
                            task=str(raw_anno['duration_group']),
         | 
| 44 | 
            +
                            level=raw_anno['level'],
         | 
| 45 | 
            +
                            question_category=raw_anno['question_category'])
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                        if 'correct_choice' in raw_anno:
         | 
| 48 | 
            +
                            anno['answer'] = raw_anno['candidates'][raw_anno['correct_choice']]
         | 
| 49 | 
            +
                            anno['ans'] = chr(ord('A') + raw_anno['correct_choice'])
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        annos.append(anno)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/lvbench.py
    ADDED
    
    | @@ -0,0 +1,52 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='lvbench')
         | 
| 11 | 
            +
            class LVBenchDataset(Dataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH = 'data/lvbench/LVBench/video_info.meta.jsonl'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                VIDEO_ROOT = 'data/lvbench/videos_3fps_480_noaudio'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                @classmethod
         | 
| 18 | 
            +
                def load_annos(self, split='test'):
         | 
| 19 | 
            +
                    assert split == 'test'
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    annos = []
         | 
| 24 | 
            +
                    for raw_anno in raw_annos:
         | 
| 25 | 
            +
                        vid = raw_anno['key']
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                        for meta in raw_anno['qa']:
         | 
| 28 | 
            +
                            tok = meta['question'].split('\n')
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                            assert len(tok) == 5
         | 
| 31 | 
            +
                            assert all(any(o.startswith(k) for k in ('(A) ', '(B) ', '(C) ', '(D) ')) for o in tok[1:])
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                            options = [o[4:] for o in tok[1:]]
         | 
| 34 | 
            +
                            ans = meta['answer']
         | 
| 35 | 
            +
                            answer = options[ord(ans) - ord('A')]
         | 
| 36 | 
            +
                            assert ans in 'ABCD'
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                            anno = dict(
         | 
| 39 | 
            +
                                source='lvbench',
         | 
| 40 | 
            +
                                data_type='multimodal',
         | 
| 41 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 42 | 
            +
                                query=parse_query(tok[0]),
         | 
| 43 | 
            +
                                question=parse_question(tok[0]),
         | 
| 44 | 
            +
                                options=options,
         | 
| 45 | 
            +
                                answer=answer,
         | 
| 46 | 
            +
                                ans=ans,
         | 
| 47 | 
            +
                                task=meta['question_type'],
         | 
| 48 | 
            +
                                time_reference=meta['time_reference'])
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                            annos.append(anno)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/mlvu.py
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='mlvu')
         | 
| 11 | 
            +
            class MLVUDataset(Dataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                TASK_TO_DIR_MAP = {
         | 
| 14 | 
            +
                    'plotQA': '1_plotQA',
         | 
| 15 | 
            +
                    'findNeedle': '2_needle',
         | 
| 16 | 
            +
                    'ego': '3_ego',
         | 
| 17 | 
            +
                    'count': '4_count',
         | 
| 18 | 
            +
                    'order': '5_order',
         | 
| 19 | 
            +
                    'anomaly_reco': '6_anomaly_reco',
         | 
| 20 | 
            +
                    'topic_reasoning': '7_topic_reasoning'
         | 
| 21 | 
            +
                }
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                DATA_ROOT = 'data/mlvu'
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                @classmethod
         | 
| 26 | 
            +
                def load_annos(self, split='test'):
         | 
| 27 | 
            +
                    assert split == 'test'
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    paths = [nncore.join(self.DATA_ROOT, 'json', f'{n}.json') for n in self.TASK_TO_DIR_MAP.values()]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    raw_annos = nncore.flatten([nncore.load(p) for p in paths])
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    annos = []
         | 
| 34 | 
            +
                    for raw_anno in raw_annos:
         | 
| 35 | 
            +
                        task = raw_anno['question_type']
         | 
| 36 | 
            +
                        video_name = nncore.join(self.TASK_TO_DIR_MAP[task], raw_anno['video'])
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                        options = raw_anno['candidates']
         | 
| 39 | 
            +
                        answer = raw_anno['answer']
         | 
| 40 | 
            +
                        ans = chr(ord('A') + options.index(answer))
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        anno = dict(
         | 
| 43 | 
            +
                            source='mlvu',
         | 
| 44 | 
            +
                            data_type='multimodal',
         | 
| 45 | 
            +
                            video_path=nncore.join(self.DATA_ROOT, 'video', video_name),
         | 
| 46 | 
            +
                            query=parse_query(raw_anno['question']),
         | 
| 47 | 
            +
                            question=parse_question(raw_anno['question']),
         | 
| 48 | 
            +
                            options=options,
         | 
| 49 | 
            +
                            answer=answer,
         | 
| 50 | 
            +
                            ans=ans,
         | 
| 51 | 
            +
                            task=task)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        annos.append(anno)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/mvbench.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='mvbench')
         | 
| 11 | 
            +
            class MVBenchDataset(Dataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                META_DATA = [('Episodic Reasoning', 'episodic_reasoning.json', 'tvqa/frames_fps3_hq', 'frame'),
         | 
| 14 | 
            +
                             ('Action Sequence', 'action_sequence.json', 'star/Charades_v1_480', 'video'),
         | 
| 15 | 
            +
                             ('Action Prediction', 'action_prediction.json', 'star/Charades_v1_480', 'video'),
         | 
| 16 | 
            +
                             ('Action Antonym', 'action_antonym.json', 'ssv2_video', 'video'),
         | 
| 17 | 
            +
                             ('Fine-grained Action', 'fine_grained_action.json', 'Moments_in_Time_Raw/videos', 'video'),
         | 
| 18 | 
            +
                             ('Unexpected Action', 'unexpected_action.json', 'FunQA_test/test', 'video'),
         | 
| 19 | 
            +
                             ('Object Existence', 'object_existence.json', 'clevrer/video_validation', 'video'),
         | 
| 20 | 
            +
                             ('Object Interaction', 'object_interaction.json', 'star/Charades_v1_480', 'video'),
         | 
| 21 | 
            +
                             ('Object Shuffle', 'object_shuffle.json', 'perception/videos', 'video'),
         | 
| 22 | 
            +
                             ('Moving Direction', 'moving_direction.json', 'clevrer/video_validation', 'video'),
         | 
| 23 | 
            +
                             ('Action Localization', 'action_localization.json', 'sta/sta_video', 'video'),
         | 
| 24 | 
            +
                             ('Scene Transition', 'scene_transition.json', 'scene_qa/video', 'video'),
         | 
| 25 | 
            +
                             ('Action Count', 'action_count.json', 'perception/videos', 'video'),
         | 
| 26 | 
            +
                             ('Moving Count', 'moving_count.json', 'clevrer/video_validation', 'video'),
         | 
| 27 | 
            +
                             ('Moving Attribute', 'moving_attribute.json', 'clevrer/video_validation', 'video'),
         | 
| 28 | 
            +
                             ('State Change', 'state_change.json', 'perception/videos', 'video'),
         | 
| 29 | 
            +
                             ('Fine-grained Pose', 'fine_grained_pose.json', 'nturgbd', 'video'),
         | 
| 30 | 
            +
                             ('Character Order', 'character_order.json', 'perception/videos', 'video'),
         | 
| 31 | 
            +
                             ('Egocentric Navigation', 'egocentric_navigation.json', 'vlnqa', 'video'),
         | 
| 32 | 
            +
                             ('Counterfactual Inference', 'counterfactual_inference.json', 'clevrer/video_validation', 'video')]
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                DATA_ROOT = 'data/mvbench'
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                MIN_LEN = 64
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                @classmethod
         | 
| 39 | 
            +
                def load_annos(self, split='test', sample_frames=32):
         | 
| 40 | 
            +
                    assert split == 'test'
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    annos = []
         | 
| 43 | 
            +
                    for meta in self.META_DATA:
         | 
| 44 | 
            +
                        raw_annos = nncore.load(nncore.join(self.DATA_ROOT, 'json', meta[1]))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        for raw_anno in raw_annos:
         | 
| 47 | 
            +
                            video_name = nncore.join(meta[2], raw_anno['video'])
         | 
| 48 | 
            +
                            video_path = nncore.join(self.DATA_ROOT, 'video', video_name)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                            if meta[3] == 'frame':
         | 
| 51 | 
            +
                                num_frames = len(nncore.ls(video_path, ext='.jpg'))
         | 
| 52 | 
            +
                                video_path = [
         | 
| 53 | 
            +
                                    nncore.join(video_path, f'{i:0>5}.jpg')
         | 
| 54 | 
            +
                                    for i in range(1, num_frames + 1, num_frames // (sample_frames - 1))
         | 
| 55 | 
            +
                                ][:sample_frames]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                            options = raw_anno['candidates']
         | 
| 58 | 
            +
                            answer = raw_anno['answer']
         | 
| 59 | 
            +
                            ans = chr(ord('A') + options.index(answer))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                            anno = dict(
         | 
| 62 | 
            +
                                source='mvbench',
         | 
| 63 | 
            +
                                data_type='multimodal',
         | 
| 64 | 
            +
                                video_path=video_path,
         | 
| 65 | 
            +
                                query=parse_query(raw_anno['question']),
         | 
| 66 | 
            +
                                question=parse_question(raw_anno['question']),
         | 
| 67 | 
            +
                                options=options,
         | 
| 68 | 
            +
                                answer=answer,
         | 
| 69 | 
            +
                                ans=ans,
         | 
| 70 | 
            +
                                task=meta[0])
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                            annos.append(anno)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/nextgqa.py
    ADDED
    
    | @@ -0,0 +1,87 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import csv
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='nextgqa')
         | 
| 13 | 
            +
            class NExTGQADataset(AnsweringDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_VALID = 'data/nextgqa/val.csv'
         | 
| 16 | 
            +
                ANNO_PATH_TEST = 'data/nextgqa/test.csv'
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                SPAN_PATH_VALID = 'data/nextgqa/gsub_val.json'
         | 
| 19 | 
            +
                SPAN_PATH_TEST = 'data/nextgqa/gsub_test.json'
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                VIDEO_ID_MAP = 'data/nextgqa/map_vid_vidorID.json'
         | 
| 22 | 
            +
                VIDEO_ROOT = 'data/nextqa/videos'
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                SOURCE = 'nextgqa'
         | 
| 25 | 
            +
                DATA_TYPE = 'multimodal'
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                UNIT = 0.1
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @classmethod
         | 
| 30 | 
            +
                def load_annos(self, split='valid'):
         | 
| 31 | 
            +
                    assert split in ('valid', 'test')
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if split == 'valid':
         | 
| 34 | 
            +
                        anno_path = self.ANNO_PATH_VALID
         | 
| 35 | 
            +
                        raw_spans = nncore.load(self.SPAN_PATH_VALID)
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        anno_path = self.ANNO_PATH_TEST
         | 
| 38 | 
            +
                        raw_spans = nncore.load(self.SPAN_PATH_TEST)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    with open(anno_path, mode='r') as f:
         | 
| 41 | 
            +
                        reader = csv.DictReader(f)
         | 
| 42 | 
            +
                        raw_annos = [d for d in reader]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    video_id_map = nncore.load(self.VIDEO_ID_MAP)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    annos = []
         | 
| 47 | 
            +
                    for raw_anno in raw_annos:
         | 
| 48 | 
            +
                        vid = raw_anno['video_id']
         | 
| 49 | 
            +
                        qid = raw_anno['qid']
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        video_id = video_id_map[vid]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        query = parse_query(raw_anno['question'].capitalize() + '?')
         | 
| 54 | 
            +
                        question = parse_question(raw_anno['question'].capitalize() + '?')
         | 
| 55 | 
            +
                        options = [raw_anno[k].capitalize() for k in ('a0', 'a1', 'a2', 'a3', 'a4')]
         | 
| 56 | 
            +
                        answer = raw_anno['answer'].capitalize()
         | 
| 57 | 
            +
                        ans = chr(ord('A') + options.index(answer))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                        anno = dict(
         | 
| 60 | 
            +
                            source=self.SOURCE,
         | 
| 61 | 
            +
                            data_type=self.DATA_TYPE,
         | 
| 62 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, video_id + '.mp4'),
         | 
| 63 | 
            +
                            duration=raw_spans[vid]['duration'],
         | 
| 64 | 
            +
                            query=query,
         | 
| 65 | 
            +
                            question=question,
         | 
| 66 | 
            +
                            options=options,
         | 
| 67 | 
            +
                            answer=answer,
         | 
| 68 | 
            +
                            ans=ans,
         | 
| 69 | 
            +
                            span=raw_spans[vid]['location'][qid],
         | 
| 70 | 
            +
                            task=raw_anno['type'])
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        annos.append(anno)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    return annos
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            @DATASETS.register(name='nextgqa_crop')
         | 
| 78 | 
            +
            class NExTGQACropDataset(AnsweringCropDataset, NExTGQADataset):
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                SOURCE = 'nextgqa_crop'
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            @DATASETS.register(name='nextgqa_grounding')
         | 
| 84 | 
            +
            class NExTGQAGroundingDataset(GroundingDataset, NExTGQADataset):
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                SOURCE = 'nextgqa_grounding'
         | 
| 87 | 
            +
                DATA_TYPE = 'grounding'
         | 
    	
        videomind/dataset/sub_classes/nextqa.py
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import csv
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import AnsweringDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='nextqa')
         | 
| 13 | 
            +
            class NExTQADataset(AnsweringDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/nextqa/train.csv'
         | 
| 16 | 
            +
                ANNO_PATH_VALID = 'data/nextqa/val.csv'
         | 
| 17 | 
            +
                ANNO_PATH_TEST = 'data/nextqa/test.csv'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VIDEO_ID_MAP = 'data/nextqa/map_vid_vidorID.json'
         | 
| 20 | 
            +
                VIDEO_ROOT = 'data/nextqa/NExTVideo'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @classmethod
         | 
| 23 | 
            +
                def load_annos(self, split='train'):
         | 
| 24 | 
            +
                    if split == 'train':
         | 
| 25 | 
            +
                        anno_path = self.ANNO_PATH_TRAIN
         | 
| 26 | 
            +
                    elif split == 'valid':
         | 
| 27 | 
            +
                        anno_path = self.ANNO_PATH_VALID
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        anno_path = self.ANNO_PATH_TEST
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    with open(anno_path, mode='r') as f:
         | 
| 32 | 
            +
                        reader = csv.DictReader(f)
         | 
| 33 | 
            +
                        raw_annos = [d for d in reader]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    video_id_map = nncore.load(self.VIDEO_ID_MAP)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    annos = []
         | 
| 38 | 
            +
                    for raw_anno in raw_annos:
         | 
| 39 | 
            +
                        vid = raw_anno['video']
         | 
| 40 | 
            +
                        qid = raw_anno['qid']
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        video_id = video_id_map[vid]
         | 
| 43 | 
            +
                        query = parse_query(raw_anno['question'].capitalize() + '?')
         | 
| 44 | 
            +
                        question = parse_question(raw_anno['question'].capitalize() + '?')
         | 
| 45 | 
            +
                        options = [raw_anno[k].capitalize() for k in ('a0', 'a1', 'a2', 'a3', 'a4')]
         | 
| 46 | 
            +
                        ans = chr(ord('A') + int(raw_anno['answer']))
         | 
| 47 | 
            +
                        answer = options[int(raw_anno['answer'])]
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        anno = dict(
         | 
| 50 | 
            +
                            source='nextqa',
         | 
| 51 | 
            +
                            data_type='multimodal',
         | 
| 52 | 
            +
                            uid=f'{vid}_{qid}',
         | 
| 53 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, video_id + '.mp4'),
         | 
| 54 | 
            +
                            query=query,
         | 
| 55 | 
            +
                            question=question,
         | 
| 56 | 
            +
                            options=options,
         | 
| 57 | 
            +
                            answer=answer,
         | 
| 58 | 
            +
                            ans=ans,
         | 
| 59 | 
            +
                            task=raw_anno['type'])
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                        annos.append(anno)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/qa_ego4d.py
    ADDED
    
    | @@ -0,0 +1,98 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import nncore
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 8 | 
            +
            from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
         | 
| 9 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            @DATASETS.register(name='qa_ego4d')
         | 
| 13 | 
            +
            class QAEgo4DDataset(AnsweringDataset):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                ANNO_PATH_TRAIN = 'data/qa_ego4d/annotations.QaEgo4D_train.json'
         | 
| 16 | 
            +
                ANNO_PATH_VALID = 'data/qa_ego4d/annotations.QaEgo4D_val_options.json'
         | 
| 17 | 
            +
                ANNO_PATH_TEST = 'data/qa_ego4d/annotations.QaEgo4D_test_options.json'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VIDEO_ROOT = 'data/ego4d/v1/videos_3fps_480_noaudio'
         | 
| 20 | 
            +
                DURATIONS = 'data/ego4d/v1/durations.json'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                SOURCE = 'qa_ego4d'
         | 
| 23 | 
            +
                DATA_TYPE = 'multimodal'
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                UNIT = 0.001
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                @classmethod
         | 
| 28 | 
            +
                def load_annos(self, split='train'):
         | 
| 29 | 
            +
                    if split == 'train':
         | 
| 30 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 31 | 
            +
                    elif split == 'valid':
         | 
| 32 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 33 | 
            +
                    else:
         | 
| 34 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    annos = []
         | 
| 39 | 
            +
                    for raw_anno in raw_annos:
         | 
| 40 | 
            +
                        vid = raw_anno['video_id']
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        duration = durations[vid]
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                        # too short or too long samples
         | 
| 45 | 
            +
                        if split == 'train' and (duration < 10 or duration > 600):
         | 
| 46 | 
            +
                            continue
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                        span = [raw_anno['moment_start_frame'] / 30, raw_anno['moment_end_frame'] / 30]
         | 
| 49 | 
            +
                        span = [round(span[0], 3), round(span[1], 3)]
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                        # skip samples with too short moments
         | 
| 52 | 
            +
                        # if split == 'train' and span[1] - span[0] < 2:
         | 
| 53 | 
            +
                        #     continue
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                        answer = raw_anno['answer'].capitalize()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        if 'options' in raw_anno:
         | 
| 58 | 
            +
                            options = [o.capitalize() for o in raw_anno['options']]
         | 
| 59 | 
            +
                            idx = options.index(answer)
         | 
| 60 | 
            +
                            ans = chr(ord('A') + idx)
         | 
| 61 | 
            +
                        else:
         | 
| 62 | 
            +
                            # NOTE: indeterministic evaluation
         | 
| 63 | 
            +
                            assert len(raw_anno['wrong_answers']) == 3
         | 
| 64 | 
            +
                            idx = random.randint(0, 3)
         | 
| 65 | 
            +
                            ans = chr(ord('A') + idx)
         | 
| 66 | 
            +
                            options = [o.capitalize() for o in raw_anno['wrong_answers']]
         | 
| 67 | 
            +
                            options.insert(idx, answer)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                        assert len(options) == 4, options
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                        anno = dict(
         | 
| 72 | 
            +
                            source=self.SOURCE,
         | 
| 73 | 
            +
                            data_type=self.DATA_TYPE,
         | 
| 74 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 75 | 
            +
                            duration=duration,
         | 
| 76 | 
            +
                            query=parse_query(raw_anno['question'].capitalize()),
         | 
| 77 | 
            +
                            question=parse_question(raw_anno['question'].capitalize()),
         | 
| 78 | 
            +
                            options=options,
         | 
| 79 | 
            +
                            answer=answer,
         | 
| 80 | 
            +
                            ans=ans,
         | 
| 81 | 
            +
                            span=[span])
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                        annos.append(anno)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    return annos
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            @DATASETS.register(name='qa_ego4d_crop')
         | 
| 89 | 
            +
            class QAEgo4DCropDataset(AnsweringCropDataset, QAEgo4DDataset):
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                SOURCE = 'qa_ego4d_crop'
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            @DATASETS.register(name='qa_ego4d_grounding')
         | 
| 95 | 
            +
            class QAEgo4DGroundingDataset(GroundingDataset, QAEgo4DDataset):
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                SOURCE = 'qa_ego4d_grounding'
         | 
| 98 | 
            +
                DATA_TYPE = 'grounding'
         | 
    	
        videomind/dataset/sub_classes/queryd.py
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='queryd')
         | 
| 11 | 
            +
            class QuerYDDataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                VID_PATH = 'data/queryd/train_list.txt'
         | 
| 14 | 
            +
                QUERY_PATH = 'data/queryd/raw_captions_combined_filtered-v2.pkl'
         | 
| 15 | 
            +
                SPAN_PATH = 'data/queryd/times_captions_combined_filtered-v2.pkl'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                VIDEO_ROOT = 'data/queryd/videos_3fps_480_noaudio'
         | 
| 18 | 
            +
                DURATIONS = 'data/queryd/durations.json'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                UNIT = 0.001
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @classmethod
         | 
| 23 | 
            +
                def load_annos(self, split='train'):
         | 
| 24 | 
            +
                    assert split == 'train'
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    vids = nncore.load(self.VID_PATH)
         | 
| 27 | 
            +
                    queries = nncore.load(self.QUERY_PATH)
         | 
| 28 | 
            +
                    spans = nncore.load(self.SPAN_PATH)
         | 
| 29 | 
            +
                    durations = nncore.load(self.DURATIONS)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    annos = []
         | 
| 32 | 
            +
                    for vid in vids:
         | 
| 33 | 
            +
                        for query, span in zip(queries[vid], spans[vid]):
         | 
| 34 | 
            +
                            video_name = vid[6:]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                            if video_name not in durations:
         | 
| 37 | 
            +
                                continue
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                            anno = dict(
         | 
| 40 | 
            +
                                source='queryd',
         | 
| 41 | 
            +
                                data_type='grounding',
         | 
| 42 | 
            +
                                video_path=nncore.join(self.VIDEO_ROOT, video_name + '.mp4'),
         | 
| 43 | 
            +
                                duration=durations[video_name],
         | 
| 44 | 
            +
                                query=parse_query(' '.join(query)),
         | 
| 45 | 
            +
                                span=[span])
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                            annos.append(anno)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/qvhighlights.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='qvhighlights')
         | 
| 11 | 
            +
            class QVHighlightsDataset(GroundingDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_TRAIN = 'data/qvhighlights/highlight_train_release.jsonl'
         | 
| 14 | 
            +
                ANNO_PATH_VALID = 'data/qvhighlights/highlight_val_release.jsonl'
         | 
| 15 | 
            +
                ANNO_PATH_TEST = 'data/qvhighlights/highlight_test_release.jsonl'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                VIDEO_ROOT = 'data/qvhighlights/videos_3fps_480_noaudio'
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                UNIT = 2.0
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                @classmethod
         | 
| 22 | 
            +
                def load_annos(self, split='train'):
         | 
| 23 | 
            +
                    if split == 'train':
         | 
| 24 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 25 | 
            +
                    elif split == 'valid':
         | 
| 26 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        print('WARNING: Test split does not have ground truth annotations')
         | 
| 29 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    annos = []
         | 
| 32 | 
            +
                    for raw_anno in raw_annos:
         | 
| 33 | 
            +
                        vid = raw_anno['vid']
         | 
| 34 | 
            +
                        qid = raw_anno['qid']
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                        anno = dict(
         | 
| 37 | 
            +
                            source='qvhighlights',
         | 
| 38 | 
            +
                            data_type='grounding',
         | 
| 39 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 40 | 
            +
                            duration=raw_anno['duration'],
         | 
| 41 | 
            +
                            query=parse_query(raw_anno['query']),
         | 
| 42 | 
            +
                            span=raw_anno.get('relevant_windows'),
         | 
| 43 | 
            +
                            vid=vid,
         | 
| 44 | 
            +
                            qid=qid)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        annos.append(anno)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    return annos
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            @DATASETS.register(name='qvhighlights_single')
         | 
| 52 | 
            +
            class QVHighlightsSingleDataset(QVHighlightsDataset):
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @classmethod
         | 
| 55 | 
            +
                def load_annos(self, split='train'):
         | 
| 56 | 
            +
                    assert split == 'train'
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    annos = []
         | 
| 61 | 
            +
                    for raw_anno in raw_annos:
         | 
| 62 | 
            +
                        # skip samples with multiple moments
         | 
| 63 | 
            +
                        if len(raw_anno['relevant_windows']) > 1:
         | 
| 64 | 
            +
                            continue
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        vid = raw_anno['vid']
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                        anno = dict(
         | 
| 69 | 
            +
                            source='qvhighlights_single',
         | 
| 70 | 
            +
                            data_type='grounding',
         | 
| 71 | 
            +
                            video_path=nncore.join(self.VIDEO_ROOT, vid + '.mp4'),
         | 
| 72 | 
            +
                            duration=raw_anno['duration'],
         | 
| 73 | 
            +
                            query=parse_query(raw_anno['query']),
         | 
| 74 | 
            +
                            span=raw_anno.get('relevant_windows'))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                        annos.append(anno)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    return annos
         | 
    	
        videomind/dataset/sub_classes/rextime.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Ye Liu. Licensed under the BSD 3-Clause License.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import nncore
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from videomind.dataset.hybrid import DATASETS
         | 
| 6 | 
            +
            from videomind.dataset.wrappers import AnsweringCropDataset, AnsweringDataset, GroundingDataset
         | 
| 7 | 
            +
            from videomind.utils.parser import parse_query, parse_question
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            @DATASETS.register(name='rextime')
         | 
| 11 | 
            +
            class ReXTimeDataset(AnsweringDataset):
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                ANNO_PATH_TRAIN = 'data/rextime/rextime_train.json'
         | 
| 14 | 
            +
                ANNO_PATH_VALID = 'data/rextime/rextime_val.json'
         | 
| 15 | 
            +
                ANNO_PATH_TEST = 'data/rextime/rextime_test_release.json'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                VIDEO_ROOT_ANET = 'data/activitynet/videos_3fps_480_noaudio'
         | 
| 18 | 
            +
                VIDEO_ROOT_QVHL = 'data/qvhighlights/videos_3fps_480_noaudio'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                DURATIONS_ANET = 'data/activitynet/durations.json'
         | 
| 21 | 
            +
                DURATIONS_QVHL = 'data/qvhighlights/durations.json'
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                SOURCE = 'rextime'
         | 
| 24 | 
            +
                DATA_TYPE = 'multimodal'
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                UNIT = 1.0
         | 
| 27 | 
            +
                MIN_LEN = 64
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @classmethod
         | 
| 30 | 
            +
                def load_annos(self, split='train'):
         | 
| 31 | 
            +
                    if split == 'train':
         | 
| 32 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TRAIN)
         | 
| 33 | 
            +
                    elif split == 'valid':
         | 
| 34 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_VALID)
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        print('WARNING: Test split does not have ground truth annotations')
         | 
| 37 | 
            +
                        raw_annos = nncore.load(self.ANNO_PATH_TEST)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    durations_anet = nncore.load(self.DURATIONS_ANET)
         | 
| 40 | 
            +
                    durations_qvhl = nncore.load(self.DURATIONS_QVHL)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    annos = []
         | 
| 43 | 
            +
                    for raw_anno in raw_annos:
         | 
| 44 | 
            +
                        vid = raw_anno['vid']
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                        if len(vid) == 13:
         | 
| 47 | 
            +
                            video_path = nncore.join(self.VIDEO_ROOT_ANET, vid + '.mp4')
         | 
| 48 | 
            +
                            duration = durations_anet[vid]
         | 
| 49 | 
            +
                        else:
         | 
| 50 | 
            +
                            video_path = nncore.join(self.VIDEO_ROOT_QVHL, vid + '.mp4')
         | 
| 51 | 
            +
                            duration = durations_qvhl[vid]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                        anno = dict(
         | 
| 54 | 
            +
                            source=self.SOURCE,
         | 
| 55 | 
            +
                            data_type=self.DATA_TYPE,
         | 
| 56 | 
            +
                            video_path=video_path,
         | 
| 57 | 
            +
                            duration=duration,
         | 
| 58 | 
            +
                            query=parse_query(raw_anno['question']),
         | 
| 59 | 
            +
                            question=parse_question(raw_anno['question']),
         | 
| 60 | 
            +
                            options=[o.capitalize() for o in raw_anno['options']],
         | 
| 61 | 
            +
                            answer=raw_anno['answer'].replace('From <s0> to <e0>, ', '').capitalize(),
         | 
| 62 | 
            +
                            ans=raw_anno['ans'],
         | 
| 63 | 
            +
                            span=[raw_anno['span']],
         | 
| 64 | 
            +
                            task=raw_anno['category'])
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                        annos.append(anno)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return annos
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            @DATASETS.register(name='rextime_crop')
         | 
| 72 | 
            +
            class ReXTimeCropDataset(AnsweringCropDataset, ReXTimeDataset):
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                SOURCE = 'rextime_crop'
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            @DATASETS.register(name='rextime_grounding')
         | 
| 78 | 
            +
            class ReXTimeGroundingDataset(GroundingDataset, ReXTimeDataset):
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                SOURCE = 'rextime_grounding'
         | 
| 81 | 
            +
                DATA_TYPE = 'grounding'
         | 
