|
""" |
|
Copyright (c) 2022, salesforce.com, inc. |
|
All rights reserved. |
|
SPDX-License-Identifier: BSD-3-Clause |
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
""" |
|
|
|
import os |
|
|
|
from minigpt4.common.registry import registry |
|
from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder |
|
from minigpt4.datasets.datasets.laion_dataset import LaionDataset |
|
from minigpt4.datasets.datasets.cc_combine_dataset import CCCombineDataset, CCAlignDataset |
|
|
|
|
|
@registry.register_builder("cc_combine") |
|
class CCCombineBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = CCCombineDataset |
|
|
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_combine/defaults.yaml"} |
|
|
|
def _download_ann(self): |
|
pass |
|
|
|
def _download_vis(self): |
|
pass |
|
|
|
def build(self): |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
|
|
datasets = dict() |
|
split = "train" |
|
|
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets[split] = dataset_cls( |
|
vis_processor=self.vis_processors[split], |
|
text_processor=self.text_processors[split], |
|
location=build_info.storage, |
|
).inner_dataset |
|
|
|
return datasets |
|
|
|
|
|
@registry.register_builder("laion") |
|
class LaionBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = LaionDataset |
|
|
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} |
|
|
|
def _download_ann(self): |
|
pass |
|
|
|
def _download_vis(self): |
|
pass |
|
|
|
def build(self): |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
|
|
datasets = dict() |
|
split = "train" |
|
|
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets[split] = dataset_cls( |
|
vis_processor=self.vis_processors[split], |
|
text_processor=self.text_processors[split], |
|
location=build_info.storage, |
|
).inner_dataset |
|
|
|
return datasets |
|
|
|
|
|
@registry.register_builder("cc_align") |
|
class CCAlignBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = CCAlignDataset |
|
|
|
DATASET_CONFIG_DICT = { |
|
"default": "configs/datasets/cc_combine/align.yaml", |
|
} |