import os import torch from dataclasses import dataclass from accelerate import PartialState from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format from kto_dataset_processor import process_feel_dataset, SupportedLanguages from datetime import datetime import wandb from enum import Enum from typing import Optional from pathlib import Path # PEFT library: attach and load adapters from peft import get_peft_model, PeftModel #################################### # CONFIGURATION #################################### @dataclass class ScriptArguments: """ Configuration for the script. """ process_dataset_func: callable = process_feel_dataset checkpoint_path: str = None push_to_hub: bool = True language: str = "English" # Default to English def __post_init__(self): """Validate the language after initialization""" try: # This will raise ValueError if language is not in the enum SupportedLanguages(self.language) except ValueError: supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages]) raise ValueError( f"Invalid language: '{self.language}'\n" f"Supported languages are:\n- {supported_langs}" ) @dataclass class ModelArguments(ModelConfig): """ Configuration for the model. """ model_name: str = "CohereForAI/aya-expanse-8b" use_peft: bool = True lora_target_modules: str = "all-linear" lora_r: int = 16 lora_alpha: int = 16 trust_remote_code: bool = True @dataclass class TrainingArguments(KTOConfig): """ Configuration for the KTO trainer. """ output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" num_train_epochs: int = 1 per_device_train_batch_size: int = 4 learning_rate: float = 5e-7 lr_scheduler_type: str = "cosine" gradient_accumulation_steps: int = 1 logging_steps: int = 10 eval_steps: int = 500 warmup_ratio: float = 0.1 bf16: bool = True logging_first_step: bool = True # Initialize configurations script_args = ScriptArguments() training_args = TrainingArguments() model_args = ModelArguments() #################################### # HELPER FUNCTIONS #################################### def load_model_and_tokenizer(model_args): """ Load the base model and tokenizer from the Hugging Face Hub. """ model = AutoModelForCausalLM.from_pretrained( model_args.model_name, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name, trust_remote_code=model_args.trust_remote_code ) # Set pad token if it is missing if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Setup chat format if not available on the tokenizer if not getattr(tokenizer, "chat_template", None): model, tokenizer = setup_chat_format(model, tokenizer) return model, tokenizer def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path: """ Generate standardized adapter path. If timestamp is None, returns the base language directory. Otherwise, returns specific adapter version path. Format: adapters/{model_name}/{language}/version_{timestamp} """ # Clean model name (remove slashes, etc.) clean_model_name = model_name.replace('/', '_') base_path = Path("adapters") / clean_model_name / language if timestamp: return base_path / f"version_{timestamp}" return base_path def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]: """ Load the most recent adapter for given model and language. Returns: (loaded_model, timestamp of loaded adapter) """ adapter_base = get_adapter_path(model_name, language) if not adapter_base.exists(): return None, None # Get all version directories and sort by timestamp versions = sorted( [d for d in adapter_base.glob("version_*")], key=lambda x: x.name, reverse=True ) if not versions: return None, None latest_version = versions[0] timestamp = latest_version.name.replace("version_", "") model = PeftModel.from_pretrained(model, latest_version, is_trainable=True) return model, timestamp #################################### # MAIN LOGIC #################################### def main(): # Initialize wandb for logging wandb.init(project="kto") # Get timestamp at start of training training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') print("Loading base model and tokenizer...") model, tokenizer = load_model_and_tokenizer(model_args) ref_model, _ = load_model_and_tokenizer(model_args) print("Models and tokenizer loaded.") # Load existing adapter or create new one loaded_model, previous_timestamp = load_latest_adapter( model, model_args.model_name, script_args.language ) if loaded_model is not None: model = loaded_model print(f"Loaded existing adapter trained at {previous_timestamp}") else: # Initialize new LoRA adapter peft_config = get_peft_config(model_args) model = get_peft_model(model, peft_config) print("Initialized new adapter") # ----------------------------- # Data Preparation and Training # ----------------------------- print("Processing dataset...") dataset = script_args.process_dataset_func(script_args.language) print("Dataset processed.") print("Initializing trainer...") trainer = KTOTrainer( model=model, ref_model=ref_model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], processing_class=tokenizer, peft_config=peft_config, ) # Training print("Starting training...") trainer.train() print("Training completed.") # Evaluation print("Evaluating model...") metrics = trainer.evaluate() print(f"Metrics: {metrics}") trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Log metrics to wandb wandb.log({ "epoch": metrics.get("epoch"), "grad_norm": metrics.get("grad_norm"), "kl": metrics.get("kl"), "learning_rate": metrics.get("learning_rate"), "logits/chosen": metrics.get("logits/chosen"), "logits/rejected": metrics.get("logits/rejected"), "logps/chosen": metrics.get("logps/chosen"), "logps/rejected": metrics.get("logps/rejected"), "loss": metrics.get("loss"), "rewards/chosen": metrics.get("rewards/chosen"), "rewards/margins": metrics.get("rewards/margins"), "rewards/rejected": metrics.get("rewards/rejected"), "step": metrics.get("step") }) # Save the adapter adapter_path = get_adapter_path( model_args.model_name, script_args.language, training_timestamp ) adapter_path.parent.mkdir(parents=True, exist_ok=True) print(f"Saving adapter to: {adapter_path}") model.save_pretrained(adapter_path) # Save metadata metadata = AdapterMetadata( training_timestamp=training_timestamp, model_name=model_args.model_name, language=script_args.language, ) metadata.save(adapter_path / "metadata.json") if script_args.push_to_hub: repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}" print(f"Pushing adapter to Hugging Face Hub at {repo_id}...") model.push_to_hub(repo_id=repo_id) print("Process completed.") # Finish wandb run wandb.finish() if __name__ == "__main__": main()