Spaces:
Sleeping
Sleeping
Commit
·
1b22e3d
1
Parent(s):
23b71fd
Updated HROM-TRAIN code to V1.1
Browse files- HROM_Trainer.py +40 -16
HROM_Trainer.py
CHANGED
@@ -23,10 +23,11 @@ CONFIG = {
|
|
23 |
"dataset": "daily_dialog",
|
24 |
"vocab_size": 32000,
|
25 |
"tokenizer_train_samples": 100000,
|
26 |
-
"learning_rate":
|
27 |
"max_turns": 6,
|
28 |
"max_checkpoints": 5,
|
29 |
-
"num_epochs":
|
|
|
30 |
}
|
31 |
|
32 |
class RotaryEmbedding(nn.Module):
|
@@ -242,27 +243,37 @@ class HROMTrainer:
|
|
242 |
self.tokenizer = tokenizer
|
243 |
|
244 |
def train_step(self, batch):
|
245 |
-
self.optimizer.zero_grad()
|
246 |
autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
|
247 |
with autocast():
|
248 |
outputs = self.model(
|
249 |
batch["input_ids"].to(self.device),
|
250 |
attention_mask=batch["attention_mask"].to(self.device)
|
251 |
)
|
252 |
-
|
253 |
outputs.view(-1, CONFIG["vocab_size"]),
|
254 |
batch["labels"].view(-1).to(self.device)
|
255 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
if self.scaler is not None:
|
257 |
-
self.scaler.scale(loss).backward()
|
258 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
259 |
self.scaler.step(self.optimizer)
|
260 |
self.scaler.update()
|
261 |
else:
|
262 |
-
loss.backward()
|
263 |
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
264 |
self.optimizer.step()
|
265 |
-
|
|
|
266 |
|
267 |
class SafetyManager:
|
268 |
def __init__(self, model, tokenizer):
|
@@ -344,17 +355,30 @@ def train():
|
|
344 |
safety = SafetyManager(model, tokenizer)
|
345 |
|
346 |
step = 0
|
|
|
|
|
347 |
model.train()
|
|
|
348 |
for epoch in range(CONFIG["num_epochs"]):
|
349 |
for batch in dataloader:
|
350 |
loss = trainer_obj.train_step(batch)
|
351 |
-
|
352 |
-
checkpoint_manager.save(model, trainer_obj.optimizer, step)
|
353 |
-
safety.debug_generation()
|
354 |
-
if step % CONFIG["debug_interval"] == 0:
|
355 |
-
print(f"Step {step} | Loss: {loss:.4f}")
|
356 |
-
safety.debug_generation("What's the meaning of life?")
|
357 |
step += 1
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
if __name__ == "__main__":
|
360 |
-
train()
|
|
|
23 |
"dataset": "daily_dialog",
|
24 |
"vocab_size": 32000,
|
25 |
"tokenizer_train_samples": 100000,
|
26 |
+
"learning_rate": 1e-4, # Lowered learning rate
|
27 |
"max_turns": 6,
|
28 |
"max_checkpoints": 5,
|
29 |
+
"num_epochs": 100, # Increased number of epochs
|
30 |
+
"grad_accum_steps": 4 # Gradient accumulation steps
|
31 |
}
|
32 |
|
33 |
class RotaryEmbedding(nn.Module):
|
|
|
243 |
self.tokenizer = tokenizer
|
244 |
|
245 |
def train_step(self, batch):
|
|
|
246 |
autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
|
247 |
with autocast():
|
248 |
outputs = self.model(
|
249 |
batch["input_ids"].to(self.device),
|
250 |
attention_mask=batch["attention_mask"].to(self.device)
|
251 |
)
|
252 |
+
original_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
|
253 |
outputs.view(-1, CONFIG["vocab_size"]),
|
254 |
batch["labels"].view(-1).to(self.device)
|
255 |
)
|
256 |
+
scaled_loss = original_loss / CONFIG["grad_accum_steps"]
|
257 |
+
|
258 |
+
if self.scaler is not None:
|
259 |
+
self.scaler.scale(scaled_loss).backward()
|
260 |
+
else:
|
261 |
+
scaled_loss.backward()
|
262 |
+
|
263 |
+
return original_loss.item()
|
264 |
+
|
265 |
+
def clip_and_step(self):
|
266 |
+
if self.scaler is not None:
|
267 |
+
self.scaler.unscale_(self.optimizer)
|
268 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
269 |
+
|
270 |
if self.scaler is not None:
|
|
|
|
|
271 |
self.scaler.step(self.optimizer)
|
272 |
self.scaler.update()
|
273 |
else:
|
|
|
|
|
274 |
self.optimizer.step()
|
275 |
+
|
276 |
+
self.optimizer.zero_grad()
|
277 |
|
278 |
class SafetyManager:
|
279 |
def __init__(self, model, tokenizer):
|
|
|
355 |
safety = SafetyManager(model, tokenizer)
|
356 |
|
357 |
step = 0
|
358 |
+
optimizer_step = 0
|
359 |
+
total_loss = 0.0
|
360 |
model.train()
|
361 |
+
|
362 |
for epoch in range(CONFIG["num_epochs"]):
|
363 |
for batch in dataloader:
|
364 |
loss = trainer_obj.train_step(batch)
|
365 |
+
total_loss += loss
|
|
|
|
|
|
|
|
|
|
|
366 |
step += 1
|
367 |
|
368 |
+
if step % CONFIG["grad_accum_steps"] == 0:
|
369 |
+
trainer_obj.clip_and_step()
|
370 |
+
avg_loss = total_loss / CONFIG["grad_accum_steps"]
|
371 |
+
total_loss = 0.0
|
372 |
+
|
373 |
+
if optimizer_step % CONFIG["checkpoint_interval"] == 0:
|
374 |
+
checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
|
375 |
+
safety.debug_generation()
|
376 |
+
|
377 |
+
if optimizer_step % CONFIG["debug_interval"] == 0:
|
378 |
+
print(f"Optimizer Step {optimizer_step} | Loss: {avg_loss:.4f}")
|
379 |
+
safety.debug_generation("What's the meaning of life?")
|
380 |
+
|
381 |
+
optimizer_step += 1
|
382 |
+
|
383 |
if __name__ == "__main__":
|
384 |
+
train()
|