Spaces:
Running
on
A100
Running
on
A100
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| import random | |
| from typing import Any, Dict, List | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import PreTrainedTokenizer | |
| from llava.mm_utils import dynamic_process_images_and_prompt, dynamic_s2_process_images_and_prompt, process_images | |
| from llava.train.args import DataArguments | |
| from llava.utils.logging import logger | |
| from llava.utils.media import extract_media | |
| from llava.utils.tokenizer import preprocess_conversation | |
| __all__ = ["BaseDataset"] | |
| def _process_speech(speech: List[Any], data_args: DataArguments) -> torch.Tensor: | |
| return torch.tensor(speech) | |
| def _process_sound(sound: List[Any], data_args: DataArguments) -> torch.Tensor: | |
| return torch.tensor(sound) | |
| def _process_sound_masks(sound_masks: List[Any], data_args: DataArguments) -> torch.Tensor: | |
| return torch.tensor(sound_masks) | |
| class BaseDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer: PreTrainedTokenizer, | |
| data_args: DataArguments, | |
| no_system_prompt: bool = False, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.data_args = data_args | |
| self.no_system_prompt = no_system_prompt | |
| self.instances = [] | |
| self.enable_dynamic_res = False | |
| self.enable_dynamic_res_s2 = False | |
| # global_batch_size: int, | |
| self.global_batch_size = kwargs.get("global_batch_size", 1) | |
| # by default, dataset cls will resample on failure | |
| self.resample_on_failure = kwargs.get("resample_on_failure", True) | |
| # by default, dataset cls will resample on failure | |
| self.resample_on_failure = kwargs.get("resample_on_failure", True) | |
| def process(self, instance: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| raise NotImplementedError | |
| def __getitem__(self, index: int) -> Dict[str, Any]: | |
| instance = self.instances[index] | |
| try: | |
| # Process instance to conversation | |
| conversation = self.process(instance) | |
| # Extract media from conversation | |
| media, media_meta = extract_media(conversation, self.data_args) | |
| if "speech" in media: | |
| processed_speech = _process_speech(media["speech"], self.data_args) | |
| if "sound" in media: | |
| processed_sound = _process_sound(media["sound"], self.data_args) | |
| processed_sound_feature_masks = _process_sound_masks(media_meta["sound_feature_masks"], self.data_args) | |
| processed_sound_embed_masks = _process_sound_masks(media_meta["sound_embed_masks"], self.data_args) | |
| # Prepare "input_ids" and "labels" for training | |
| data = preprocess_conversation(conversation, self.tokenizer, no_system_prompt=self.no_system_prompt) | |
| if "speech" in media: | |
| data["speech"] = processed_speech | |
| if "sound" in media: | |
| data["sound"] = processed_sound | |
| data["sound_feature_masks"] = processed_sound_feature_masks | |
| data["sound_embed_masks"] = processed_sound_embed_masks | |
| except Exception as e: | |
| if not self.resample_on_failure: | |
| raise e | |
| else: | |
| logger.exception(f"Error processing instance '{instance}': '{e}'. Resampling.") | |
| return self.__getitem__(random.randint(0, len(self.instances) - 1)) | |
| return data | |
| def __len__(self) -> int: | |
| return len(self.instances) | |