RxnIM / mllm /dataset /builder.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
4.69 kB
from functools import partial
from typing import Callable, Dict, Tuple, Any, Optional
from torch.utils.data import Dataset
from transformers import EvalPrediction, TrainingArguments
from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
from .single_image_convsation import SingleImageConvDataset
from .single_image_interactive import SingleImageInteractive
from ..conversation import get_conv_template
from .utils import init_ceph_client_if_needed
DatasetDict = Dict[str, Dataset]
ComputeMetrics = Callable[[EvalPrediction], Dict]
def prepare_data(
data_args,
model_args,
training_args: TrainingArguments,
preprocessor: Dict[str, Any],
) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
# raw dataset
datasets = {
'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
}
# compute metric
compute_metric_cfg = data_args.get('compute_metric', None)
compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
# conv dataset wrap
conv_args = model_args.conv_args
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
conv_template = partial(get_conv_template, name=conv_template)
transforms = conv_args.get('transforms', None)
if transforms is not None:
transforms = TRANSFORMS.build(transforms)
# process func
process_func = {}
for k, v in model_args.process_func_args.items():
process_func[k] = FUNCTIONS.build(cfg=v)
conv_dataset_cls = partial(
SingleImageConvDataset,
preprocessor=preprocessor,
process_func=process_func,
tokenize_kwargs=tokenize_kwargs,
conv_template=conv_template,
training_args=training_args,
transforms=transforms,
)
ds = {
'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
}
# multi test set
if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
print(f"processing multitest set")
k2v = {}
for k, item in data_args.multitest.items():
_dataset_cls = partial(DATASETS.build, item['cfg'])
_compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
k2v[k] = {
"dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
"compute_metric": _compute_metric
}
ds['multitest'] = k2v
print(f"processing multitest set. done.")
# in default, ceph client do init at the beginning of program.
# importantly, before dataloader worker fork.
lazy_init = data_args.get('lazy_init', True)
if not lazy_init:
init_ceph_client_if_needed()
return ds, compute_metrics
def build_compute_metric(compute_metric_cfg, preprocessor):
if compute_metric_cfg is not None:
compute_metric_cfg = dict(compute_metric_cfg) # copy cfg because we modify it
compute_metric_cfg.update(dict(preprocessor=preprocessor))
compute_metrics = METRICS.build(cfg=compute_metric_cfg)
else:
compute_metrics = None
return compute_metrics
def prepare_interactive(
model_args,
preprocessor: Dict[str, Any],
):
conv_args = model_args.conv_args
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
conv_template = partial(get_conv_template, name=conv_template)
transforms = conv_args.get('transforms', None)
if transforms is not None:
transforms = TRANSFORMS.build(transforms)
# process func
process_func = {}
for k, v in model_args.process_func_args.items():
process_func[k] = FUNCTIONS.build(cfg=v)
ds = SingleImageInteractive(
preprocessor=preprocessor,
process_func=process_func,
tokenize_kwargs=tokenize_kwargs,
conv_template=conv_template,
training_args=None,
transforms=transforms,
mode='test',
)
return ds