Spaces:
Sleeping
Sleeping
File size: 8,075 Bytes
78757b7 0a375ac 78757b7 0a375ac 056b95d 0a375ac 78757b7 0a375ac 78757b7 0a375ac 78757b7 056b95d 78757b7 056b95d 78757b7 056b95d 78757b7 056b95d 78757b7 056b95d 78757b7 0a375ac 78757b7 056b95d 78757b7 2762989 0a375ac 78757b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
import os
import torch
from dataclasses import dataclass
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
from kto_dataset_processor import process_feel_dataset, SupportedLanguages
from datetime import datetime
import wandb
from enum import Enum
from typing import Optional
from pathlib import Path
# PEFT library: attach and load adapters
from peft import get_peft_model, PeftModel
####################################
# CONFIGURATION
####################################
@dataclass
class ScriptArguments:
"""
Configuration for the script.
"""
process_dataset_func: callable = process_feel_dataset
checkpoint_path: str = None
push_to_hub: bool = True
language: str = "English" # Default to English
def __post_init__(self):
"""Validate the language after initialization"""
try:
# This will raise ValueError if language is not in the enum
SupportedLanguages(self.language)
except ValueError:
supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
raise ValueError(
f"Invalid language: '{self.language}'\n"
f"Supported languages are:\n- {supported_langs}"
)
@dataclass
class ModelArguments(ModelConfig):
"""
Configuration for the model.
"""
model_name: str = "CohereForAI/aya-expanse-8b"
use_peft: bool = True
lora_target_modules: str = "all-linear"
lora_r: int = 16
lora_alpha: int = 16
trust_remote_code: bool = True
@dataclass
class TrainingArguments(KTOConfig):
"""
Configuration for the KTO trainer.
"""
output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
num_train_epochs: int = 1
per_device_train_batch_size: int = 4
learning_rate: float = 5e-7
lr_scheduler_type: str = "cosine"
gradient_accumulation_steps: int = 1
logging_steps: int = 10
eval_steps: int = 500
warmup_ratio: float = 0.1
bf16: bool = True
logging_first_step: bool = True
# Initialize configurations
script_args = ScriptArguments()
training_args = TrainingArguments()
model_args = ModelArguments()
####################################
# HELPER FUNCTIONS
####################################
def load_model_and_tokenizer(model_args):
"""
Load the base model and tokenizer from the Hugging Face Hub.
"""
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name,
trust_remote_code=model_args.trust_remote_code
)
# Set pad token if it is missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Setup chat format if not available on the tokenizer
if not getattr(tokenizer, "chat_template", None):
model, tokenizer = setup_chat_format(model, tokenizer)
return model, tokenizer
def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
"""
Generate standardized adapter path.
If timestamp is None, returns the base language directory.
Otherwise, returns specific adapter version path.
Format: adapters/{model_name}/{language}/version_{timestamp}
"""
# Clean model name (remove slashes, etc.)
clean_model_name = model_name.replace('/', '_')
base_path = Path("adapters") / clean_model_name / language
if timestamp:
return base_path / f"version_{timestamp}"
return base_path
def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
"""
Load the most recent adapter for given model and language.
Returns: (loaded_model, timestamp of loaded adapter)
"""
adapter_base = get_adapter_path(model_name, language)
if not adapter_base.exists():
return None, None
# Get all version directories and sort by timestamp
versions = sorted(
[d for d in adapter_base.glob("version_*")],
key=lambda x: x.name,
reverse=True
)
if not versions:
return None, None
latest_version = versions[0]
timestamp = latest_version.name.replace("version_", "")
model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
return model, timestamp
####################################
# MAIN LOGIC
####################################
def main():
# Initialize wandb for logging
wandb.init(project="kto")
# Get timestamp at start of training
training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
print("Loading base model and tokenizer...")
model, tokenizer = load_model_and_tokenizer(model_args)
ref_model, _ = load_model_and_tokenizer(model_args)
print("Models and tokenizer loaded.")
# Load existing adapter or create new one
loaded_model, previous_timestamp = load_latest_adapter(
model,
model_args.model_name,
script_args.language
)
if loaded_model is not None:
model = loaded_model
print(f"Loaded existing adapter trained at {previous_timestamp}")
else:
# Initialize new LoRA adapter
peft_config = get_peft_config(model_args)
model = get_peft_model(model, peft_config)
print("Initialized new adapter")
# -----------------------------
# Data Preparation and Training
# -----------------------------
print("Processing dataset...")
dataset = script_args.process_dataset_func(script_args.language)
print("Dataset processed.")
print("Initializing trainer...")
trainer = KTOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=peft_config,
)
# Training
print("Starting training...")
trainer.train()
print("Training completed.")
# Evaluation
print("Evaluating model...")
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Log metrics to wandb
wandb.log({
"epoch": metrics.get("epoch"),
"grad_norm": metrics.get("grad_norm"),
"kl": metrics.get("kl"),
"learning_rate": metrics.get("learning_rate"),
"logits/chosen": metrics.get("logits/chosen"),
"logits/rejected": metrics.get("logits/rejected"),
"logps/chosen": metrics.get("logps/chosen"),
"logps/rejected": metrics.get("logps/rejected"),
"loss": metrics.get("loss"),
"rewards/chosen": metrics.get("rewards/chosen"),
"rewards/margins": metrics.get("rewards/margins"),
"rewards/rejected": metrics.get("rewards/rejected"),
"step": metrics.get("step")
})
# Save the adapter
adapter_path = get_adapter_path(
model_args.model_name,
script_args.language,
training_timestamp
)
adapter_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Saving adapter to: {adapter_path}")
model.save_pretrained(adapter_path)
# Save metadata
metadata = AdapterMetadata(
training_timestamp=training_timestamp,
model_name=model_args.model_name,
language=script_args.language,
)
metadata.save(adapter_path / "metadata.json")
if script_args.push_to_hub:
repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
model.push_to_hub(repo_id=repo_id)
print("Process completed.")
# Finish wandb run
wandb.finish()
if __name__ == "__main__":
main()
|