nishantb06 commited on
Commit
8c0c4ae
·
verified ·
1 Parent(s): 02a1a11

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +76 -0
  2. model_weights.pt +3 -0
  3. requirements.txt +7 -0
  4. smollm_training.py +556 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from smollm_training import SmolLMConfig, tokenizer, SmolLM
4
+
5
+
6
+ # Load the model
7
+ def load_model():
8
+ config = SmolLMConfig()
9
+ model = SmolLM(config) # Create base model instead of Lightning model
10
+
11
+ # Load just the model weights
12
+ state_dict = torch.load("model_weights.pt", map_location="cpu")
13
+ model.load_state_dict(state_dict)
14
+
15
+ model.eval()
16
+ return model
17
+
18
+
19
+ def generate_text(prompt, max_tokens, temperature=0.8, top_k=40):
20
+ """Generate text based on the prompt"""
21
+ try:
22
+ # Encode the prompt
23
+ prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
24
+
25
+ # Move to device if needed
26
+ device = next(model.parameters()).device
27
+ prompt_ids = prompt_ids.to(device)
28
+
29
+ # Generate text
30
+ with torch.no_grad():
31
+ generated_ids = model.generate( # Call generate directly on base model
32
+ prompt_ids,
33
+ max_new_tokens=max_tokens,
34
+ temperature=temperature,
35
+ top_k=top_k,
36
+ )
37
+
38
+ # Decode the generated text
39
+ generated_text = tokenizer.decode(generated_ids[0].tolist())
40
+
41
+ return generated_text
42
+
43
+ except Exception as e:
44
+ return f"An error occurred: {str(e)}"
45
+
46
+
47
+ # Load the model globally
48
+ model = load_model()
49
+
50
+ # Create the Gradio interface
51
+ demo = gr.Interface(
52
+ fn=generate_text,
53
+ inputs=[
54
+ gr.Textbox(
55
+ label="Enter your prompt", placeholder="Once upon a time...", lines=3
56
+ ),
57
+ gr.Slider(
58
+ minimum=50,
59
+ maximum=500,
60
+ value=100,
61
+ step=10,
62
+ label="Maximum number of tokens",
63
+ ),
64
+ ],
65
+ outputs=gr.Textbox(label="Generated Text", lines=10),
66
+ title="SmolLM Text Generator",
67
+ description="Enter a prompt and the model will generate a continuation.",
68
+ examples=[
69
+ ["Once upon a time", 100],
70
+ ["The future of AI is", 200],
71
+ ["In a galaxy far far away", 150],
72
+ ],
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch()
model_weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9890d0e8cae8c871513e6df473dab27fde8524d0ebd1b800f97264c78931e2e9
3
+ size 666342726
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ transformers
4
+ pytorch-lightning
5
+ datasets
6
+ wandb
7
+ lightning
smollm_training.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import for colab/kaggle
2
+ # !pip install datasets transformers wandb -q
3
+ # !pip install pytorch-lightning lightning tiktoken -q
4
+ import os
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+
13
+ from datasets import load_dataset
14
+ from transformers import GPT2Tokenizer
15
+
16
+ import pytorch_lightning as pl
17
+ from pytorch_lightning.callbacks import LearningRateMonitor, RichProgressBar
18
+ from pytorch_lightning.loggers import WandbLogger
19
+ from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme
20
+ from pytorch_lightning.callbacks import ModelCheckpoint
21
+
22
+ block_size = 512
23
+ batch_size = 8
24
+ max_lr = 1e-3
25
+ warmup_steps = 10
26
+ max_steps = 25000
27
+ log_every_n_steps = 100
28
+ save_checkpoints_every_n_steps = 10
29
+ effective_batch_size = 32
30
+
31
+ tokenizer: GPT2Tokenizer = GPT2Tokenizer.from_pretrained(
32
+ "HuggingFaceTB/cosmo2-tokenizer"
33
+ )
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+ vocab_size = tokenizer.vocab_size
36
+
37
+
38
+ def load_cosmopedia_dataset(batch_size=8, seq_length=1024):
39
+ """
40
+ Returns a torch dataloader for the cosmopedia dataset
41
+ """
42
+ try:
43
+ dataset = load_dataset(
44
+ "HuggingFaceTB/smollm-corpus",
45
+ name="cosmopedia-v2",
46
+ split="train",
47
+ streaming=True,
48
+ )
49
+
50
+ def encode(examples):
51
+ tokens = tokenizer(
52
+ examples["text"],
53
+ truncation=True,
54
+ padding="max_length",
55
+ max_length=seq_length + 1,
56
+ return_tensors="pt",
57
+ )
58
+ input_ids = tokens["input_ids"].squeeze(0).clone().detach()
59
+ input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1)
60
+ labels = input_ids.clone().detach()
61
+ labels = labels[1:].to(torch.int64)
62
+ input_ids = input_ids[:-1].to(torch.int64)
63
+
64
+ return {"input_ids": input_ids, "labels": labels}
65
+
66
+ dataset = dataset.map(encode, remove_columns=["text"], batched=False)
67
+ dataset = dataset.with_format("torch")
68
+ dataloader = DataLoader(dataset, batch_size=batch_size)
69
+ return dataloader
70
+ except Exception as e:
71
+ print(e)
72
+ return None
73
+
74
+
75
+ @dataclass
76
+ class SmolLMConfig:
77
+ block_size = 1024
78
+ vocab_size = 49152
79
+ n_layers = 30
80
+ n_heads = 9
81
+ n_embed = 576
82
+ dropout = 0.1
83
+ mlp_hidden_dim = 1536
84
+ attention_dropout = 0.0
85
+ dropout = 0.1
86
+ n_key_value_heads = 3
87
+ rms_norm_eps = 1e-5
88
+
89
+
90
+ ## Function which enables K and V to have less heads than Q.
91
+ ## it repeats the K and V heads n_rep times
92
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
94
+ bs, n_kv_heads, slen, head_dim = x.shape
95
+ if n_rep == 1:
96
+ return x
97
+ return (
98
+ x[:, :, :, None, :]
99
+ .expand(bs, n_kv_heads, slen, n_rep, head_dim)
100
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
101
+ )
102
+
103
+
104
+ class RMSNorm(torch.nn.Module):
105
+ def __init__(self, dim: int, eps: float = 1e-6):
106
+ """
107
+ Initialize the RMSNorm normalization layer.
108
+
109
+ Args:
110
+ dim (int): The dimension of the input tensor.
111
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
112
+
113
+ Attributes:
114
+ eps (float): A small value added to the denominator for numerical stability.
115
+ weight (nn.Parameter): Learnable scaling parameter.
116
+
117
+ """
118
+ super().__init__()
119
+ self.eps = eps
120
+ self.weight = nn.Parameter(torch.ones(dim))
121
+
122
+ def _norm(self, x):
123
+ """
124
+ Apply the RMSNorm normalization to the input tensor.
125
+
126
+ Args:
127
+ x (torch.Tensor): The input tensor.
128
+
129
+ Returns:
130
+ torch.Tensor: The normalized tensor.
131
+
132
+ """
133
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
134
+
135
+ def forward(self, x):
136
+ """
137
+ Forward pass through the RMSNorm layer.
138
+
139
+ Args:
140
+ x (torch.Tensor): The input tensor.
141
+
142
+ Returns:
143
+ torch.Tensor: The output tensor after applying RMSNorm.
144
+
145
+ """
146
+ output = self._norm(x.float()).type_as(x)
147
+ return output * self.weight
148
+
149
+
150
+ class CausalMultiHeadAttention(nn.Module):
151
+ def __init__(self, config: SmolLMConfig):
152
+ super().__init__()
153
+ self.config = config
154
+ self.n_head = config.n_heads
155
+ self.n_embd = config.n_embed
156
+
157
+ # Linear projections for Q, K, V
158
+ # self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) # [n_embd, 3 * n_embd]
159
+ self.w_q = nn.Linear(config.n_embed, config.n_embed, bias=False)
160
+ self.w_k = nn.Linear(
161
+ config.n_embed, config.n_embed // config.n_key_value_heads, bias=False
162
+ )
163
+ self.w_v = nn.Linear(
164
+ config.n_embed, config.n_embed // config.n_key_value_heads, bias=False
165
+ )
166
+ self.c_proj = nn.Linear(
167
+ config.n_embed, config.n_embed, bias=False
168
+ ) # [n_embd, n_embd]
169
+ self.c_proj.NANGPT_SCALE_INIT = 1
170
+
171
+ self.n_rep = self.config.n_heads // self.config.n_key_value_heads
172
+
173
+ self.resid_dropout = nn.Dropout(config.dropout)
174
+ self.register_buffer(
175
+ "bias",
176
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
177
+ 1, 1, config.block_size, config.block_size
178
+ ),
179
+ )
180
+
181
+ def forward(self, x):
182
+ B, T, C = x.size() # [B, T, n_embd]
183
+
184
+ # Linear projection and split into Q, K, V
185
+ # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) # [B, T, n_embd] each
186
+ q = self.w_q(x) # [B, T, 576]
187
+ k = self.w_k(x) # [B, T, 192]
188
+ v = self.w_v(x) # [B, T, 192]
189
+
190
+ # Reshape for multi-head attention
191
+ k = k.view(
192
+ B,
193
+ T,
194
+ self.config.n_key_value_heads,
195
+ k.size(-1) // self.config.n_key_value_heads,
196
+ ).transpose(
197
+ 1, 2
198
+ ) # [B, 3, T, 64]
199
+ q = q.view(
200
+ B, T, self.config.n_heads, q.size(-1) // self.config.n_heads
201
+ ).transpose(
202
+ 1, 2
203
+ ) # [B, 9, T, 64]
204
+ v = v.view(
205
+ B,
206
+ T,
207
+ self.config.n_key_value_heads,
208
+ v.size(-1) // self.config.n_key_value_heads,
209
+ ).transpose(
210
+ 1, 2
211
+ ) # [B, 3, T, 64]
212
+
213
+ # repeat k and v for each head
214
+ k = repeat_kv(k, self.n_rep)
215
+ v = repeat_kv(v, self.n_rep)
216
+
217
+ # # Attention scores
218
+ # att = (q @ k.transpose(-2, -1)) * (
219
+ # 1.0 / (k.size(-1) ** 0.5)
220
+ # ) # [B, n_head, T, T]
221
+ # att = att.masked_fill(
222
+ # self.bias[:, :, :T, :T] == 0, float("-inf")
223
+ # ) # [B, n_head, T, T]
224
+ # att = F.softmax(att, dim=-1) # [B, n_head, T, T]
225
+
226
+ # # Weighted sum of values
227
+ # y = att @ v # [B, n_head, T, n_embd/n_head]
228
+
229
+ # Flash attention
230
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # Flash attention
231
+ # Reshape and project
232
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # [B, T, n_embd]
233
+ y = self.c_proj(y) # [B, T, n_embd]
234
+ y = self.resid_dropout(y) # [B, T, n_embd]
235
+
236
+ return y
237
+
238
+
239
+ class MLP(nn.Module):
240
+
241
+ def __init__(self, config: SmolLMConfig):
242
+ super().__init__()
243
+ self.c_fc = nn.Linear(config.n_embed, config.mlp_hidden_dim, bias=False)
244
+ self.silu = nn.SiLU()
245
+ self.c_proj = nn.Linear(config.mlp_hidden_dim, config.n_embed, bias=False)
246
+ self.c_proj.NANOGPT_SCALE_INIT = 1
247
+
248
+ def forward(self, x):
249
+ x = self.c_fc(x)
250
+ x = self.silu(x)
251
+ x = self.c_proj(x)
252
+ return x
253
+
254
+
255
+ class LlamaMLP(nn.Module):
256
+
257
+ def __init__(self, config: SmolLMConfig):
258
+ super().__init__()
259
+ self.hidden_dim = config.mlp_hidden_dim # 1536
260
+ self.w1 = nn.Linear(config.n_embed, self.hidden_dim, bias=False)
261
+ self.w2 = nn.Linear(self.hidden_dim, config.n_embed, bias=False)
262
+ self.w3 = nn.Linear(config.n_embed, self.hidden_dim, bias=False)
263
+
264
+ def forward(self, x):
265
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
266
+
267
+
268
+ class DecoderBlockWithRMSNorm(nn.Module):
269
+ def __init__(self, config: SmolLMConfig):
270
+ super().__init__()
271
+ self.config = config
272
+ self.rms_1 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps)
273
+ self.attn = CausalMultiHeadAttention(config)
274
+ self.rms_2 = RMSNorm(self.config.n_embed, eps=self.config.rms_norm_eps)
275
+ self.mlp = LlamaMLP(config)
276
+
277
+ def forward(self, x):
278
+ x = x + self.attn(self.rms_1(x))
279
+ x = x + self.mlp(self.rms_2(x))
280
+ return x
281
+
282
+
283
+ class DecoderBlockWithLayerNorm(nn.Module):
284
+ def __init__(self, config: SmolLMConfig):
285
+ super().__init__()
286
+ self.ln_1 = nn.LayerNorm(config.n_embed)
287
+ self.attn = CausalMultiHeadAttention(config)
288
+ self.ln_2 = nn.LayerNorm(config.n_embed)
289
+ self.mlp = MLP(config)
290
+
291
+ def forward(self, x):
292
+ x = x + self.attn(self.ln_1(x))
293
+ x = x + self.mlp(self.ln_2(x))
294
+ return x
295
+
296
+
297
+ class SmolLM(nn.Module):
298
+ def __init__(self, config: SmolLMConfig):
299
+ super().__init__()
300
+ self.config = config
301
+ self.wte = nn.Embedding(
302
+ config.vocab_size, config.n_embed
303
+ ) # [vocab_size, n_embd]
304
+ self.wpe = nn.Embedding(
305
+ config.block_size, config.n_embed
306
+ ) # [max_seq_len, n_embd]
307
+ self.drop = nn.Dropout(config.dropout)
308
+ self.blocks = nn.ModuleList(
309
+ [DecoderBlockWithRMSNorm(config) for _ in range(config.n_layers)]
310
+ )
311
+ self.rms_norm = RMSNorm(config.n_embed, eps=config.rms_norm_eps) # [n_embd]
312
+ self.lm_head = nn.Linear(
313
+ config.n_embed, config.vocab_size, bias=False
314
+ ) # [n_embd, vocab_size]
315
+
316
+ # weight sharing
317
+ self.wte.weight = self.lm_head.weight
318
+
319
+ self.apply(self._init_weights)
320
+
321
+ def _init_weights(self, module):
322
+ if isinstance(module, nn.Linear):
323
+ std = 0.02
324
+ if hasattr(module, "NANGPT_SCALE_INIT"):
325
+ std *= (2 * self.config.n_layers) ** -0.5
326
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
327
+ if module.bias is not None:
328
+ torch.nn.init.zeros_(module.bias)
329
+ elif isinstance(module, nn.Embedding):
330
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
331
+
332
+ def forward(self, idx, targets=None):
333
+ # idx is of shape (B, T)
334
+ B, T = idx.size()
335
+ assert (
336
+ T <= self.config.block_size
337
+ ), f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
338
+
339
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
340
+ pos_emb = self.wpe(pos) # position embeddings of shape (T, n_embd)
341
+ x = self.wte(idx) # token embeddings of shape (B, T, n_embd)
342
+ x = x + pos_emb
343
+
344
+ # forward the blocks of the transformer
345
+ for block in self.blocks:
346
+ x = block(x)
347
+ # forward the final layernorm and the classifier
348
+ x = self.rms_norm(x)
349
+ logits = self.lm_head(x) # (B, T, vocab_size)
350
+ loss = None
351
+ if targets is not None:
352
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
353
+ return logits, loss
354
+
355
+ @torch.no_grad()
356
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
357
+ """
358
+ Generate text given a starting sequence of tokens.
359
+
360
+ Args:
361
+ idx (torch.Tensor): Starting token indices, shape (B, T)
362
+ max_new_tokens (int): Number of tokens to generate
363
+ temperature (float): Sampling temperature (1.0 = no change, < 1.0 = less random, > 1.0 = more random)
364
+ top_k (int): If specified, only sample from the top k most probable tokens
365
+ """
366
+ for _ in range(max_new_tokens):
367
+ # if the sequence context is growing too long we must crop it at block_size
368
+ idx_cond = (
369
+ idx
370
+ if idx.size(1) <= self.config.block_size
371
+ else idx[:, -self.config.block_size :]
372
+ )
373
+ # forward the model to get the logits for the index in the sequence
374
+ logits, _ = self(idx_cond)
375
+ # pluck the logits at the final step and scale by desired temperature
376
+ logits = logits[:, -1, :] / temperature
377
+ # optionally crop the logits to only the top k options
378
+ if top_k is not None:
379
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
380
+ logits[logits < v[:, [-1]]] = -float("Inf")
381
+ # apply softmax to convert logits to (normalized) probabilities
382
+ probs = F.softmax(logits, dim=-1)
383
+ # sample from the distribution
384
+ idx_next = torch.multinomial(probs, num_samples=1)
385
+ # append sampled index to the running sequence
386
+ idx = torch.cat((idx, idx_next), dim=1)
387
+
388
+ return idx
389
+
390
+
391
+ class SmolLMLightning(pl.LightningModule):
392
+ def __init__(self, config: SmolLMConfig, lr, warmup_steps, max_steps):
393
+ super().__init__()
394
+ self.save_hyperparameters()
395
+ self.config = config
396
+ self.model = SmolLM(self.config)
397
+ self.criterion = nn.CrossEntropyLoss()
398
+ self.tokenizer = tokenizer
399
+ self.generation_prompt = "Once upon a time"
400
+ self._generating = False
401
+
402
+ def forward(self, x):
403
+ return self.model(x)
404
+
405
+ def training_step(self, batch, batch_idx):
406
+ input_ids = batch["input_ids"]
407
+ target_ids = batch["labels"]
408
+ logits, _ = self(input_ids)
409
+ loss = self.criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
410
+
411
+ # Log the loss with 4 decimal precision
412
+ self.log(
413
+ "train_loss", loss, prog_bar=True, on_step=True, on_epoch=False, logger=True
414
+ )
415
+
416
+ # Generate text every n steps, but only if we're not already generating
417
+ if (self.global_step) % log_every_n_steps == 0 and not self._generating:
418
+ self._generating = True
419
+ self.generate_and_log_sample()
420
+ self._generating = False
421
+
422
+ return loss
423
+
424
+ def generate_and_log_sample(self):
425
+ """Generate and log a sample of text from the model"""
426
+ try:
427
+ # Encode the prompt
428
+ prompt_ids = self.tokenizer.encode(
429
+ self.generation_prompt, return_tensors="pt"
430
+ ).to(self.device)
431
+
432
+ # Generate new tokens
433
+ generated_ids = self.model.generate(
434
+ prompt_ids, max_new_tokens=50, temperature=0.8, top_k=40
435
+ )
436
+
437
+ # Decode the generated tokens
438
+ generated_text = self.tokenizer.decode(generated_ids[0].tolist())
439
+
440
+ # Create a formatted message
441
+ message = (
442
+ f"\n{'='*40}\n"
443
+ f"Step {self.global_step} generation:\n"
444
+ f"Prompt: {self.generation_prompt}\n"
445
+ f"Generated: {generated_text}\n"
446
+ f"{'='*40}\n"
447
+ )
448
+
449
+ print(message)
450
+
451
+ # Log to WandB
452
+ if hasattr(self.logger, "experiment"):
453
+ self.logger.experiment.log(
454
+ {"generated_text": generated_text, "global_step": self.global_step}
455
+ )
456
+ except Exception as e:
457
+ print(f"Generation failed with error: {str(e)}")
458
+
459
+ def configure_optimizers(self):
460
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
461
+
462
+ def lr_lambda(current_step):
463
+ if current_step < self.hparams.warmup_steps:
464
+ return self.hparams.lr * (current_step + 1) / self.hparams.warmup_steps
465
+ elif current_step > self.hparams.max_steps:
466
+ return self.hparams.lr * 0.1
467
+ decay_ratio = (current_step - self.hparams.warmup_steps) / (
468
+ self.hparams.max_steps - self.hparams.warmup_steps
469
+ )
470
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
471
+ return self.hparams.lr * 0.1 + coeff * (
472
+ self.hparams.lr - self.hparams.lr * 0.1
473
+ )
474
+
475
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
476
+ return [optimizer], [scheduler]
477
+
478
+
479
+ if __name__ == "__main__":
480
+ torch.set_float32_matmul_precision("high")
481
+
482
+ dataloader = load_cosmopedia_dataset(batch_size=batch_size, seq_length=block_size)
483
+
484
+ # Check if checkpoint exists
485
+ checkpoint_path = "checkpoints/best-checkpoint.ckpt"
486
+ if os.path.exists(checkpoint_path):
487
+ print(f"Loading model from checkpoint: {checkpoint_path}")
488
+ model = SmolLMLightning.load_from_checkpoint(
489
+ checkpoint_path,
490
+ config=SmolLMConfig(),
491
+ lr=max_lr,
492
+ warmup_steps=warmup_steps,
493
+ max_steps=max_steps,
494
+ )
495
+ else:
496
+ print("Starting training from scratch")
497
+ model = SmolLMLightning(SmolLMConfig(), max_lr, warmup_steps, max_steps)
498
+
499
+ # Replace TensorBoard logger with WandB logger
500
+ wandb_logger = WandbLogger(
501
+ project="smollm", # your project name
502
+ name="transformer_experiment", # name of the run
503
+ log_model=True, # log model checkpoints
504
+ )
505
+
506
+ os.makedirs("checkpoints", exist_ok=True)
507
+ checkpoint_callback = ModelCheckpoint(
508
+ dirpath="checkpoints/",
509
+ filename="best-checkpoint",
510
+ verbose=True,
511
+ every_n_train_steps=save_checkpoints_every_n_steps,
512
+ )
513
+
514
+ device = "cpu"
515
+ if torch.cuda.is_available():
516
+ device = "cuda"
517
+ elif torch.backends.mps.is_available():
518
+ device = "mps"
519
+ print(f"using device: {device}")
520
+
521
+ progress_bar = RichProgressBar(
522
+ refresh_rate=1,
523
+ leave=False,
524
+ theme=RichProgressBarTheme(
525
+ description="",
526
+ progress_bar="#6206E0",
527
+ progress_bar_finished="#6206E0",
528
+ progress_bar_pulse="#6206E0",
529
+ batch_progress="",
530
+ time="dim",
531
+ processing_speed="dim underline",
532
+ metrics="italic",
533
+ metrics_text_delimiter=" ",
534
+ metrics_format=".3f",
535
+ ),
536
+ console_kwargs=None,
537
+ )
538
+
539
+ trainer = pl.Trainer(
540
+ max_steps=max_steps,
541
+ accelerator=device,
542
+ devices=1,
543
+ callbacks=[
544
+ LearningRateMonitor(logging_interval="step"),
545
+ progress_bar,
546
+ checkpoint_callback,
547
+ ],
548
+ precision="bf16-mixed",
549
+ log_every_n_steps=1,
550
+ enable_progress_bar=True,
551
+ enable_model_summary=True,
552
+ logger=wandb_logger,
553
+ accumulate_grad_batches=effective_batch_size // batch_size,
554
+ )
555
+
556
+ trainer.fit(model, dataloader)