TimurHromek commited on
Commit
3d2f665
·
1 Parent(s): b4e9c5b

Delete HROM_Trainer.py

Browse files
Files changed (1) hide show
  1. HROM_Trainer.py +0 -384
HROM_Trainer.py DELETED
@@ -1,384 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset, DataLoader
4
- from datasets import load_dataset
5
- from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
6
- import math
7
- import os
8
- import re
9
- from datetime import datetime
10
- from contextlib import nullcontext
11
-
12
- # Configuration
13
- CONFIG = {
14
- "dim": 512,
15
- "n_layers": 6,
16
- "n_heads": 8,
17
- "ff_dim": 2048,
18
- "dropout": 0.1,
19
- "max_seq_len": 1024,
20
- "batch_size": 32,
21
- "checkpoint_interval": 1000,
22
- "debug_interval": 500,
23
- "dataset": "daily_dialog",
24
- "vocab_size": 32000,
25
- "tokenizer_train_samples": 100000,
26
- "learning_rate": 1e-4, # Lowered learning rate
27
- "max_turns": 6,
28
- "max_checkpoints": 5,
29
- "num_epochs": 100, # Increased number of epochs
30
- "grad_accum_steps": 4 # Gradient accumulation steps
31
- }
32
-
33
- class RotaryEmbedding(nn.Module):
34
- def __init__(self, dim):
35
- super().__init__()
36
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
37
- self.register_buffer("inv_freq", inv_freq)
38
-
39
- def forward(self, seq_len):
40
- t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
41
- freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
42
- return torch.cat((freqs, freqs), dim=-1)
43
-
44
- def rotate_half(x):
45
- x1, x2 = x.chunk(2, dim=-1)
46
- return torch.cat((-x2, x1), dim=-1)
47
-
48
- def apply_rotary_pos_emb(pos, t):
49
- pos = pos.unsqueeze(0).unsqueeze(1)
50
- return (t * pos.cos()) + (rotate_half(t) * pos.sin())
51
-
52
- class SwiGLU(nn.Module):
53
- def forward(self, x):
54
- x, gate = x.chunk(2, dim=-1)
55
- return x * torch.sigmoid(gate)
56
-
57
- class HROMAttention(nn.Module):
58
- def __init__(self):
59
- super().__init__()
60
- self.dim = CONFIG["dim"]
61
- self.n_heads = CONFIG["n_heads"]
62
- self.head_dim = self.dim // self.n_heads
63
- self.qkv = nn.Linear(self.dim, 3 * self.dim)
64
- self.proj = nn.Linear(self.dim, self.dim)
65
- self.rotary = RotaryEmbedding(self.head_dim)
66
- self.dropout = nn.Dropout(CONFIG["dropout"])
67
-
68
- def forward(self, x, mask=None):
69
- B, T, _ = x.shape
70
- qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
71
- q, k, v = qkv.unbind(2)
72
- q = q.transpose(1, 2)
73
- k = k.transpose(1, 2)
74
- v = v.transpose(1, 2)
75
- pos = self.rotary(T)
76
- q = apply_rotary_pos_emb(pos, q)
77
- k = apply_rotary_pos_emb(pos, k)
78
- attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
79
- if mask is not None:
80
- mask = mask.unsqueeze(1)
81
- attn = attn + mask
82
- attn = torch.softmax(attn, dim=-1)
83
- attn = self.dropout(attn)
84
- out = attn @ v
85
- out = out.transpose(1, 2).reshape(B, T, self.dim)
86
- return self.proj(out)
87
-
88
- class HROMBlock(nn.Module):
89
- def __init__(self):
90
- super().__init__()
91
- self.attn = HROMAttention()
92
- self.ff = nn.Sequential(
93
- nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]),
94
- SwiGLU(),
95
- nn.Linear(CONFIG["ff_dim"], CONFIG["dim"])
96
- )
97
- self.norm1 = nn.LayerNorm(CONFIG["dim"])
98
- self.norm2 = nn.LayerNorm(CONFIG["dim"])
99
- self.dropout = nn.Dropout(CONFIG["dropout"])
100
-
101
- def forward(self, x, mask=None):
102
- x = x + self.dropout(self.attn(self.norm1(x), mask))
103
- x = x + self.dropout(self.ff(self.norm2(x)))
104
- return x
105
-
106
- class HROM(nn.Module):
107
- def __init__(self):
108
- super().__init__()
109
- self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
110
- self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
111
- self.norm = nn.LayerNorm(CONFIG["dim"])
112
- self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
113
- self.apply(self._init_weights)
114
-
115
- def _init_weights(self, module):
116
- if isinstance(module, nn.Linear):
117
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
- if module.bias is not None:
119
- torch.nn.init.zeros_(module.bias)
120
-
121
- def forward(self, x, attention_mask=None):
122
- x = self.embed(x)
123
- if attention_mask is not None:
124
- B, T = attention_mask.shape
125
- causal_mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
126
- causal_mask = causal_mask.to(x.device)
127
- pad_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype=torch.float32)
128
- pad_mask = (1.0 - pad_mask) * torch.finfo(torch.float32).min
129
- mask = causal_mask + pad_mask.squeeze(1)
130
- else:
131
- B, T = x.shape[:2]
132
- mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
133
- mask = mask.to(x.device)
134
- mask = mask.unsqueeze(0).expand(B, -1, -1)
135
- for block in self.blocks:
136
- x = block(x, mask)
137
- return self.head(self.norm(x))
138
-
139
- class TokenizerTrainer:
140
- def __init__(self):
141
- self.tokenizer = Tokenizer(models.BPE())
142
- self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
143
- self.tokenizer.decoder = decoders.ByteLevel()
144
- self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
145
-
146
- def train(self, dataset_name):
147
- dataset = load_dataset(dataset_name, split=f"train[:{CONFIG['tokenizer_train_samples']}]")
148
- text_samples = []
149
- for entry in dataset:
150
- if "dialog" in entry:
151
- for i, utterance in enumerate(entry["dialog"][:CONFIG["max_turns"]]):
152
- role = "<user>" if i % 2 == 0 else "<assistant>"
153
- text_samples.append(f"{role} {utterance}")
154
- else:
155
- text_samples.append(self._clean_text(entry.get("text", "")))
156
- trainer = trainers.BpeTrainer(
157
- vocab_size=CONFIG["vocab_size"],
158
- special_tokens=self.special_tokens,
159
- min_frequency=2,
160
- show_progress=True
161
- )
162
- self.tokenizer.train_from_iterator(text_samples, trainer=trainer, length=len(text_samples))
163
- self.tokenizer.post_processor = processors.TemplateProcessing(
164
- single="$A </s>",
165
- pair="$A $B </s>",
166
- special_tokens=[("</s>", self.tokenizer.token_to_id("</s>"))],
167
- )
168
- os.makedirs("tokenizer", exist_ok=True)
169
- self.tokenizer.save("tokenizer/hrom_tokenizer.json")
170
-
171
- def _clean_text(self, text):
172
- text = re.sub(r'[^\w\s.,!?\'\-:;<>]', '', text)
173
- text = re.sub(r'\s+', ' ', text).strip()
174
- return text
175
-
176
- class ChatDataset(Dataset):
177
- def __init__(self, tokenizer):
178
- full_dataset = load_dataset(CONFIG["dataset"], split="train")
179
- num_samples = min(len(full_dataset), CONFIG["tokenizer_train_samples"])
180
- self.dataset = full_dataset.shuffle(seed=42).select(range(num_samples))
181
- self.tokenizer = tokenizer
182
- self.max_length = CONFIG["max_seq_len"]
183
- self.turn_sep = self.tokenizer.token_to_id("</s>")
184
-
185
- def __len__(self):
186
- return len(self.dataset)
187
-
188
- def __getitem__(self, idx):
189
- entry = self.dataset[idx]
190
- formatted = []
191
- if "dialog" in entry:
192
- dialog = entry["dialog"][:CONFIG["max_turns"]]
193
- for i, utterance in enumerate(dialog):
194
- role_token = "<user>" if i % 2 == 0 else "<assistant>"
195
- formatted.extend([
196
- self.tokenizer.token_to_id(role_token),
197
- *self.tokenizer.encode(utterance).ids,
198
- self.turn_sep
199
- ])
200
- else:
201
- text = entry.get("text", "")
202
- formatted.extend([
203
- self.tokenizer.token_to_id("<user>"),
204
- *self.tokenizer.encode(text).ids,
205
- self.turn_sep
206
- ])
207
- formatted = formatted[:self.max_length-2]
208
- formatted = [self.tokenizer.token_to_id("<s>"), *formatted, self.tokenizer.token_to_id("</s>")]
209
- return {
210
- "input_ids": formatted[:-1],
211
- "labels": formatted[1:]
212
- }
213
-
214
- @staticmethod
215
- def collate_fn(batch):
216
- max_len = max(len(item["input_ids"]) for item in batch)
217
- pad_id = Tokenizer.from_file("tokenizer/hrom_tokenizer.json").token_to_id("<pad>")
218
- inputs, labels, masks = [], [], []
219
- for item in batch:
220
- pad_len = max_len - len(item["input_ids"])
221
- inputs.append(item["input_ids"] + [pad_id] * pad_len)
222
- labels.append(item["labels"] + [pad_id] * pad_len)
223
- masks.append([1] * len(item["input_ids"]) + [0] * pad_len)
224
- return {
225
- "input_ids": torch.tensor(inputs),
226
- "labels": torch.tensor(labels),
227
- "attention_mask": torch.tensor(masks)
228
- }
229
-
230
- class HROMTrainer:
231
- def __init__(self, model, tokenizer):
232
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
- self.model = model.to(self.device)
234
- if self.device.type == "cuda":
235
- self.scaler = torch.cuda.amp.GradScaler()
236
- else:
237
- self.scaler = None
238
- self.optimizer = torch.optim.AdamW(
239
- self.model.parameters(),
240
- lr=CONFIG["learning_rate"],
241
- fused=True if self.device.type == "cuda" else False
242
- )
243
- self.tokenizer = tokenizer
244
-
245
- def train_step(self, batch):
246
- autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
247
- with autocast():
248
- outputs = self.model(
249
- batch["input_ids"].to(self.device),
250
- attention_mask=batch["attention_mask"].to(self.device)
251
- )
252
- original_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
253
- outputs.view(-1, CONFIG["vocab_size"]),
254
- batch["labels"].view(-1).to(self.device)
255
- )
256
- scaled_loss = original_loss / CONFIG["grad_accum_steps"]
257
-
258
- if self.scaler is not None:
259
- self.scaler.scale(scaled_loss).backward()
260
- else:
261
- scaled_loss.backward()
262
-
263
- return original_loss.item()
264
-
265
- def clip_and_step(self):
266
- if self.scaler is not None:
267
- self.scaler.unscale_(self.optimizer)
268
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
269
-
270
- if self.scaler is not None:
271
- self.scaler.step(self.optimizer)
272
- self.scaler.update()
273
- else:
274
- self.optimizer.step()
275
-
276
- self.optimizer.zero_grad()
277
-
278
- class SafetyManager:
279
- def __init__(self, model, tokenizer):
280
- self.model = model
281
- self.tokenizer = tokenizer
282
- self.bad_words = ["hate", "kill", "harm"]
283
- self.bad_word_ids = [tokenizer.encode(w).ids for w in self.bad_words]
284
-
285
- def content_filter(self, text):
286
- tokens = self.tokenizer.encode(text).ids
287
- for bad_ids in self.bad_word_ids:
288
- if any(tokens[i:i+len(bad_ids)] == bad_ids for i in range(len(tokens))):
289
- return False
290
- return True
291
-
292
- def generate_safely(self, prompt, max_length=50):
293
- input_ids = self.tokenizer.encode(prompt).ids
294
- device = next(self.model.parameters()).device
295
- for _ in range(max_length):
296
- with torch.no_grad():
297
- logits = self.model(torch.tensor([input_ids]).to(device))
298
- next_token = logits.argmax(-1)[:, -1].item()
299
- if next_token == self.tokenizer.token_to_id("</s>"):
300
- break
301
- generated = self.tokenizer.decode(input_ids + [next_token])
302
- if not self.content_filter(generated):
303
- break
304
- input_ids.append(next_token)
305
- return self.tokenizer.decode(input_ids)
306
-
307
- def debug_generation(self, prompt="Hello!"):
308
- print(f"\nSafety Check Generation:")
309
- response = self.generate_safely(prompt)
310
- print(f"Prompt: {prompt}\nResponse: {response}")
311
-
312
- class CheckpointManager:
313
- def __init__(self):
314
- self.checkpoint_dir = "checkpoints"
315
- os.makedirs(self.checkpoint_dir, exist_ok=True)
316
-
317
- def save(self, model, optimizer, step):
318
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
319
- path = f"{self.checkpoint_dir}/hrom_{timestamp}_step{step}.pt"
320
- torch.save({
321
- "model": model.state_dict(),
322
- "optimizer": optimizer.state_dict(),
323
- "step": step,
324
- "config": CONFIG
325
- }, path)
326
- self._cleanup_old_checkpoints()
327
-
328
- def _cleanup_old_checkpoints(self):
329
- checkpoints = sorted(os.listdir(self.checkpoint_dir),
330
- key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
331
- while len(checkpoints) > CONFIG["max_checkpoints"]:
332
- os.remove(os.path.join(self.checkpoint_dir, checkpoints[0]))
333
- checkpoints = checkpoints[1:]
334
-
335
- def train():
336
- checkpoint_manager = CheckpointManager()
337
- if not os.path.exists("tokenizer/hrom_tokenizer.json"):
338
- print("Training tokenizer...")
339
- tokenizer_trainer = TokenizerTrainer()
340
- tokenizer_trainer.train(CONFIG["dataset"])
341
-
342
- tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
343
- model = HROM()
344
- print("Downloading and caching the dataset...")
345
- _ = load_dataset(CONFIG["dataset"], split="train", download_mode="reuse_cache_if_exists")
346
-
347
- dataset = ChatDataset(tokenizer)
348
- dataloader = DataLoader(
349
- dataset,
350
- batch_size=CONFIG["batch_size"],
351
- collate_fn=ChatDataset.collate_fn
352
- )
353
-
354
- trainer_obj = HROMTrainer(model, tokenizer)
355
- safety = SafetyManager(model, tokenizer)
356
-
357
- step = 0
358
- optimizer_step = 0
359
- total_loss = 0.0
360
- model.train()
361
-
362
- for epoch in range(CONFIG["num_epochs"]):
363
- for batch in dataloader:
364
- loss = trainer_obj.train_step(batch)
365
- total_loss += loss
366
- step += 1
367
-
368
- if step % CONFIG["grad_accum_steps"] == 0:
369
- trainer_obj.clip_and_step()
370
- avg_loss = total_loss / CONFIG["grad_accum_steps"]
371
- total_loss = 0.0
372
-
373
- if optimizer_step % CONFIG["checkpoint_interval"] == 0:
374
- checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
375
- safety.debug_generation()
376
-
377
- if optimizer_step % CONFIG["debug_interval"] == 0:
378
- print(f"Optimizer Step {optimizer_step} | Loss: {avg_loss:.4f}")
379
- safety.debug_generation("What's the meaning of life?")
380
-
381
- optimizer_step += 1
382
-
383
- if __name__ == "__main__":
384
- train()