File size: 8,075 Bytes
78757b7
 
 
 
 
 
0a375ac
78757b7
 
0a375ac
 
056b95d
0a375ac
78757b7
 
 
 
 
 
 
 
0a375ac
78757b7
 
 
 
 
0a375ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78757b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056b95d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78757b7
 
 
 
 
 
 
 
056b95d
 
 
78757b7
 
 
 
 
056b95d
 
 
 
 
 
 
 
 
 
78757b7
056b95d
 
78757b7
056b95d
78757b7
 
 
 
 
0a375ac
78757b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056b95d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78757b7
 
2762989
0a375ac
 
78757b7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import torch
from dataclasses import dataclass
from accelerate import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
from kto_dataset_processor import process_feel_dataset, SupportedLanguages
from datetime import datetime
import wandb
from enum import Enum
from typing import Optional
from pathlib import Path


# PEFT library: attach and load adapters
from peft import get_peft_model, PeftModel

####################################
#  CONFIGURATION
####################################


@dataclass
class ScriptArguments:
    """
    Configuration for the script.
    """
    process_dataset_func: callable = process_feel_dataset
    checkpoint_path: str = None
    push_to_hub: bool = True
    language: str = "English"  # Default to English

    def __post_init__(self):
        """Validate the language after initialization"""
        try:
            # This will raise ValueError if language is not in the enum
            SupportedLanguages(self.language)
        except ValueError:
            supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
            raise ValueError(
                f"Invalid language: '{self.language}'\n"
                f"Supported languages are:\n- {supported_langs}"
            )

@dataclass
class ModelArguments(ModelConfig):
    """
    Configuration for the model.
    """
    model_name: str = "CohereForAI/aya-expanse-8b"
    use_peft: bool = True
    lora_target_modules: str = "all-linear"
    lora_r: int = 16
    lora_alpha: int = 16
    trust_remote_code: bool = True

@dataclass
class TrainingArguments(KTOConfig):
    """
    Configuration for the KTO trainer.
    """
    output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    num_train_epochs: int = 1
    per_device_train_batch_size: int = 4
    learning_rate: float = 5e-7
    lr_scheduler_type: str = "cosine"
    gradient_accumulation_steps: int = 1
    logging_steps: int = 10
    eval_steps: int = 500
    warmup_ratio: float = 0.1
    bf16: bool = True
    logging_first_step: bool = True

# Initialize configurations
script_args = ScriptArguments()
training_args = TrainingArguments()
model_args = ModelArguments()

####################################
#  HELPER FUNCTIONS
####################################

def load_model_and_tokenizer(model_args):
    """
    Load the base model and tokenizer from the Hugging Face Hub.
    """
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name,
        trust_remote_code=model_args.trust_remote_code,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name,
        trust_remote_code=model_args.trust_remote_code
    )

    # Set pad token if it is missing
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Setup chat format if not available on the tokenizer
    if not getattr(tokenizer, "chat_template", None):
        model, tokenizer = setup_chat_format(model, tokenizer)

    return model, tokenizer

def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
    """
    Generate standardized adapter path.
    If timestamp is None, returns the base language directory.
    Otherwise, returns specific adapter version path.

    Format: adapters/{model_name}/{language}/version_{timestamp}
    """
    # Clean model name (remove slashes, etc.)
    clean_model_name = model_name.replace('/', '_')

    base_path = Path("adapters") / clean_model_name / language
    if timestamp:
        return base_path / f"version_{timestamp}"
    return base_path

def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
    """
    Load the most recent adapter for given model and language.
    Returns: (loaded_model, timestamp of loaded adapter)
    """
    adapter_base = get_adapter_path(model_name, language)

    if not adapter_base.exists():
        return None, None

    # Get all version directories and sort by timestamp
    versions = sorted(
        [d for d in adapter_base.glob("version_*")],
        key=lambda x: x.name,
        reverse=True
    )

    if not versions:
        return None, None

    latest_version = versions[0]
    timestamp = latest_version.name.replace("version_", "")

    model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
    return model, timestamp

####################################
#  MAIN LOGIC
####################################

def main():
    # Initialize wandb for logging
    wandb.init(project="kto")

    # Get timestamp at start of training
    training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

    print("Loading base model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(model_args)
    ref_model, _ = load_model_and_tokenizer(model_args)
    print("Models and tokenizer loaded.")

    # Load existing adapter or create new one
    loaded_model, previous_timestamp = load_latest_adapter(
        model,
        model_args.model_name,
        script_args.language
    )

    if loaded_model is not None:
        model = loaded_model
        print(f"Loaded existing adapter trained at {previous_timestamp}")
    else:
        # Initialize new LoRA adapter
        peft_config = get_peft_config(model_args)
        model = get_peft_model(model, peft_config)
        print("Initialized new adapter")

    # -----------------------------
    # Data Preparation and Training
    # -----------------------------
    print("Processing dataset...")
    dataset = script_args.process_dataset_func(script_args.language)
    print("Dataset processed.")

    print("Initializing trainer...")
    trainer = KTOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    # Training
    print("Starting training...")
    trainer.train()
    print("Training completed.")

    # Evaluation
    print("Evaluating model...")
    metrics = trainer.evaluate()
    print(f"Metrics: {metrics}")
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    # Log metrics to wandb
    wandb.log({
        "epoch": metrics.get("epoch"),
        "grad_norm": metrics.get("grad_norm"),
        "kl": metrics.get("kl"),
        "learning_rate": metrics.get("learning_rate"),
        "logits/chosen": metrics.get("logits/chosen"),
        "logits/rejected": metrics.get("logits/rejected"),
        "logps/chosen": metrics.get("logps/chosen"),
        "logps/rejected": metrics.get("logps/rejected"),
        "loss": metrics.get("loss"),
        "rewards/chosen": metrics.get("rewards/chosen"),
        "rewards/margins": metrics.get("rewards/margins"),
        "rewards/rejected": metrics.get("rewards/rejected"),
        "step": metrics.get("step")
    })

    # Save the adapter
    adapter_path = get_adapter_path(
        model_args.model_name,
        script_args.language,
        training_timestamp
    )
    adapter_path.parent.mkdir(parents=True, exist_ok=True)

    print(f"Saving adapter to: {adapter_path}")
    model.save_pretrained(adapter_path)

    # Save metadata
    metadata = AdapterMetadata(
        training_timestamp=training_timestamp,
        model_name=model_args.model_name,
        language=script_args.language,
    )
    metadata.save(adapter_path / "metadata.json")

    if script_args.push_to_hub:
        repo_id = f"feel-fl/adapters/{model_args.model_name.replace('/', '_')}/{script_args.language}"
        print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
        model.push_to_hub(repo_id=repo_id)

    print("Process completed.")

    # Finish wandb run
    wandb.finish()

if __name__ == "__main__":
    main()