File size: 882 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from functools import partial
from typing import Tuple, Dict, Any, Type
from transformers.trainer import DataCollator
from .shikra import ShikraTrainer
from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage
TYPE2TRAINER = {
'shikra': ShikraTrainer,
}
def prepare_trainer_collator(
model_args,
preprocessor: Dict[str, Any],
collator_kwargs: Dict[str, Any]
) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]:
type_ = model_args.type
trainer_cls = TYPE2TRAINER[type_]
data_collator_func = partial(
Seq2Seq2DataCollatorWithImage,
preprocessor=preprocessor,
**collator_kwargs,
)
data_collator_dict = {
"train_collator": data_collator_func(inference_mode=False),
"eval_collator": data_collator_func(inference_mode=True),
}
return trainer_cls, data_collator_dict
|