nishantb06 commited on
Commit
1a3e79b
Β·
verified Β·
1 Parent(s): f2b6aad

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +76 -0
  2. model_weights.pt +3 -0
  3. requirements.txt +7 -0
  4. smollm_training.py +531 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from smollm_training import SmolLMConfig, tokenizer, SmolLM
4
+
5
+
6
+ # Load the model
7
+ def load_model():
8
+ config = SmolLMConfig()
9
+ model = SmolLM(config) # Create base model instead of Lightning model
10
+
11
+ # Load just the model weights
12
+ state_dict = torch.load("model_weights.pt", map_location="cpu")
13
+ model.load_state_dict(state_dict)
14
+
15
+ model.eval()
16
+ return model
17
+
18
+
19
+ def generate_text(prompt, max_tokens, temperature=0.8, top_k=40):
20
+ """Generate text based on the prompt"""
21
+ try:
22
+ # Encode the prompt
23
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
24
+
25
+ # Move to device if needed
26
+ device = next(model.parameters()).device
27
+ prompt_ids = prompt_ids.to(device)
28
+
29
+ # Generate text
30
+ with torch.no_grad():
31
+ generated_ids = model.generate( # Call generate directly on base model
32
+ prompt_ids,
33
+ max_new_tokens=max_tokens,
34
+ temperature=temperature,
35
+ top_k=top_k,
36
+ )
37
+
38
+ # Decode the generated text
39
+ generated_text = tokenizer.decode(generated_ids[0].tolist())
40
+
41
+ return generated_text
42
+
43
+ except Exception as e:
44
+ return f"An error occurred: {str(e)}"
45
+
46
+
47
+ # Load the model globally
48
+ model = load_model()
49
+
50
+ # Create the Gradio interface
51
+ demo = gr.Interface(
52
+ fn=generate_text,
53
+ inputs=[
54
+ gr.Textbox(
55
+ label="Enter your prompt", placeholder="Once upon a time...", lines=3
56
+ ),
57
+ gr.Slider(
58
+ minimum=50,
59
+ maximum=500,
60
+ value=100,
61
+ step=10,
62
+ label="Maximum number of tokens",
63
+ ),
64
+ ],
65
+ outputs=gr.Textbox(label="Generated Text", lines=10),
66
+ title="SmolLM Text Generator",
67
+ description="Enter a prompt and the model will generate a continuation.",
68
+ examples=[
69
+ ["Once upon a time", 100],
70
+ ["The future of AI is", 200],
71
+ ["In a galaxy far far away", 150],
72
+ ],
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()
model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbe0e1252067a3dd754093de963677c05013e5b5e1d99dfea462409c432f85b2
3
+ size 667028174
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers
4
+ pytorch-lightning
5
+ datasets
6
+ wandb
7
+ lightning
smollm_training.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import for colab/kaggle
2
+ # !pip install datasets transformers wandb -q
3
+ # !pip install pytorch-lightning lightning tiktoken -q
4
+ import os
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+
13
+ from datasets import load_dataset
14
+ from transformers import GPT2Tokenizer
15
+
16
+ import pytorch_lightning as pl
17
+ from pytorch_lightning.callbacks import LearningRateMonitor, RichProgressBar
18
+ from pytorch_lightning.loggers import WandbLogger
19
+ from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
20
+ from pytorch_lightning.callbacks import ModelCheckpoint
21
+
22
+ block_size = 512
23
+ batch_size = 8
24
+ max_lr = 1e-3
25
+ warmup_steps = 10
26
+ max_steps = 25000
27
+ log_every_n_steps = 100
28
+ save_checkpoints_every_n_steps = 10
29
+ effective_batch_size = 32
30
+
31
+ tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained(
32
+ "HuggingFaceTB/cosmo2-tokenizer"
33
+ )
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+ vocab_size = tokenizer.vocab_size
36
+
37
+
38
+ def load_cosmopedia_dataset(batch_size=8, seq_length=1024):
39
+ """
40
+ Returns a torch dataloader for the cosmopedia dataset
41
+ """
42
+ try:
43
+ dataset = load_dataset(
44
+ "HuggingFaceTB/smollm-corpus",
45
+ name="cosmopedia-v2",
46
+ split="train",
47
+ streaming=True,
48
+ )
49
+
50
+ def encode(examples):
51
+ tokens = tokenizer(
52
+ examples["text"],
53
+ truncation=True,
54
+ padding="max_length",
55
+ max_length=seq_length + 1,
56
+ return_tensors="pt",
57
+ )
58
+ input_ids = tokens["input_ids"].squeeze(0).clone().detach()
59
+ input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
60
+ labels = input_ids.clone().detach()
61
+ labels = labels[1:].to(torch.int64)
62
+ input_ids = input_ids[:-1].to(torch.int64)
63
+
64
+ return {"input_ids": input_ids, "labels": labels}
65
+
66
+ dataset = dataset.map(encode, remove_columns=["text"], batched=False)
67
+ dataset = dataset.with_format("torch")
68
+ dataloader = DataLoader(dataset, batch_size=batch_size)
69
+ return dataloader
70
+ except Exception as e:
71
+ print(e)
72
+ return None
73
+
74
+
75
+ @dataclass
76
+ class SmolLMConfig:
77
+ block_size = 1024
78
+ vocab_size = 49152
79
+ n_layers = 30
80
+ n_heads = 9
81
+ n_embed = 576
82
+ dropout = 0.1
83
+ mlp_hidden_dim = 1536
84
+ attention_dropout = 0.0
85
+ dropout = 0.1
86
+ n_key_value_heads = 3
87
+ rms_norm_eps = 1e-5
88
+
89
+
90
+ ## Function which enables K and V to have less heads than Q.
91
+ ## it repeats the K and V heads n_rep times
92
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
94
+ bs, n_kv_heads, slen, head_dim = x.shape
95
+ if n_rep == 1:
96
+ return x
97
+ return (
98
+ x[:, :, :, None, :]
99
+ .expand(bs, n_kv_heads, slen, n_rep, head_dim)
100
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
101
+ )
102
+
103
+
104
+ class RMSNorm(torch.nn.Module):
105
+ def __init__(self, dim: int, eps: float = 1e-6):
106
+ """
107
+ Initialize the RMSNorm normalization layer.
108
+
109
+ Args:
110
+ dim (int): The dimension of the input tensor.
111
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
112
+
113
+ Attributes:
114
+ eps (float): A small value added to the denominator for numerical stability.
115
+ weight (nn.Parameter): Learnable scaling parameter.
116
+
117
+ """
118
+ super().__init__()
119
+ self.eps = eps
120
+ self.weight = nn.Parameter(torch.ones(dim))
121
+
122
+ def _norm(self, x):
123
+ """
124
+ Apply the RMSNorm normalization to the input tensor.
125
+
126
+ Args:
127
+ x (torch.Tensor): The input tensor.
128
+
129
+ Returns:
130
+ torch.Tensor: The normalized tensor.
131
+
132
+ """
133
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
134
+
135
+ def forward(self, x):
136
+ """
137
+ Forward pass through the RMSNorm layer.
138
+
139
+ Args:
140
+ x (torch.Tensor): The input tensor.
141
+
142
+ Returns:
143
+ torch.Tensor: The output tensor after applying RMSNorm.
144
+
145
+ """
146
+ output = self._norm(x.float()).type_as(x)
147
+ return output * self.weight
148
+
149
+
150
+ class CausalMultiHeadAttention(nn.Module):
151
+ def __init__(self, config: SmolLMConfig):
152
+ super().__init__()
153
+ self.config = config
154
+ self.n_head = config.n_heads
155
+ self.n_embd = config.n_embed
156
+
157
+ # Linear projections for Q, K, V
158
+ # self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) # [n_embd, 3 * n_embd]
159
+ self.w_q = nn.Linear(config.n_embed, config.n_embed)
160
+ self.w_k = nn.Linear(config.n_embed, config.n_embed // config.n_key_value_heads)
161
+ self.w_v = nn.Linear(config.n_embed, config.n_embed // config.n_key_value_heads)
162
+ self.c_proj = nn.Linear(config.n_embed, config.n_embed) # [n_embd, n_embd]
163
+
164
+ self.n_rep = self.config.n_heads // self.config.n_key_value_heads
165
+
166
+ self.resid_dropout = nn.Dropout(config.dropout)
167
+ self.register_buffer(
168
+ "bias",
169
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
170
+ 1, 1, config.block_size, config.block_size
171
+ ),
172
+ )
173
+
174
+ def forward(self, x):
175
+ B, T, C = x.size() # [B, T, n_embd]
176
+
177
+ # Linear projection and split into Q, K, V
178
+ # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each
179
+ q = self.w_q(x) # [B, T, 576]
180
+ k = self.w_k(x) # [B, T, 192]
181
+ v = self.w_v(x) # [B, T, 192]
182
+
183
+ # Reshape for multi-head attention
184
+ k = k.view(
185
+ B,
186
+ T,
187
+ self.config.n_key_value_heads,
188
+ k.size(-1) // self.config.n_key_value_heads,
189
+ ).transpose(
190
+ 1, 2
191
+ ) # [B, 3, T, 64]
192
+ q = q.view(
193
+ B, T, self.config.n_heads, q.size(-1) // self.config.n_heads
194
+ ).transpose(
195
+ 1, 2
196
+ ) # [B, 9, T, 64]
197
+ v = v.view(
198
+ B,
199
+ T,
200
+ self.config.n_key_value_heads,
201
+ v.size(-1) // self.config.n_key_value_heads,
202
+ ).transpose(
203
+ 1, 2
204
+ ) # [B, 3, T, 64]
205
+
206
+ # repeat k and v for each head
207
+ k = repeat_kv(k, self.n_rep)
208
+ v = repeat_kv(v, self.n_rep)
209
+
210
+ # Attention scores
211
+ att = (q @ k.transpose(-2, -1)) * (
212
+ 1.0 / (k.size(-1) ** 0.5)
213
+ ) # [B, n_head, T, T]
214
+ att = att.masked_fill(
215
+ self.bias[:, :, :T, :T] == 0, float("-inf")
216
+ ) # [B, n_head, T, T]
217
+ att = F.softmax(att, dim=-1) # [B, n_head, T, T]
218
+
219
+ # Weighted sum of values
220
+ y = att @ v # [B, n_head, T, n_embd/n_head]
221
+
222
+ # Reshape and project
223
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd]
224
+ y = self.c_proj(y) # [B, T, n_embd]
225
+ y = self.resid_dropout(y) # [B, T, n_embd]
226
+
227
+ return y
228
+
229
+
230
+ class MLP(nn.Module):
231
+
232
+ def __init__(self, config: SmolLMConfig):
233
+ super().__init__()
234
+ self.c_fc = nn.Linear(config.n_embed, config.mlp_hidden_dim)
235
+ self.silu = nn.SiLU()
236
+ self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embed)
237
+ self.c_proj.NANOGPT_SCALE_INIT = 1
238
+
239
+ def forward(self, x):
240
+ x = self.c_fc(x)
241
+ x = self.silu(x)
242
+ x = self.c_proj(x)
243
+ return x
244
+
245
+
246
+ class LlamaMLP(nn.Module):
247
+
248
+ def __init__(self, config: SmolLMConfig):
249
+ super().__init__()
250
+ self.hidden_dim = config.mlp_hidden_dim # 1536
251
+ self.w1 = nn.Linear(config.n_embed, self.hidden_dim)
252
+ self.w2 = nn.Linear(self.hidden_dim, config.n_embed)
253
+ self.w3 = nn.Linear(config.n_embed, self.hidden_dim)
254
+
255
+ def forward(self, x):
256
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
257
+
258
+
259
+ class DecoderBlockWithRMSNorm(nn.Module):
260
+ def __init__(self, config: SmolLMConfig):
261
+ super().__init__()
262
+ self.config = config
263
+ self.rms_1 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps)
264
+ self.attn = CausalMultiHeadAttention(config)
265
+ self.rms_2 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps)
266
+ self.mlp = LlamaMLP(config)
267
+
268
+ def forward(self, x):
269
+ x = x + self.attn(self.rms_1(x))
270
+ x = x + self.mlp(self.rms_2(x))
271
+ return x
272
+
273
+
274
+ class DecoderBlockWithLayerNorm(nn.Module):
275
+ def __init__(self, config: SmolLMConfig):
276
+ super().__init__()
277
+ self.ln_1 = nn.LayerNorm(config.n_embed)
278
+ self.attn = CausalMultiHeadAttention(config)
279
+ self.ln_2 = nn.LayerNorm(config.n_embed)
280
+ self.mlp = MLP(config)
281
+
282
+ def forward(self, x):
283
+ x = x + self.attn(self.ln_1(x))
284
+ x = x + self.mlp(self.ln_2(x))
285
+ return x
286
+
287
+
288
+ class SmolLM(nn.Module):
289
+ def __init__(self, config: SmolLMConfig):
290
+ super().__init__()
291
+ self.config = config
292
+ self.wte = nn.Embedding(
293
+ config.vocab_size, config.n_embed
294
+ ) # [vocab_size, n_embd]
295
+ self.wpe = nn.Embedding(
296
+ config.block_size, config.n_embed
297
+ ) # [max_seq_len, n_embd]
298
+ self.drop = nn.Dropout(config.dropout)
299
+ self.blocks = nn.ModuleList(
300
+ [DecoderBlockWithRMSNorm(config) for _ in range(config.n_layers)]
301
+ )
302
+ self.rms_norm = RMSNorm(config.n_embed, eps=config.rms_norm_eps) # [n_embd]
303
+ self.lm_head = nn.Linear(
304
+ config.n_embed, config.vocab_size, bias=False
305
+ ) # [n_embd, vocab_size]
306
+
307
+ # weight sharing
308
+ self.wte.weight = self.lm_head.weight
309
+
310
+ self.apply(self._init_weights)
311
+
312
+ def _init_weights(self, module):
313
+ if isinstance(module, (nn.Linear, nn.Embedding)):
314
+ module.weight.data.normal_(mean=0.0, std=0.02)
315
+ if isinstance(module, nn.Linear) and module.bias is not None:
316
+ module.bias.data.zero_()
317
+ elif isinstance(module, nn.LayerNorm):
318
+ module.bias.data.zero_()
319
+ module.weight.data.fill_(1.0)
320
+
321
+ def forward(self, idx, targets=None):
322
+ # idx is of shape (B, T)
323
+ B, T = idx.size()
324
+ assert (
325
+ T <= self.config.block_size
326
+ ), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
327
+
328
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
329
+ pos_emb = self.wpe(pos) # position embeddings of shape (T, n_embd)
330
+ x = self.wte(idx) # token embeddings of shape (B, T, n_embd)
331
+ x = x + pos_emb
332
+
333
+ # forward the blocks of the transformer
334
+ for block in self.blocks:
335
+ x = block(x)
336
+ # forward the final layernorm and the classifier
337
+ x = self.rms_norm(x)
338
+ logits = self.lm_head(x) # (B, T, vocab_size)
339
+ loss = None
340
+ if targets is not None:
341
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
342
+ return logits, loss
343
+
344
+ @torch.no_grad()
345
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
346
+ """
347
+ Generate text given a starting sequence of tokens.
348
+
349
+ Args:
350
+ idx (torch.Tensor): Starting token indices, shape (B, T)
351
+ max_new_tokens (int): Number of tokens to generate
352
+ temperature (float): Sampling temperature (1.0 = no change, < 1.0 = less random, > 1.0 = more random)
353
+ top_k (int): If specified, only sample from the top k most probable tokens
354
+ """
355
+ for _ in range(max_new_tokens):
356
+ # if the sequence context is growing too long we must crop it at block_size
357
+ idx_cond = (
358
+ idx
359
+ if idx.size(1) <= self.config.block_size
360
+ else idx[:, -self.config.block_size :]
361
+ )
362
+ # forward the model to get the logits for the index in the sequence
363
+ logits, _ = self(idx_cond)
364
+ # pluck the logits at the final step and scale by desired temperature
365
+ logits = logits[:, -1, :] / temperature
366
+ # optionally crop the logits to only the top k options
367
+ if top_k is not None:
368
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
369
+ logits[logits < v[:, [-1]]] = -float("Inf")
370
+ # apply softmax to convert logits to (normalized) probabilities
371
+ probs = F.softmax(logits, dim=-1)
372
+ # sample from the distribution
373
+ idx_next = torch.multinomial(probs, num_samples=1)
374
+ # append sampled index to the running sequence
375
+ idx = torch.cat((idx, idx_next), dim=1)
376
+
377
+ return idx
378
+
379
+
380
+ class SmolLMLightning(pl.LightningModule):
381
+ def __init__(self, config: SmolLMConfig, lr, warmup_steps, max_steps):
382
+ super().__init__()
383
+ self.save_hyperparameters()
384
+ self.config = config
385
+ self.model = SmolLM(self.config)
386
+ self.criterion = nn.CrossEntropyLoss()
387
+ self.tokenizer = tokenizer
388
+ self.generation_prompt = "Once upon a time"
389
+ self._generating = False
390
+
391
+ def forward(self, x):
392
+ return self.model(x)
393
+
394
+ def training_step(self, batch, batch_idx):
395
+ input_ids = batch["input_ids"]
396
+ target_ids = batch["labels"]
397
+ logits, _ = self(input_ids)
398
+ loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
399
+
400
+ # Log the loss with 4 decimal precision
401
+ self.log(
402
+ "train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, logger=True
403
+ )
404
+
405
+ # Generate text every n steps, but only if we're not already generating
406
+ if (self.global_step) % log_every_n_steps == 0 and not self._generating:
407
+ self._generating = True
408
+ self.generate_and_log_sample()
409
+ self._generating = False
410
+
411
+ return loss
412
+
413
+ def generate_and_log_sample(self):
414
+ """Generate and log a sample of text from the model"""
415
+ try:
416
+ # Encode the prompt
417
+ prompt_ids = self.tokenizer.encode(
418
+ self.generation_prompt, return_tensors="pt"
419
+ ).to(self.device)
420
+
421
+ # Generate new tokens
422
+ generated_ids = self.model.generate(
423
+ prompt_ids, max_new_tokens=50, temperature=0.8, top_k=40
424
+ )
425
+
426
+ # Decode the generated tokens
427
+ generated_text = self.tokenizer.decode(generated_ids[0].tolist())
428
+
429
+ # Create a formatted message
430
+ message = (
431
+ f"\n{'='*40}\n"
432
+ f"Step {self.global_step} generation:\n"
433
+ f"Prompt: {self.generation_prompt}\n"
434
+ f"Generated: {generated_text}\n"
435
+ f"{'='*40}\n"
436
+ )
437
+
438
+ print(message)
439
+
440
+ # Log to WandB
441
+ if hasattr(self.logger, "experiment"):
442
+ self.logger.experiment.log(
443
+ {"generated_text": generated_text, "global_step": self.global_step}
444
+ )
445
+ except Exception as e:
446
+ print(f"Generation failed with error: {str(e)}")
447
+
448
+ def configure_optimizers(self):
449
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
450
+
451
+ def lr_lambda(current_step):
452
+ if current_step < self.hparams.warmup_steps:
453
+ return self.hparams.lr * (current_step + 1) / self.hparams.warmup_steps
454
+ elif current_step > self.hparams.max_steps:
455
+ return self.hparams.lr * 0.1
456
+ decay_ratio = (current_step - self.hparams.warmup_steps) / (
457
+ self.hparams.max_steps - self.hparams.warmup_steps
458
+ )
459
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
460
+ return self.hparams.lr * 0.1 + coeff * (
461
+ self.hparams.lr - self.hparams.lr * 0.1
462
+ )
463
+
464
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
465
+ return [optimizer], [scheduler]
466
+
467
+
468
+ if __name__ == "__main__":
469
+ dataloader = load_cosmopedia_dataset(batch_size=batch_size, seq_length=block_size)
470
+ model = SmolLMLightning(SmolLMConfig(), max_lr, warmup_steps, max_steps)
471
+
472
+ # Replace TensorBoard logger with WandB logger
473
+ wandb_logger = WandbLogger(
474
+ project="smollm", # your project name
475
+ name="transformer_experiment", # name of the run
476
+ log_model=True, # log model checkpoints
477
+ )
478
+
479
+ os.makedirs("checkpoints", exist_ok=True)
480
+ checkpoint_callback = ModelCheckpoint(
481
+ dirpath="checkpoints/",
482
+ filename="best-checkpoint",
483
+ verbose=True,
484
+ every_n_train_steps=save_checkpoints_every_n_steps,
485
+ )
486
+
487
+ device = "cpu"
488
+ if torch.cuda.is_available():
489
+ device = "cuda"
490
+ elif torch.backends.mps.is_available():
491
+ device = "mps"
492
+ print(f"using device: {device}")
493
+
494
+ torch.set_float32_matmul_precision("high")
495
+
496
+ progress_bar = RichProgressBar(
497
+ refresh_rate=1,
498
+ leave=False,
499
+ theme=RichProgressBarTheme(
500
+ description="",
501
+ progress_bar="#6206E0",
502
+ progress_bar_finished="#6206E0",
503
+ progress_bar_pulse="#6206E0",
504
+ batch_progress="",
505
+ time="dim",
506
+ processing_speed="dim underline",
507
+ metrics="italic",
508
+ metrics_text_delimiter=" ",
509
+ metrics_format=".3f",
510
+ ),
511
+ console_kwargs=None,
512
+ )
513
+
514
+ trainer = pl.Trainer(
515
+ max_steps=max_steps,
516
+ accelerator=device,
517
+ devices=1,
518
+ callbacks=[
519
+ LearningRateMonitor(logging_interval="step"),
520
+ progress_bar,
521
+ checkpoint_callback,
522
+ ],
523
+ precision="bf16-mixed",
524
+ log_every_n_steps=1,
525
+ enable_progress_bar=True,
526
+ enable_model_summary=True,
527
+ logger=wandb_logger,
528
+ accumulate_grad_batches=effective_batch_size // batch_size,
529
+ )
530
+
531
+ trainer.fit(model, dataloader)