Spaces:
Running
on
A100
Running
on
A100
import os | |
import torch | |
from dataclasses import dataclass | |
from accelerate import PartialState | |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser | |
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format | |
from kto_dataset_processor import process_feel_dataset, SupportedLanguages | |
from datetime import datetime | |
import wandb | |
from enum import Enum | |
from typing import Optional | |
from pathlib import Path | |
# PEFT library: attach and load adapters | |
from peft import get_peft_model, PeftModel | |
#################################### | |
# CONFIGURATION | |
#################################### | |
class ScriptArguments: | |
""" | |
Configuration for the script. | |
""" | |
process_dataset_func: callable = process_feel_dataset | |
checkpoint_path: str = None | |
push_to_hub: bool = True | |
language: str = "English" # Default to English | |
def __post_init__(self): | |
"""Validate the language after initialization""" | |
try: | |
# This will raise ValueError if language is not in the enum | |
SupportedLanguages(self.language) | |
except ValueError: | |
supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages]) | |
raise ValueError( | |
f"Invalid language: '{self.language}'\n" | |
f"Supported languages are:\n- {supported_langs}" | |
) | |
class ModelArguments(ModelConfig): | |
""" | |
Configuration for the model. | |
""" | |
model_name: str = "CohereForAI/aya-expanse-8b" | |
use_peft: bool = True | |
lora_target_modules: str = "all-linear" | |
lora_r: int = 16 | |
lora_alpha: int = 16 | |
trust_remote_code: bool = True | |
class TrainingArguments(KTOConfig): | |
""" | |
Configuration for the KTO trainer. | |
""" | |
output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" | |
num_train_epochs: int = 1 | |
per_device_train_batch_size: int = 4 | |
learning_rate: float = 5e-7 | |
lr_scheduler_type: str = "cosine" | |
gradient_accumulation_steps: int = 1 | |
logging_steps: int = 10 | |
eval_steps: int = 500 | |
warmup_ratio: float = 0.1 | |
bf16: bool = True | |
logging_first_step: bool = True | |
# Initialize configurations | |
script_args = ScriptArguments() | |
training_args = TrainingArguments() | |
model_args = ModelArguments() | |
#################################### | |
# HELPER FUNCTIONS | |
#################################### | |
def load_model_and_tokenizer(model_args): | |
""" | |
Load the base model and tokenizer from the Hugging Face Hub. | |
""" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=model_args.trust_remote_code, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_args.model_name, | |
trust_remote_code=model_args.trust_remote_code | |
) | |
# Set pad token if it is missing | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Setup chat format if not available on the tokenizer | |
if not getattr(tokenizer, "chat_template", None): | |
model, tokenizer = setup_chat_format(model, tokenizer) | |
return model, tokenizer | |
def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path: | |
""" | |
Generate standardized adapter path. | |
If timestamp is None, returns the base language directory. | |
Otherwise, returns specific adapter version path. | |
Format: adapters/{model_name}/{language}/version_{timestamp} | |
""" | |
# Clean model name (remove slashes, etc.) | |
clean_model_name = model_name.replace('/', '_') | |
base_path = Path("adapters") / clean_model_name / language | |
if timestamp: | |
return base_path / f"version_{timestamp}" | |
return base_path | |
def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]: | |
""" | |
Load the most recent adapter for given model and language. | |
Returns: (loaded_model, timestamp of loaded adapter) | |
""" | |
adapter_base = get_adapter_path(model_name, language) | |
if not adapter_base.exists(): | |
return None, None | |
# Get all version directories and sort by timestamp | |
versions = sorted( | |
[d for d in adapter_base.glob("version_*")], | |
key=lambda x: x.name, | |
reverse=True | |
) | |
if not versions: | |
return None, None | |
latest_version = versions[0] | |
timestamp = latest_version.name.replace("version_", "") | |
model = PeftModel.from_pretrained(model, latest_version, is_trainable=True) | |
return model, timestamp | |
#################################### | |
# MAIN LOGIC | |
#################################### | |
def main(): | |
# Initialize wandb for logging | |
wandb.init(project="kto") | |
# Get timestamp at start of training | |
training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
print("Loading base model and tokenizer...") | |
model, tokenizer = load_model_and_tokenizer(model_args) | |
ref_model, _ = load_model_and_tokenizer(model_args) | |
print("Models and tokenizer loaded.") | |
# Load existing adapter or create new one | |
loaded_model, previous_timestamp = load_latest_adapter( | |
model, | |
model_args.model_name, | |
script_args.language | |
) | |
if loaded_model is not None: | |
model = loaded_model | |
print(f"Loaded existing adapter trained at {previous_timestamp}") | |
else: | |
# Initialize new LoRA adapter | |
peft_config = get_peft_config(model_args) | |
model = get_peft_model(model, peft_config) | |
print("Initialized new adapter") | |
# ----------------------------- | |
# Data Preparation and Training | |
# ----------------------------- | |
print("Processing dataset...") | |
dataset = script_args.process_dataset_func(script_args.language) | |
print("Dataset processed.") | |
print("Initializing trainer...") | |
trainer = KTOTrainer( | |
model=model, | |
ref_model=ref_model, | |
args=training_args, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["test"], | |
processing_class=tokenizer, | |
peft_config=peft_config, | |
) | |
# Training | |
print("Starting training...") | |
trainer.train() | |
print("Training completed.") | |
# Evaluation | |
print("Evaluating model...") | |
metrics = trainer.evaluate() | |
print(f"Metrics: {metrics}") | |
trainer.log_metrics("eval", metrics) | |
trainer.save_metrics("eval", metrics) | |
# Log metrics to wandb | |
wandb.log({ | |
"epoch": metrics.get("epoch"), | |
"grad_norm": metrics.get("grad_norm"), | |
"kl": metrics.get("kl"), | |
"learning_rate": metrics.get("learning_rate"), | |
"logits/chosen": metrics.get("logits/chosen"), | |
"logits/rejected": metrics.get("logits/rejected"), | |
"logps/chosen": metrics.get("logps/chosen"), | |
"logps/rejected": metrics.get("logps/rejected"), | |
"loss": metrics.get("loss"), | |
"rewards/chosen": metrics.get("rewards/chosen"), | |
"rewards/margins": metrics.get("rewards/margins"), | |
"rewards/rejected": metrics.get("rewards/rejected"), | |
"step": metrics.get("step") | |
}) | |
# Save the adapter | |
adapter_path = get_adapter_path( | |
model_args.model_name, | |
script_args.language, | |
training_timestamp | |
) | |
adapter_path.parent.mkdir(parents=True, exist_ok=True) | |
print(f"Saving adapter to: {adapter_path}") | |
model.save_pretrained(adapter_path) | |
# Save metadata | |
metadata = AdapterMetadata( | |
training_timestamp=training_timestamp, | |
model_name=model_args.model_name, | |
language=script_args.language, | |
) | |
metadata.save(adapter_path / "metadata.json") | |
if script_args.push_to_hub: | |
repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}" | |
print(f"Pushing adapter to Hugging Face Hub at {repo_id}...") | |
model.push_to_hub(repo_id=repo_id) | |
print("Process completed.") | |
# Finish wandb run | |
wandb.finish() | |
if __name__ == "__main__": | |
main() | |