TimurHromek commited on
Commit
1b22e3d
·
1 Parent(s): 23b71fd

Updated HROM-TRAIN code to V1.1

Browse files
Files changed (1) hide show
  1. HROM_Trainer.py +40 -16
HROM_Trainer.py CHANGED
@@ -23,10 +23,11 @@ CONFIG = {
23
  "dataset": "daily_dialog",
24
  "vocab_size": 32000,
25
  "tokenizer_train_samples": 100000,
26
- "learning_rate": 3e-4,
27
  "max_turns": 6,
28
  "max_checkpoints": 5,
29
- "num_epochs": 50 # Increased number of epochs for longer training
 
30
  }
31
 
32
  class RotaryEmbedding(nn.Module):
@@ -242,27 +243,37 @@ class HROMTrainer:
242
  self.tokenizer = tokenizer
243
 
244
  def train_step(self, batch):
245
- self.optimizer.zero_grad()
246
  autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
247
  with autocast():
248
  outputs = self.model(
249
  batch["input_ids"].to(self.device),
250
  attention_mask=batch["attention_mask"].to(self.device)
251
  )
252
- loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
253
  outputs.view(-1, CONFIG["vocab_size"]),
254
  batch["labels"].view(-1).to(self.device)
255
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  if self.scaler is not None:
257
- self.scaler.scale(loss).backward()
258
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
259
  self.scaler.step(self.optimizer)
260
  self.scaler.update()
261
  else:
262
- loss.backward()
263
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
264
  self.optimizer.step()
265
- return loss.item()
 
266
 
267
  class SafetyManager:
268
  def __init__(self, model, tokenizer):
@@ -344,17 +355,30 @@ def train():
344
  safety = SafetyManager(model, tokenizer)
345
 
346
  step = 0
 
 
347
  model.train()
 
348
  for epoch in range(CONFIG["num_epochs"]):
349
  for batch in dataloader:
350
  loss = trainer_obj.train_step(batch)
351
- if step % CONFIG["checkpoint_interval"] == 0:
352
- checkpoint_manager.save(model, trainer_obj.optimizer, step)
353
- safety.debug_generation()
354
- if step % CONFIG["debug_interval"] == 0:
355
- print(f"Step {step} | Loss: {loss:.4f}")
356
- safety.debug_generation("What's the meaning of life?")
357
  step += 1
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  if __name__ == "__main__":
360
- train()
 
23
  "dataset": "daily_dialog",
24
  "vocab_size": 32000,
25
  "tokenizer_train_samples": 100000,
26
+ "learning_rate": 1e-4, # Lowered learning rate
27
  "max_turns": 6,
28
  "max_checkpoints": 5,
29
+ "num_epochs": 100, # Increased number of epochs
30
+ "grad_accum_steps": 4 # Gradient accumulation steps
31
  }
32
 
33
  class RotaryEmbedding(nn.Module):
 
243
  self.tokenizer = tokenizer
244
 
245
  def train_step(self, batch):
 
246
  autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
247
  with autocast():
248
  outputs = self.model(
249
  batch["input_ids"].to(self.device),
250
  attention_mask=batch["attention_mask"].to(self.device)
251
  )
252
+ original_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
253
  outputs.view(-1, CONFIG["vocab_size"]),
254
  batch["labels"].view(-1).to(self.device)
255
  )
256
+ scaled_loss = original_loss / CONFIG["grad_accum_steps"]
257
+
258
+ if self.scaler is not None:
259
+ self.scaler.scale(scaled_loss).backward()
260
+ else:
261
+ scaled_loss.backward()
262
+
263
+ return original_loss.item()
264
+
265
+ def clip_and_step(self):
266
+ if self.scaler is not None:
267
+ self.scaler.unscale_(self.optimizer)
268
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
269
+
270
  if self.scaler is not None:
 
 
271
  self.scaler.step(self.optimizer)
272
  self.scaler.update()
273
  else:
 
 
274
  self.optimizer.step()
275
+
276
+ self.optimizer.zero_grad()
277
 
278
  class SafetyManager:
279
  def __init__(self, model, tokenizer):
 
355
  safety = SafetyManager(model, tokenizer)
356
 
357
  step = 0
358
+ optimizer_step = 0
359
+ total_loss = 0.0
360
  model.train()
361
+
362
  for epoch in range(CONFIG["num_epochs"]):
363
  for batch in dataloader:
364
  loss = trainer_obj.train_step(batch)
365
+ total_loss += loss
 
 
 
 
 
366
  step += 1
367
 
368
+ if step % CONFIG["grad_accum_steps"] == 0:
369
+ trainer_obj.clip_and_step()
370
+ avg_loss = total_loss / CONFIG["grad_accum_steps"]
371
+ total_loss = 0.0
372
+
373
+ if optimizer_step % CONFIG["checkpoint_interval"] == 0:
374
+ checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
375
+ safety.debug_generation()
376
+
377
+ if optimizer_step % CONFIG["debug_interval"] == 0:
378
+ print(f"Optimizer Step {optimizer_step} | Loss: {avg_loss:.4f}")
379
+ safety.debug_generation("What's the meaning of life?")
380
+
381
+ optimizer_step += 1
382
+
383
  if __name__ == "__main__":
384
+ train()