elapt1c commited on
Commit
63a71b1
·
1 Parent(s): 149637d

Delete HROM_Trainer.py

Browse files

this is a space, people can check the github if they want to train. we want to centralise the files as much as possible to prevent outdated sources

Files changed (1) hide show
  1. HROM_Trainer.py +0 -384
HROM_Trainer.py DELETED
@@ -1,384 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.utils.data import Dataset, DataLoader
4
- from datasets import load_dataset
5
- from tokenizers import Tokenizer, models, trainers, pre_tokenizers, processors, decoders
6
- import math
7
- import os
8
- import re
9
- from datetime import datetime
10
- from contextlib import nullcontext
11
-
12
- # Configuration
13
- CONFIG = {
14
- "dim": 512,
15
- "n_layers": 6,
16
- "n_heads": 8,
17
- "ff_dim": 2048,
18
- "dropout": 0.1,
19
- "max_seq_len": 1024,
20
- "batch_size": 32,
21
- "checkpoint_interval": 1000,
22
- "debug_interval": 500,
23
- "dataset": "daily_dialog",
24
- "vocab_size": 32000,
25
- "tokenizer_train_samples": 100000,
26
- "learning_rate": 1e-4, # Lowered learning rate
27
- "max_turns": 6,
28
- "max_checkpoints": 5,
29
- "num_epochs": 100, # Increased number of epochs
30
- "grad_accum_steps": 4 # Gradient accumulation steps
31
- }
32
-
33
- class RotaryEmbedding(nn.Module):
34
- def __init__(self, dim):
35
- super().__init__()
36
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
37
- self.register_buffer("inv_freq", inv_freq)
38
-
39
- def forward(self, seq_len):
40
- t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
41
- freqs = torch.einsum("i, j -> i j", t, self.inv_freq)
42
- return torch.cat((freqs, freqs), dim=-1)
43
-
44
- def rotate_half(x):
45
- x1, x2 = x.chunk(2, dim=-1)
46
- return torch.cat((-x2, x1), dim=-1)
47
-
48
- def apply_rotary_pos_emb(pos, t):
49
- pos = pos.unsqueeze(0).unsqueeze(1)
50
- return (t * pos.cos()) + (rotate_half(t) * pos.sin())
51
-
52
- class SwiGLU(nn.Module):
53
- def forward(self, x):
54
- x, gate = x.chunk(2, dim=-1)
55
- return x * torch.sigmoid(gate)
56
-
57
- class HROMAttention(nn.Module):
58
- def __init__(self):
59
- super().__init__()
60
- self.dim = CONFIG["dim"]
61
- self.n_heads = CONFIG["n_heads"]
62
- self.head_dim = self.dim // self.n_heads
63
- self.qkv = nn.Linear(self.dim, 3 * self.dim)
64
- self.proj = nn.Linear(self.dim, self.dim)
65
- self.rotary = RotaryEmbedding(self.head_dim)
66
- self.dropout = nn.Dropout(CONFIG["dropout"])
67
-
68
- def forward(self, x, mask=None):
69
- B, T, _ = x.shape
70
- qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
71
- q, k, v = qkv.unbind(2)
72
- q = q.transpose(1, 2)
73
- k = k.transpose(1, 2)
74
- v = v.transpose(1, 2)
75
- pos = self.rotary(T)
76
- q = apply_rotary_pos_emb(pos, q)
77
- k = apply_rotary_pos_emb(pos, k)
78
- attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
79
- if mask is not None:
80
- mask = mask.unsqueeze(1)
81
- attn = attn + mask
82
- attn = torch.softmax(attn, dim=-1)
83
- attn = self.dropout(attn)
84
- out = attn @ v
85
- out = out.transpose(1, 2).reshape(B, T, self.dim)
86
- return self.proj(out)
87
-
88
- class HROMBlock(nn.Module):
89
- def __init__(self):
90
- super().__init__()
91
- self.attn = HROMAttention()
92
- self.ff = nn.Sequential(
93
- nn.Linear(CONFIG["dim"], 2 * CONFIG["ff_dim"]),
94
- SwiGLU(),
95
- nn.Linear(CONFIG["ff_dim"], CONFIG["dim"])
96
- )
97
- self.norm1 = nn.LayerNorm(CONFIG["dim"])
98
- self.norm2 = nn.LayerNorm(CONFIG["dim"])
99
- self.dropout = nn.Dropout(CONFIG["dropout"])
100
-
101
- def forward(self, x, mask=None):
102
- x = x + self.dropout(self.attn(self.norm1(x), mask))
103
- x = x + self.dropout(self.ff(self.norm2(x)))
104
- return x
105
-
106
- class HROM(nn.Module):
107
- def __init__(self):
108
- super().__init__()
109
- self.embed = nn.Embedding(CONFIG["vocab_size"], CONFIG["dim"])
110
- self.blocks = nn.ModuleList([HROMBlock() for _ in range(CONFIG["n_layers"])])
111
- self.norm = nn.LayerNorm(CONFIG["dim"])
112
- self.head = nn.Linear(CONFIG["dim"], CONFIG["vocab_size"])
113
- self.apply(self._init_weights)
114
-
115
- def _init_weights(self, module):
116
- if isinstance(module, nn.Linear):
117
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
- if module.bias is not None:
119
- torch.nn.init.zeros_(module.bias)
120
-
121
- def forward(self, x, attention_mask=None):
122
- x = self.embed(x)
123
- if attention_mask is not None:
124
- B, T = attention_mask.shape
125
- causal_mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
126
- causal_mask = causal_mask.to(x.device)
127
- pad_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(dtype=torch.float32)
128
- pad_mask = (1.0 - pad_mask) * torch.finfo(torch.float32).min
129
- mask = causal_mask + pad_mask.squeeze(1)
130
- else:
131
- B, T = x.shape[:2]
132
- mask = torch.triu(torch.ones(T, T) * float('-inf'), diagonal=1)
133
- mask = mask.to(x.device)
134
- mask = mask.unsqueeze(0).expand(B, -1, -1)
135
- for block in self.blocks:
136
- x = block(x, mask)
137
- return self.head(self.norm(x))
138
-
139
- class TokenizerTrainer:
140
- def __init__(self):
141
- self.tokenizer = Tokenizer(models.BPE())
142
- self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True)
143
- self.tokenizer.decoder = decoders.ByteLevel()
144
- self.special_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<user>", "<assistant>"]
145
-
146
- def train(self, dataset_name):
147
- dataset = load_dataset(dataset_name, split=f"train[:{CONFIG['tokenizer_train_samples']}]")
148
- text_samples = []
149
- for entry in dataset:
150
- if "dialog" in entry:
151
- for i, utterance in enumerate(entry["dialog"][:CONFIG["max_turns"]]):
152
- role = "<user>" if i % 2 == 0 else "<assistant>"
153
- text_samples.append(f"{role} {utterance}")
154
- else:
155
- text_samples.append(self._clean_text(entry.get("text", "")))
156
- trainer = trainers.BpeTrainer(
157
- vocab_size=CONFIG["vocab_size"],
158
- special_tokens=self.special_tokens,
159
- min_frequency=2,
160
- show_progress=True
161
- )
162
- self.tokenizer.train_from_iterator(text_samples, trainer=trainer, length=len(text_samples))
163
- self.tokenizer.post_processor = processors.TemplateProcessing(
164
- single="$A </s>",
165
- pair="$A $B </s>",
166
- special_tokens=[("</s>", self.tokenizer.token_to_id("</s>"))],
167
- )
168
- os.makedirs("tokenizer", exist_ok=True)
169
- self.tokenizer.save("tokenizer/hrom_tokenizer.json")
170
-
171
- def _clean_text(self, text):
172
- text = re.sub(r'[^\w\s.,!?\'\-:;<>]', '', text)
173
- text = re.sub(r'\s+', ' ', text).strip()
174
- return text
175
-
176
- class ChatDataset(Dataset):
177
- def __init__(self, tokenizer):
178
- full_dataset = load_dataset(CONFIG["dataset"], split="train")
179
- num_samples = min(len(full_dataset), CONFIG["tokenizer_train_samples"])
180
- self.dataset = full_dataset.shuffle(seed=42).select(range(num_samples))
181
- self.tokenizer = tokenizer
182
- self.max_length = CONFIG["max_seq_len"]
183
- self.turn_sep = self.tokenizer.token_to_id("</s>")
184
-
185
- def __len__(self):
186
- return len(self.dataset)
187
-
188
- def __getitem__(self, idx):
189
- entry = self.dataset[idx]
190
- formatted = []
191
- if "dialog" in entry:
192
- dialog = entry["dialog"][:CONFIG["max_turns"]]
193
- for i, utterance in enumerate(dialog):
194
- role_token = "<user>" if i % 2 == 0 else "<assistant>"
195
- formatted.extend([
196
- self.tokenizer.token_to_id(role_token),
197
- *self.tokenizer.encode(utterance).ids,
198
- self.turn_sep
199
- ])
200
- else:
201
- text = entry.get("text", "")
202
- formatted.extend([
203
- self.tokenizer.token_to_id("<user>"),
204
- *self.tokenizer.encode(text).ids,
205
- self.turn_sep
206
- ])
207
- formatted = formatted[:self.max_length-2]
208
- formatted = [self.tokenizer.token_to_id("<s>"), *formatted, self.tokenizer.token_to_id("</s>")]
209
- return {
210
- "input_ids": formatted[:-1],
211
- "labels": formatted[1:]
212
- }
213
-
214
- @staticmethod
215
- def collate_fn(batch):
216
- max_len = max(len(item["input_ids"]) for item in batch)
217
- pad_id = Tokenizer.from_file("tokenizer/hrom_tokenizer.json").token_to_id("<pad>")
218
- inputs, labels, masks = [], [], []
219
- for item in batch:
220
- pad_len = max_len - len(item["input_ids"])
221
- inputs.append(item["input_ids"] + [pad_id] * pad_len)
222
- labels.append(item["labels"] + [pad_id] * pad_len)
223
- masks.append([1] * len(item["input_ids"]) + [0] * pad_len)
224
- return {
225
- "input_ids": torch.tensor(inputs),
226
- "labels": torch.tensor(labels),
227
- "attention_mask": torch.tensor(masks)
228
- }
229
-
230
- class HROMTrainer:
231
- def __init__(self, model, tokenizer):
232
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
233
- self.model = model.to(self.device)
234
- if self.device.type == "cuda":
235
- self.scaler = torch.cuda.amp.GradScaler()
236
- else:
237
- self.scaler = None
238
- self.optimizer = torch.optim.AdamW(
239
- self.model.parameters(),
240
- lr=CONFIG["learning_rate"],
241
- fused=True if self.device.type == "cuda" else False
242
- )
243
- self.tokenizer = tokenizer
244
-
245
- def train_step(self, batch):
246
- autocast = torch.cuda.amp.autocast if self.device.type == "cuda" else nullcontext
247
- with autocast():
248
- outputs = self.model(
249
- batch["input_ids"].to(self.device),
250
- attention_mask=batch["attention_mask"].to(self.device)
251
- )
252
- original_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.token_to_id("<pad>"))(
253
- outputs.view(-1, CONFIG["vocab_size"]),
254
- batch["labels"].view(-1).to(self.device)
255
- )
256
- scaled_loss = original_loss / CONFIG["grad_accum_steps"]
257
-
258
- if self.scaler is not None:
259
- self.scaler.scale(scaled_loss).backward()
260
- else:
261
- scaled_loss.backward()
262
-
263
- return original_loss.item()
264
-
265
- def clip_and_step(self):
266
- if self.scaler is not None:
267
- self.scaler.unscale_(self.optimizer)
268
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
269
-
270
- if self.scaler is not None:
271
- self.scaler.step(self.optimizer)
272
- self.scaler.update()
273
- else:
274
- self.optimizer.step()
275
-
276
- self.optimizer.zero_grad()
277
-
278
- class SafetyManager:
279
- def __init__(self, model, tokenizer):
280
- self.model = model
281
- self.tokenizer = tokenizer
282
- self.bad_words = ["hate", "kill", "harm"]
283
- self.bad_word_ids = [tokenizer.encode(w).ids for w in self.bad_words]
284
-
285
- def content_filter(self, text):
286
- tokens = self.tokenizer.encode(text).ids
287
- for bad_ids in self.bad_word_ids:
288
- if any(tokens[i:i+len(bad_ids)] == bad_ids for i in range(len(tokens))):
289
- return False
290
- return True
291
-
292
- def generate_safely(self, prompt, max_length=50):
293
- input_ids = self.tokenizer.encode(prompt).ids
294
- device = next(self.model.parameters()).device
295
- for _ in range(max_length):
296
- with torch.no_grad():
297
- logits = self.model(torch.tensor([input_ids]).to(device))
298
- next_token = logits.argmax(-1)[:, -1].item()
299
- if next_token == self.tokenizer.token_to_id("</s>"):
300
- break
301
- generated = self.tokenizer.decode(input_ids + [next_token])
302
- if not self.content_filter(generated):
303
- break
304
- input_ids.append(next_token)
305
- return self.tokenizer.decode(input_ids)
306
-
307
- def debug_generation(self, prompt="Hello!"):
308
- print(f"\nSafety Check Generation:")
309
- response = self.generate_safely(prompt)
310
- print(f"Prompt: {prompt}\nResponse: {response}")
311
-
312
- class CheckpointManager:
313
- def __init__(self):
314
- self.checkpoint_dir = "checkpoints"
315
- os.makedirs(self.checkpoint_dir, exist_ok=True)
316
-
317
- def save(self, model, optimizer, step):
318
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
319
- path = f"{self.checkpoint_dir}/hrom_{timestamp}_step{step}.pt"
320
- torch.save({
321
- "model": model.state_dict(),
322
- "optimizer": optimizer.state_dict(),
323
- "step": step,
324
- "config": CONFIG
325
- }, path)
326
- self._cleanup_old_checkpoints()
327
-
328
- def _cleanup_old_checkpoints(self):
329
- checkpoints = sorted(os.listdir(self.checkpoint_dir),
330
- key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
331
- while len(checkpoints) > CONFIG["max_checkpoints"]:
332
- os.remove(os.path.join(self.checkpoint_dir, checkpoints[0]))
333
- checkpoints = checkpoints[1:]
334
-
335
- def train():
336
- checkpoint_manager = CheckpointManager()
337
- if not os.path.exists("tokenizer/hrom_tokenizer.json"):
338
- print("Training tokenizer...")
339
- tokenizer_trainer = TokenizerTrainer()
340
- tokenizer_trainer.train(CONFIG["dataset"])
341
-
342
- tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
343
- model = HROM()
344
- print("Downloading and caching the dataset...")
345
- _ = load_dataset(CONFIG["dataset"], split="train", download_mode="reuse_cache_if_exists")
346
-
347
- dataset = ChatDataset(tokenizer)
348
- dataloader = DataLoader(
349
- dataset,
350
- batch_size=CONFIG["batch_size"],
351
- collate_fn=ChatDataset.collate_fn
352
- )
353
-
354
- trainer_obj = HROMTrainer(model, tokenizer)
355
- safety = SafetyManager(model, tokenizer)
356
-
357
- step = 0
358
- optimizer_step = 0
359
- total_loss = 0.0
360
- model.train()
361
-
362
- for epoch in range(CONFIG["num_epochs"]):
363
- for batch in dataloader:
364
- loss = trainer_obj.train_step(batch)
365
- total_loss += loss
366
- step += 1
367
-
368
- if step % CONFIG["grad_accum_steps"] == 0:
369
- trainer_obj.clip_and_step()
370
- avg_loss = total_loss / CONFIG["grad_accum_steps"]
371
- total_loss = 0.0
372
-
373
- if optimizer_step % CONFIG["checkpoint_interval"] == 0:
374
- checkpoint_manager.save(model, trainer_obj.optimizer, optimizer_step)
375
- safety.debug_generation()
376
-
377
- if optimizer_step % CONFIG["debug_interval"] == 0:
378
- print(f"Optimizer Step {optimizer_step} | Loss: {avg_loss:.4f}")
379
- safety.debug_generation("What's the meaning of life?")
380
-
381
- optimizer_step += 1
382
-
383
- if __name__ == "__main__":
384
- train()