import os import json import random import json from pathlib import Path from llava.datasets.builder import DATASETS from pathlib import Path from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2 from llava.constants import DEFAULT_VIDEO_TOKEN from llava.utils import master_print class GPT4VTTVqaDataset(FramesTaskDataset): def __init__(self, anno_path, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_tt_vqa'): self.default_fps = 0.5 self.fps = fps self.conv_type = conv_type self.task_types = task_types self.annotation = self.get_dataset(anno_path) assert self.conv_type in ('single', 'multi'), "gpt4v_tt_vqa conv type must in single/multi" # assert hasattr(self.data_args, 'task_types'), "gpt4v_tt_vqa must have key 'task_types' in yaml config" # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...") super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) def get_dataset(self, anno_path): dataset = [] anno_path = Path(anno_path) with anno_path.open('rb') as f: data = json.load(f) for info in data: for task_type in self.task_types: info_task = info.copy() if task_type not in info or len(info_task[task_type]) == 0: continue if task_type == 'qas' and self.conv_type == 'single': for qa_pair in info_task[task_type]: one_info = info_task.copy() one_info[task_type] = [qa_pair] one_info.update({ 'task_type': task_type }) dataset.append(one_info) else: info_task.update({ 'task_type': task_type }) dataset.append(info_task) return dataset def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] if hasattr(self.data_args, 'caption_prompt'): cap_prompt = eval(self.data_args.caption_prompt) else: cap_prompt = tt_caption_prompt if item['task_type'] == 'caption': all_convs.append([ { 'from': 'human', 'value': random.choice(cap_prompt) }, { 'from': 'model', 'value': item['caption'] } ]) else: for idx, qa in enumerate(item['qas']): all_convs.append([ { 'from': 'human', 'value': qa['q'] }, { 'from': 'model', 'value': qa['a'] } ]) conversations = [] random.shuffle(all_convs) for idx, conv in enumerate(all_convs): if idx == 0: conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] conversations.extend(conv) return conversations @DATASETS.register_obj def gpt4v_tt_vqa(data_args): anno_path = None if 'train_data_path' in data_args.external_args: anno_path = data_args.external_args['train_data_path'] else: anno_path = data_configs["gpt4v_tt_vqa"]['train_data_path'] fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types'] return GPT4VTTVqaDataset(anno_path, data_args, fps, conv_type, task_types)