Spaces:
Sleeping
Sleeping
jenbenarye
commited on
Commit
·
056b95d
1
Parent(s):
801c17a
changed file name
Browse files- ml/{kto_lora.py → trainer.py} +85 -23
ml/{kto_lora.py → trainer.py}
RENAMED
@@ -9,6 +9,7 @@ from datetime import datetime
|
|
9 |
import wandb
|
10 |
from enum import Enum
|
11 |
from typing import Optional
|
|
|
12 |
|
13 |
|
14 |
# PEFT library: attach and load adapters
|
@@ -104,6 +105,48 @@ def load_model_and_tokenizer(model_args):
|
|
104 |
|
105 |
return model, tokenizer
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
####################################
|
108 |
# MAIN LOGIC
|
109 |
####################################
|
@@ -112,26 +155,29 @@ def main():
|
|
112 |
# Initialize wandb for logging
|
113 |
wandb.init(project="kto")
|
114 |
|
|
|
|
|
|
|
115 |
print("Loading base model and tokenizer...")
|
116 |
model, tokenizer = load_model_and_tokenizer(model_args)
|
117 |
ref_model, _ = load_model_and_tokenizer(model_args)
|
118 |
print("Models and tokenizer loaded.")
|
119 |
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
if
|
128 |
-
|
129 |
-
|
130 |
-
print(f"Loaded existing adapter for language '{script_args.language}' from {adapter_dir}.")
|
131 |
else:
|
132 |
-
#
|
|
|
133 |
model = get_peft_model(model, peft_config)
|
134 |
-
print(
|
135 |
|
136 |
# -----------------------------
|
137 |
# Data Preparation and Training
|
@@ -180,16 +226,32 @@ def main():
|
|
180 |
"step": metrics.get("step")
|
181 |
})
|
182 |
|
183 |
-
#
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
if script_args.push_to_hub:
|
195 |
# Using a consistent naming pattern that links to the FEEL project
|
|
|
9 |
import wandb
|
10 |
from enum import Enum
|
11 |
from typing import Optional
|
12 |
+
from pathlib import Path
|
13 |
|
14 |
|
15 |
# PEFT library: attach and load adapters
|
|
|
105 |
|
106 |
return model, tokenizer
|
107 |
|
108 |
+
def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
|
109 |
+
"""
|
110 |
+
Generate standardized adapter path.
|
111 |
+
If timestamp is None, returns the base language directory.
|
112 |
+
Otherwise, returns specific adapter version path.
|
113 |
+
|
114 |
+
Format: adapters/{model_name}/{language}/version_{timestamp}
|
115 |
+
"""
|
116 |
+
# Clean model name (remove slashes, etc.)
|
117 |
+
clean_model_name = model_name.replace('/', '_')
|
118 |
+
|
119 |
+
base_path = Path("adapters") / clean_model_name / language
|
120 |
+
if timestamp:
|
121 |
+
return base_path / f"version_{timestamp}"
|
122 |
+
return base_path
|
123 |
+
|
124 |
+
def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
|
125 |
+
"""
|
126 |
+
Load the most recent adapter for given model and language.
|
127 |
+
Returns: (loaded_model, timestamp of loaded adapter)
|
128 |
+
"""
|
129 |
+
adapter_base = get_adapter_path(model_name, language)
|
130 |
+
|
131 |
+
if not adapter_base.exists():
|
132 |
+
return None, None
|
133 |
+
|
134 |
+
# Get all version directories and sort by timestamp
|
135 |
+
versions = sorted(
|
136 |
+
[d for d in adapter_base.glob("version_*")],
|
137 |
+
key=lambda x: x.name,
|
138 |
+
reverse=True
|
139 |
+
)
|
140 |
+
|
141 |
+
if not versions:
|
142 |
+
return None, None
|
143 |
+
|
144 |
+
latest_version = versions[0]
|
145 |
+
timestamp = latest_version.name.replace("version_", "")
|
146 |
+
|
147 |
+
model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
|
148 |
+
return model, timestamp
|
149 |
+
|
150 |
####################################
|
151 |
# MAIN LOGIC
|
152 |
####################################
|
|
|
155 |
# Initialize wandb for logging
|
156 |
wandb.init(project="kto")
|
157 |
|
158 |
+
# Get timestamp at start of training
|
159 |
+
training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
160 |
+
|
161 |
print("Loading base model and tokenizer...")
|
162 |
model, tokenizer = load_model_and_tokenizer(model_args)
|
163 |
ref_model, _ = load_model_and_tokenizer(model_args)
|
164 |
print("Models and tokenizer loaded.")
|
165 |
|
166 |
+
# Load existing adapter or create new one
|
167 |
+
loaded_model, previous_timestamp = load_latest_adapter(
|
168 |
+
model,
|
169 |
+
model_args.model_name,
|
170 |
+
script_args.language
|
171 |
+
)
|
172 |
+
|
173 |
+
if loaded_model is not None:
|
174 |
+
model = loaded_model
|
175 |
+
print(f"Loaded existing adapter trained at {previous_timestamp}")
|
|
|
176 |
else:
|
177 |
+
# Initialize new LoRA adapter
|
178 |
+
peft_config = get_peft_config(model_args)
|
179 |
model = get_peft_model(model, peft_config)
|
180 |
+
print("Initialized new adapter")
|
181 |
|
182 |
# -----------------------------
|
183 |
# Data Preparation and Training
|
|
|
226 |
"step": metrics.get("step")
|
227 |
})
|
228 |
|
229 |
+
# Save the adapter
|
230 |
+
adapter_path = get_adapter_path(
|
231 |
+
model_args.model_name,
|
232 |
+
script_args.language,
|
233 |
+
training_timestamp
|
234 |
+
)
|
235 |
+
adapter_path.parent.mkdir(parents=True, exist_ok=True)
|
236 |
+
|
237 |
+
print(f"Saving adapter to: {adapter_path}")
|
238 |
+
model.save_pretrained(adapter_path)
|
239 |
+
|
240 |
+
# Save metadata
|
241 |
+
metadata = AdapterMetadata(
|
242 |
+
training_timestamp=training_timestamp,
|
243 |
+
dataset_entries=[entry["id"] for entry in dataset],
|
244 |
+
training_params={
|
245 |
+
"max_weight": script_args.max_weight,
|
246 |
+
"min_weight": script_args.min_weight,
|
247 |
+
"decay_factor": script_args.decay_factor,
|
248 |
+
"training_mode": script_args.training_mode
|
249 |
+
},
|
250 |
+
model_name=model_args.model_name,
|
251 |
+
language=script_args.language,
|
252 |
+
version=training_timestamp
|
253 |
+
)
|
254 |
+
metadata.save(adapter_path / "metadata.json")
|
255 |
|
256 |
if script_args.push_to_hub:
|
257 |
# Using a consistent naming pattern that links to the FEEL project
|