# HuggingFace trainer import json import os from dataclasses import dataclass from typing import Callable, Dict, Optional import torch from datasets import concatenate_datasets from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoTokenizer, Idefics2Processor, PreTrainedModel, PreTrainedTokenizer, TrainingArguments from colpali_engine.dataset.custom_collator import CustomCollator from colpali_engine.loss.colbert_loss import BiEncoderLoss, BiPairwiseCELoss, ColbertLoss, ColbertPairwiseCELoss from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator from colpali_engine.utils.gpu_stats import print_gpu_utilization, print_summary @dataclass class ColModelTrainingConfig: model: PreTrainedModel tr_args: TrainingArguments = None output_dir: str = None max_length: int = 256 run_eval: bool = True run_train: bool = True peft_config: Optional[LoraConfig] = None add_suffix: bool = False processor: Idefics2Processor = None tokenizer: PreTrainedTokenizer = None loss_func: Optional[Callable] = ColbertLoss() dataset_loading_func: Optional[Callable] = None eval_dataset_loader: Optional[Dict[str, Callable]] = None pretrained_peft_model_name_or_path: Optional[str] = None def __post_init__(self): if self.output_dir is None: sanitized_name = str(self.model.name_or_path).replace("/", "_") self.output_dir = f"./models/{sanitized_name}" if self.tr_args is None: self.tr_args = TrainingArguments(output_dir=self.output_dir) elif self.tr_args.output_dir is None: self.tr_args.output_dir = self.output_dir # cast if string if isinstance(self.tr_args.learning_rate, str): self.tr_args.learning_rate = float(self.tr_args.learning_rate) self.tr_args.remove_unused_columns = False if self.processor is None and self.tokenizer is None: print("Using textual model tokenization") self.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) if self.pretrained_peft_model_name_or_path is not None: self.model.load_adapter(self.pretrained_peft_model_name_or_path) print(f"Loaded pretrained adapter from {self.pretrained_peft_model_name_or_path}") if self.peft_config is not None: print("Configurating PEFT model") if self.processor is None: # Might be deprecated - use the "else" branch self.model = prepare_model_for_kbit_training(self.model) # use_gradient_checkpointing=True # self.model.enable_input_require_grads() self.model = get_peft_model(self.model, self.peft_config) self.model.print_trainable_parameters() else: # Ugly debugging hack # if self.model.model.config.text_config.vocab_size == 32000: # print("DEBUG: Resizing token embeddings - This should not happen in a real scenario!") # self.model.model.text_model.resize_token_embeddings(32003) # self.model.model.vision_model.encoder.layers = self.model.model.vision_model.encoder.layers[0:2] # self.model.enable_input_require_grads() if self.pretrained_peft_model_name_or_path is None: self.model.add_adapter(self.peft_config) self.model.enable_adapters() else: print(f"Adapter already loaded from {self.pretrained_peft_model_name_or_path}. Not overwriting.") print_gpu_utilization() class ColModelTraining: def __init__(self, config: ColModelTrainingConfig) -> None: self.config = config self.model = self.config.model self.dataset = self.config.dataset_loading_func() self.collator = CustomCollator( processor=self.config.processor, tokenizer=self.config.tokenizer, max_length=self.config.max_length ) self.current_git_hash = os.popen("git rev-parse HEAD").read().strip() self.retriever_evaluator = CustomEvaluator( is_multi_vector=( isinstance(self.config.loss_func, ColbertLoss) or isinstance(self.config.loss_func, ColbertPairwiseCELoss) ) ) def train(self) -> None: trainer = ContrastiveTrainer( model=self.model, train_dataset=self.dataset["train"], eval_dataset=self.dataset["test"], args=self.config.tr_args, data_collator=self.collator, loss_func=self.config.loss_func, is_vision_model=self.config.processor is not None, ) trainer.args.remove_unused_columns = False result = trainer.train() print_summary(result) def eval_dataset(self, test_dataset): self.model.eval() # # debug # if len(test_dataset) > 200: # test_dataset = test_dataset.select(range(0, 100)) idx_with_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is not None] idx_without_query = [idx for idx, sample in enumerate(test_dataset["query"]) if sample is None] dataloader_with_query = DataLoader( test_dataset.select(idx_with_query), batch_size=self.config.tr_args.per_device_eval_batch_size, shuffle=False, collate_fn=self.collator, ) dataloader_without_query = DataLoader( test_dataset.select(idx_without_query), batch_size=self.config.tr_args.per_device_eval_batch_size, shuffle=False, collate_fn=self.collator, ) # dataset is ordered so that non-null queries come first test_dataset = concatenate_datasets( [test_dataset.select(idx_with_query), test_dataset.select(idx_without_query)] ) relevant_docs = {} docidx_2_docid = {} qsidx_2_query = [] for idx, sample in enumerate(test_dataset): doc_id = sample["image_filename"] if "image_filename" in sample else str(hash(sample["doc"])) # query_id = sample["query_id"] if "query_id" in sample else str(hash(sample["query"])) if sample["query"] is not None: relevant_docs[str(idx)] = {doc_id: 1} qsidx_2_query.append(str(idx)) docidx_2_docid[str(idx)] = doc_id qs = [] ps = [] device = self.model.device with (torch.no_grad()): for dataloader in [dataloader_with_query, dataloader_without_query]: for batch in tqdm(dataloader): if "doc_pixel_values" not in batch: doc = self.model( input_ids=batch["doc_input_ids"].to(device), attention_mask=batch["doc_attention_mask"].to(device), ) else: if "doc_pixel_attention_mask" in batch: doc = self.model( input_ids=batch["doc_input_ids"].to(device), attention_mask=batch["doc_attention_mask"].to(device), pixel_values=batch["doc_pixel_values"].to(device), pixel_attention_mask=batch["doc_pixel_attention_mask"].to(device), ) else: doc = self.model( input_ids=batch["doc_input_ids"].to(device), attention_mask=batch["doc_attention_mask"].to(device), pixel_values=batch["doc_pixel_values"].to(device), ) ps.extend(list(torch.unbind(doc.to("cpu")))) if "query_input_ids" in batch: query = self.model( input_ids=batch["query_input_ids"].to(device), attention_mask=batch["query_attention_mask"].to(device), ) # variable len qs.extend(list(torch.unbind(query.to("cpu")))) print("Embeddings computed, evaluating") scores = self.retriever_evaluator.evaluate(qs, ps) # scores is 2d array of shape (n_queries, n_docs) # turn it into a dict results = {} assert scores.shape[0] == len(qsidx_2_query) for idx, scores_per_query in enumerate(scores): results[qsidx_2_query[idx]] = { docidx_2_docid[str(docidx)]: float(score) for docidx, score in enumerate(scores_per_query) } # evaluate metrics = self.retriever_evaluator.compute_metrics(relevant_docs, results) print(metrics) return metrics def eval(self) -> None: print("Evaluating on validation set") metrics = self.eval_dataset(self.dataset["test"]) print(f"Metrics for validation set: {metrics}") all_metrics = {"validation_set": metrics} if self.config.eval_dataset_loader is not None: for test_name, test_dataset_loading_func in self.config.eval_dataset_loader.items(): print(f"Evaluating {test_name}") test_ds = test_dataset_loading_func() metrics = self.eval_dataset(test_ds) all_metrics[test_name] = metrics print(f"Metrics for {test_name}: {metrics}") # checkpoint dumps with open(f"{self.config.output_dir}/results.json", "w") as f: json.dump(all_metrics, f) # save results as json with open(f"{self.config.output_dir}/results.json", "w") as f: json.dump(all_metrics, f) def save(self, config_file): # save model self.model.save_pretrained(self.config.output_dir) if self.config.tokenizer is not None: self.config.tokenizer.save_pretrained(self.config.output_dir) if self.config.processor is not None: self.config.processor.save_pretrained(self.config.output_dir) # save config # copy-paste the yml file with os os.system(f"cp {config_file} {self.config.output_dir}/training_config.yml") # save git hash of the commit at beginning of training with open(f"{self.config.output_dir}/git_hash.txt", "w") as f: f.write(self.current_git_hash)