File size: 4,691 Bytes
3e1d9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from functools import partial
from typing import Callable, Dict, Tuple, Any, Optional

from torch.utils.data import Dataset
from transformers import EvalPrediction, TrainingArguments

from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS
from .single_image_convsation import SingleImageConvDataset
from .single_image_interactive import SingleImageInteractive
from ..conversation import get_conv_template
from .utils import init_ceph_client_if_needed

DatasetDict = Dict[str, Dataset]
ComputeMetrics = Callable[[EvalPrediction], Dict]


def prepare_data(
        data_args,
        model_args,
        training_args: TrainingArguments,
        preprocessor: Dict[str, Any],
) -> Tuple[DatasetDict, Optional[ComputeMetrics]]:
    # raw dataset
    datasets = {
        'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None,
        'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None,
        'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None,
    }
    # compute metric
    compute_metric_cfg = data_args.get('compute_metric', None)
    compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor)
    # conv dataset wrap
    conv_args = model_args.conv_args
    tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
    conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
    conv_template = partial(get_conv_template, name=conv_template)
    transforms = conv_args.get('transforms', None)
    if transforms is not None:
        transforms = TRANSFORMS.build(transforms)
    # process func
    process_func = {}
    for k, v in model_args.process_func_args.items():
        process_func[k] = FUNCTIONS.build(cfg=v)

    conv_dataset_cls = partial(
        SingleImageConvDataset,
        preprocessor=preprocessor,
        process_func=process_func,
        tokenize_kwargs=tokenize_kwargs,
        conv_template=conv_template,
        training_args=training_args,
        transforms=transforms,
    )
    ds = {
        'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None,
        'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None,
        'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None,
    }

    # multi test set
    if hasattr(data_args, 'multitest') and bool(data_args.multitest) \
            and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict:
        print(f"processing multitest set")
        k2v = {}
        for k, item in data_args.multitest.items():
            _dataset_cls = partial(DATASETS.build, item['cfg'])
            _compute_metric = build_compute_metric(item['compute_metric'], preprocessor)
            k2v[k] = {
                "dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'),
                "compute_metric": _compute_metric
            }
        ds['multitest'] = k2v
        print(f"processing multitest set. done.")

    # in default, ceph client do init at the beginning of program.
    #  importantly, before dataloader worker fork.
    lazy_init = data_args.get('lazy_init', True)
    if not lazy_init:
        init_ceph_client_if_needed()
    return ds, compute_metrics


def build_compute_metric(compute_metric_cfg, preprocessor):
    if compute_metric_cfg is not None:
        compute_metric_cfg = dict(compute_metric_cfg)  # copy cfg because we modify it
        compute_metric_cfg.update(dict(preprocessor=preprocessor))
        compute_metrics = METRICS.build(cfg=compute_metric_cfg)
    else:
        compute_metrics = None
    return compute_metrics


def prepare_interactive(
        model_args,
        preprocessor: Dict[str, Any],
):
    conv_args = model_args.conv_args
    tokenize_kwargs = conv_args.get('tokenize_kwargs', {})
    conv_template = conv_args.get('conv_template', 'vicuna_v1.1')
    conv_template = partial(get_conv_template, name=conv_template)
    transforms = conv_args.get('transforms', None)
    if transforms is not None:
        transforms = TRANSFORMS.build(transforms)
    # process func
    process_func = {}
    for k, v in model_args.process_func_args.items():
        process_func[k] = FUNCTIONS.build(cfg=v)

    ds = SingleImageInteractive(
        preprocessor=preprocessor,
        process_func=process_func,
        tokenize_kwargs=tokenize_kwargs,
        conv_template=conv_template,
        training_args=None,
        transforms=transforms,
        mode='test',
    )
    return ds