jenbenarye commited on
Commit
056b95d
·
1 Parent(s): 801c17a

changed file name

Browse files
Files changed (1) hide show
  1. ml/{kto_lora.py → trainer.py} +85 -23
ml/{kto_lora.py → trainer.py} RENAMED
@@ -9,6 +9,7 @@ from datetime import datetime
9
  import wandb
10
  from enum import Enum
11
  from typing import Optional
 
12
 
13
 
14
  # PEFT library: attach and load adapters
@@ -104,6 +105,48 @@ def load_model_and_tokenizer(model_args):
104
 
105
  return model, tokenizer
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ####################################
108
  # MAIN LOGIC
109
  ####################################
@@ -112,26 +155,29 @@ def main():
112
  # Initialize wandb for logging
113
  wandb.init(project="kto")
114
 
 
 
 
115
  print("Loading base model and tokenizer...")
116
  model, tokenizer = load_model_and_tokenizer(model_args)
117
  ref_model, _ = load_model_and_tokenizer(model_args)
118
  print("Models and tokenizer loaded.")
119
 
120
- # -----------------------------
121
- # Adapter Loading or Initialization
122
- # -----------------------------
123
- # Configure the PEFT / LoRA adapter settings
124
- peft_config = get_peft_config(model_args)
125
- adapter_dir = os.path.join("adapters", script_args.language)
126
-
127
- if os.path.isdir(adapter_dir):
128
- # If an adapter for this language already exists, load it into the base model.
129
- model = PeftModel.from_pretrained(model, adapter_dir, is_trainable=True)
130
- print(f"Loaded existing adapter for language '{script_args.language}' from {adapter_dir}.")
131
  else:
132
- # Otherwise, initialize a new LoRA adapter.
 
133
  model = get_peft_model(model, peft_config)
134
- print(f"No adapter found for language '{script_args.language}'. Initialized new adapter.")
135
 
136
  # -----------------------------
137
  # Data Preparation and Training
@@ -180,16 +226,32 @@ def main():
180
  "step": metrics.get("step")
181
  })
182
 
183
- # -----------------------------
184
- # Adapter Saving
185
- # -----------------------------
186
- print("Saving adapter...")
187
- # Add timestamp to adapter directory
188
- timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
189
- new_adapter_dir = os.path.join(adapter_dir, f"version_{timestamp}")
190
- os.makedirs(new_adapter_dir, exist_ok=True)
191
- model.save_pretrained(new_adapter_dir)
192
- print(f"Adapter for language '{script_args.language}' saved to: {new_adapter_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  if script_args.push_to_hub:
195
  # Using a consistent naming pattern that links to the FEEL project
 
9
  import wandb
10
  from enum import Enum
11
  from typing import Optional
12
+ from pathlib import Path
13
 
14
 
15
  # PEFT library: attach and load adapters
 
105
 
106
  return model, tokenizer
107
 
108
+ def get_adapter_path(model_name: str, language: str, timestamp: str = None) -> Path:
109
+ """
110
+ Generate standardized adapter path.
111
+ If timestamp is None, returns the base language directory.
112
+ Otherwise, returns specific adapter version path.
113
+
114
+ Format: adapters/{model_name}/{language}/version_{timestamp}
115
+ """
116
+ # Clean model name (remove slashes, etc.)
117
+ clean_model_name = model_name.replace('/', '_')
118
+
119
+ base_path = Path("adapters") / clean_model_name / language
120
+ if timestamp:
121
+ return base_path / f"version_{timestamp}"
122
+ return base_path
123
+
124
+ def load_latest_adapter(model, model_name: str, language: str) -> tuple[PeftModel, str]:
125
+ """
126
+ Load the most recent adapter for given model and language.
127
+ Returns: (loaded_model, timestamp of loaded adapter)
128
+ """
129
+ adapter_base = get_adapter_path(model_name, language)
130
+
131
+ if not adapter_base.exists():
132
+ return None, None
133
+
134
+ # Get all version directories and sort by timestamp
135
+ versions = sorted(
136
+ [d for d in adapter_base.glob("version_*")],
137
+ key=lambda x: x.name,
138
+ reverse=True
139
+ )
140
+
141
+ if not versions:
142
+ return None, None
143
+
144
+ latest_version = versions[0]
145
+ timestamp = latest_version.name.replace("version_", "")
146
+
147
+ model = PeftModel.from_pretrained(model, latest_version, is_trainable=True)
148
+ return model, timestamp
149
+
150
  ####################################
151
  # MAIN LOGIC
152
  ####################################
 
155
  # Initialize wandb for logging
156
  wandb.init(project="kto")
157
 
158
+ # Get timestamp at start of training
159
+ training_timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
160
+
161
  print("Loading base model and tokenizer...")
162
  model, tokenizer = load_model_and_tokenizer(model_args)
163
  ref_model, _ = load_model_and_tokenizer(model_args)
164
  print("Models and tokenizer loaded.")
165
 
166
+ # Load existing adapter or create new one
167
+ loaded_model, previous_timestamp = load_latest_adapter(
168
+ model,
169
+ model_args.model_name,
170
+ script_args.language
171
+ )
172
+
173
+ if loaded_model is not None:
174
+ model = loaded_model
175
+ print(f"Loaded existing adapter trained at {previous_timestamp}")
 
176
  else:
177
+ # Initialize new LoRA adapter
178
+ peft_config = get_peft_config(model_args)
179
  model = get_peft_model(model, peft_config)
180
+ print("Initialized new adapter")
181
 
182
  # -----------------------------
183
  # Data Preparation and Training
 
226
  "step": metrics.get("step")
227
  })
228
 
229
+ # Save the adapter
230
+ adapter_path = get_adapter_path(
231
+ model_args.model_name,
232
+ script_args.language,
233
+ training_timestamp
234
+ )
235
+ adapter_path.parent.mkdir(parents=True, exist_ok=True)
236
+
237
+ print(f"Saving adapter to: {adapter_path}")
238
+ model.save_pretrained(adapter_path)
239
+
240
+ # Save metadata
241
+ metadata = AdapterMetadata(
242
+ training_timestamp=training_timestamp,
243
+ dataset_entries=[entry["id"] for entry in dataset],
244
+ training_params={
245
+ "max_weight": script_args.max_weight,
246
+ "min_weight": script_args.min_weight,
247
+ "decay_factor": script_args.decay_factor,
248
+ "training_mode": script_args.training_mode
249
+ },
250
+ model_name=model_args.model_name,
251
+ language=script_args.language,
252
+ version=training_timestamp
253
+ )
254
+ metadata.save(adapter_path / "metadata.json")
255
 
256
  if script_args.push_to_hub:
257
  # Using a consistent naming pattern that links to the FEEL project