File size: 4,691 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from functools import partial
from typing import Callable, Dict, Tuple, Any, Optional
from torch.utils.data import Dataset
from transformers import EvalPrediction, TrainingArguments
from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
from .single_image_convsation import SingleImageConvDataset
from .single_image_interactive import SingleImageInteractive
from ..conversation import get_conv_template
from .utils import init_ceph_client_if_needed
DatasetDict = Dict[str, Dataset]
ComputeMetrics = Callable[[EvalPrediction], Dict]
def prepare_data(
data_args,
model_args,
training_args: TrainingArguments,
preprocessor: Dict[str, Any],
) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
# raw dataset
datasets = {
'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
}
# compute metric
compute_metric_cfg = data_args.get('compute_metric', None)
compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
# conv dataset wrap
conv_args = model_args.conv_args
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
conv_template = partial(get_conv_template, name=conv_template)
transforms = conv_args.get('transforms', None)
if transforms is not None:
transforms = TRANSFORMS.build(transforms)
# process func
process_func = {}
for k, v in model_args.process_func_args.items():
process_func[k] = FUNCTIONS.build(cfg=v)
conv_dataset_cls = partial(
SingleImageConvDataset,
preprocessor=preprocessor,
process_func=process_func,
tokenize_kwargs=tokenize_kwargs,
conv_template=conv_template,
training_args=training_args,
transforms=transforms,
)
ds = {
'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
}
# multi test set
if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
print(f"processing multitest set")
k2v = {}
for k, item in data_args.multitest.items():
_dataset_cls = partial(DATASETS.build, item['cfg'])
_compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
k2v[k] = {
"dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
"compute_metric": _compute_metric
}
ds['multitest'] = k2v
print(f"processing multitest set. done.")
# in default, ceph client do init at the beginning of program.
# importantly, before dataloader worker fork.
lazy_init = data_args.get('lazy_init', True)
if not lazy_init:
init_ceph_client_if_needed()
return ds, compute_metrics
def build_compute_metric(compute_metric_cfg, preprocessor):
if compute_metric_cfg is not None:
compute_metric_cfg = dict(compute_metric_cfg) # copy cfg because we modify it
compute_metric_cfg.update(dict(preprocessor=preprocessor))
compute_metrics = METRICS.build(cfg=compute_metric_cfg)
else:
compute_metrics = None
return compute_metrics
def prepare_interactive(
model_args,
preprocessor: Dict[str, Any],
):
conv_args = model_args.conv_args
tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
conv_template = partial(get_conv_template, name=conv_template)
transforms = conv_args.get('transforms', None)
if transforms is not None:
transforms = TRANSFORMS.build(transforms)
# process func
process_func = {}
for k, v in model_args.process_func_args.items():
process_func[k] = FUNCTIONS.build(cfg=v)
ds = SingleImageInteractive(
preprocessor=preprocessor,
process_func=process_func,
tokenize_kwargs=tokenize_kwargs,
conv_template=conv_template,
training_args=None,
transforms=transforms,
mode='test',
)
return ds
|