atiwari751 commited on
Commit
349c73f
·
0 Parent(s):

final model

Browse files
Files changed (8) hide show
  1. .gitattributes +35 -0
  2. .gitignore +4 -0
  3. README.md +13 -0
  4. generate.py +76 -0
  5. input.txt +0 -0
  6. logs.txt +111 -0
  7. prompts.txt +116 -0
  8. train_get2_8_init.py +294 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ .venv
3
+ gpt2-model
4
+
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GPT2 Replica
3
+ emoji: 👀
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: 'A decoder trained starting with GPT2 weights. '
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
generate.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import tiktoken
4
+ from train_get2_8_init import GPT, GPTConfig # Import your model architecture
5
+
6
+ # Device setup
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ print(f"Using device: {device}")
9
+
10
+ # Initialize model and load trained weights
11
+ model = GPT(GPTConfig())
12
+ model.load_state_dict(torch.load('best_model.pt'))
13
+ model.to(device)
14
+ model.eval() # Set to evaluation mode
15
+
16
+ # Initialize tokenizer
17
+ enc = tiktoken.get_encoding('gpt2')
18
+
19
+ def generate_text(prompt, max_length=100, num_sequences=5, top_k=50, seed=42):
20
+ """
21
+ Generate text from a prompt using the trained model.
22
+ """
23
+ # Set random seed
24
+ torch.manual_seed(seed)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed(seed)
27
+
28
+ # Encode the prompt
29
+ tokens = enc.encode(prompt)
30
+ x = torch.tensor(tokens).unsqueeze(0).repeat(num_sequences, 1).to(device)
31
+
32
+ # Generate text
33
+ while x.size(1) < max_length:
34
+ with torch.no_grad():
35
+ logits = model(x)[0]
36
+ logits = logits[:, -1, :]
37
+ probs = F.softmax(logits, dim=-1)
38
+
39
+ # Top-k sampling
40
+ topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1)
41
+ ix = torch.multinomial(topk_probs, 1)
42
+ xcol = torch.gather(topk_indices, -1, ix)
43
+
44
+ # Append to sequence
45
+ x = torch.cat((x, xcol), dim=1)
46
+
47
+ # Decode and print results
48
+ print(f"\nPrompt: {prompt}")
49
+ print("\nGenerated sequences:")
50
+ print("-" * 50)
51
+ for i in range(num_sequences):
52
+ tokens = x[i, :].tolist()
53
+ decoded = enc.decode(tokens)
54
+ print(f"\n{i+1}. {decoded}")
55
+ print("-" * 50)
56
+
57
+ # Interactive prompt loop
58
+ if __name__ == "__main__":
59
+ while True:
60
+ prompt = input("\nEnter your prompt (or 'quit' to exit): ")
61
+ if prompt.lower() == 'quit':
62
+ break
63
+
64
+ try:
65
+ max_len = int(input("Max length (default 100): ") or 100)
66
+ num_seq = int(input("Number of sequences (default 5): ") or 5)
67
+ except ValueError:
68
+ print("Using default values due to invalid input")
69
+ max_len = 100
70
+ num_seq = 5
71
+
72
+ generate_text(
73
+ prompt=prompt,
74
+ max_length=max_len,
75
+ num_sequences=num_seq
76
+ )
input.txt ADDED
The diff for this file is too large to render. See raw diff
 
logs.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ loading weights from local path: ./S12/gpt2-model
2
+ loaded 338025 tokens
3
+ 1 epoch = 2640 batches
4
+ Epoch 1/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:07<00:00, 14.07batch/s, loss=3.5681, batch=2639/2640]
5
+ Epoch 1 completed. Loss: 3.5681
6
+ New best loss! Saving model...
7
+ Epoch 2/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:07<00:00, 14.05batch/s, loss=2.8617, batch=2639/2640]
8
+ Epoch 2 completed. Loss: 2.8617
9
+ New best loss! Saving model...
10
+ Epoch 3/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:07<00:00, 14.05batch/s, loss=1.9555, batch=2639/2640]
11
+ Epoch 3 completed. Loss: 1.9555
12
+ New best loss! Saving model...
13
+ Epoch 4/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:07<00:00, 14.05batch/s, loss=1.2269, batch=2639/2640]
14
+ Epoch 4 completed. Loss: 1.2269
15
+ New best loss! Saving model...
16
+ Epoch 5/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.00batch/s, loss=0.7809, batch=2639/2640]
17
+ Epoch 5 completed. Loss: 0.7809
18
+ New best loss! Saving model...
19
+ Training completed. Best loss: 0.7809
20
+ tensor(1.1159, device='cuda:0', grad_fn=<NllLossBackward0>)
21
+
22
+ /home/ubuntu/S12/train_get2-8-init.py:236: FutureWarning: You are using `torch.load` with `weights_only=False`...)
23
+
24
+ loaded 338025 tokens
25
+ 1 epoch = 2640 batches
26
+ Epoch 1/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:06<00:00, 14.17batch/s, loss=0.8081, batch=2639/2640]
27
+ Epoch 1 completed. Loss: 0.8081
28
+ New best loss! Saving model...
29
+ Epoch 2/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:07<00:00, 14.06batch/s, loss=0.5428, batch=2639/2640]
30
+ Epoch 2 completed. Loss: 0.5428
31
+ New best loss! Saving model...
32
+ Epoch 3/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.4627, batch=2639/2640]
33
+ Epoch 3 completed. Loss: 0.4627
34
+ New best loss! Saving model...
35
+ Epoch 4/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 13.99batch/s, loss=0.3558, batch=2639/2640]
36
+ Epoch 4 completed. Loss: 0.3558
37
+ New best loss! Saving model...
38
+ Epoch 5/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 13.99batch/s, loss=0.3293, batch=2639/2640]
39
+ Epoch 5 completed. Loss: 0.3293
40
+ New best loss! Saving model...
41
+ Epoch 6/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 13.99batch/s, loss=0.2828, batch=2639/2640]
42
+ Epoch 6 completed. Loss: 0.2828
43
+ New best loss! Saving model...
44
+ Epoch 7/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.2875, batch=2639/2640]
45
+ Epoch 7 completed. Loss: 0.2875
46
+ Epoch 8/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.01batch/s, loss=0.2383, batch=2639/2640]
47
+ Epoch 8 completed. Loss: 0.2383
48
+ New best loss! Saving model...
49
+ Epoch 9/15: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.2269, batch=2639/2640]
50
+ Epoch 9 completed. Loss: 0.2269
51
+ New best loss! Saving model...
52
+ Epoch 10/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.2363, batch=2639/2640]
53
+ Epoch 10 completed. Loss: 0.2363
54
+ Epoch 11/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.03batch/s, loss=0.2088, batch=2639/2640]
55
+ Epoch 11 completed. Loss: 0.2088
56
+ New best loss! Saving model...
57
+ Epoch 12/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.1982, batch=2639/2640]
58
+ Epoch 12 completed. Loss: 0.1982
59
+ New best loss! Saving model...
60
+ Epoch 13/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.03batch/s, loss=0.2130, batch=2639/2640]
61
+ Epoch 13 completed. Loss: 0.2130
62
+ Epoch 14/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.02batch/s, loss=0.2038, batch=2639/2640]
63
+ Epoch 14 completed. Loss: 0.2038
64
+ Epoch 15/15: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2640/2640 [03:08<00:00, 14.01batch/s, loss=0.2160, batch=2639/2640]
65
+ Epoch 15 completed. Loss: 0.2160
66
+ Training completed. Best loss: 0.1982
67
+
68
+ Increased max. tokens in a sequence to 64.
69
+
70
+ using device: cuda
71
+ /home/ubuntu/S12/train_get2-8-init.py:236: FutureWarning: You are using `torch.load` with `weights_only=False`...)
72
+
73
+ loaded 338025 tokens
74
+ 1 epoch = 1320 batches
75
+ Epoch 1/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320/1320 [01:59<00:00, 11.01batch/s, loss=0.2214, batch=1319/1320]
76
+ Epoch 1 completed. Loss: 0.2214
77
+ New best loss! Saving model...
78
+ Epoch 2/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320/1320 [02:00<00:00, 10.95batch/s, loss=0.1412, batch=1319/1320]
79
+ Epoch 2 completed. Loss: 0.1412
80
+ New best loss! Saving model...
81
+ Epoch 3/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320/1320 [02:00<00:00, 10.95batch/s, loss=0.1462, batch=1319/1320]
82
+ Epoch 3 completed. Loss: 0.1462
83
+ Epoch 4/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320/1320 [02:00<00:00, 10.94batch/s, loss=0.1165, batch=1319/1320]
84
+ Epoch 4 completed. Loss: 0.1165
85
+ New best loss! Saving model...
86
+ Epoch 5/5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1320/1320 [02:00<00:00, 10.94batch/s, loss=0.1074, batch=1319/1320]
87
+ Epoch 5 completed. Loss: 0.1074
88
+ New best loss! Saving model...
89
+ Training completed. Best loss: 0.1074
90
+
91
+ Increased max. tokens in a sequence to 128.
92
+
93
+ using device: cuda
94
+ /home/ubuntu/S12/train_get2-8-init.py:236: FutureWarning: You are using `torch.load` with `weights_only=False`...)
95
+
96
+ loaded 338025 tokens
97
+ 1 epoch = 660 batches
98
+ Epoch 1/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [01:33<00:00, 7.03batch/s, loss=0.1102, batch=659/660]
99
+ Epoch 1 completed. Loss: 0.1102
100
+ New best loss! Saving model...
101
+ Epoch 2/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [01:33<00:00, 7.03batch/s, loss=0.0738, batch=659/660]
102
+ Epoch 2 completed. Loss: 0.0738
103
+ New best loss! Saving model...
104
+ Epoch 3/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [01:33<00:00, 7.02batch/s, loss=0.0526, batch=659/660]
105
+ Epoch 3 completed. Loss: 0.0526
106
+ New best loss! Saving model...
107
+ Epoch 4/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [01:33<00:00, 7.03batch/s, loss=0.0566, batch=659/660]
108
+ Epoch 4 completed. Loss: 0.0566
109
+ Epoch 5/5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 660/660 [01:34<00:00, 7.01batch/s, loss=0.0587, batch=659/660]
110
+ Epoch 5 completed. Loss: 0.0587
111
+ Training completed. Best loss: 0.0526
prompts.txt ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Enter your prompt (or 'quit' to exit): O Brother, Where art thou?
2
+ Max length (default 100): 50
3
+ Number of sequences (default 5):
4
+
5
+ Prompt: O Brother, Where art thou?
6
+
7
+ Generated sequences:
8
+ --------------------------------------------------
9
+
10
+ 1. O Brother, Where art thou?
11
+
12
+ MENENIUS:
13
+ Being moved, what thou shouldst not,
14
+ being so far forthrong'd as the manner is, his wounds
15
+ To the people, beg their stinking breaths.
16
+ --------------------------------------------------
17
+
18
+ 2. O Brother, Where art thou?
19
+
20
+ MENENIUS:
21
+ Being a island, a whip, or a kept-anch!
22
+
23
+ SICINIUS:
24
+ Where is he, hear you?
25
+
26
+ MENENIUS
27
+ --------------------------------------------------
28
+
29
+ 3. O Brother, Where art thou?
30
+
31
+ MENENIUS:
32
+ Being moved, am sure, not against Aufidius.
33
+
34
+ MARCIUS:
35
+ That's my noble master!
36
+ What shall I do? say what;
37
+ --------------------------------------------------
38
+
39
+ 4. O Brother, Where art thou?
40
+
41
+ MENENIUS:
42
+ aw, what am I, that brought thee to this?
43
+
44
+ TITAN:
45
+ A single thing, as I am, that wonders to hear thee speak.
46
+
47
+
48
+ --------------------------------------------------
49
+
50
+ 5. O Brother, Where art thou?
51
+
52
+ CORIOLANUS:
53
+ Thy memory is as theirs.
54
+
55
+ SEBASTIAN:
56
+ I am, as ne'er I heard thee go
57
+ That thou mightst be malap
58
+ --------------------------------------------------
59
+
60
+
61
+
62
+
63
+
64
+
65
+ Enter your prompt (or 'quit' to exit): In Brutus, should we trust?
66
+ Max length (default 100): 50
67
+ Number of sequences (default 5):
68
+
69
+ Prompt: In Brutus, should we trust?
70
+
71
+ Generated sequences:
72
+ --------------------------------------------------
73
+
74
+ 1. In Brutus, should we trust? goes not
75
+ To him or our poorness?
76
+
77
+ LUCENTIO:
78
+ I pray, sir, let's away.
79
+
80
+ MARCIUS:
81
+ Nay, I will give thee
82
+ --------------------------------------------------
83
+
84
+ 2. In Brutus, should we trust? feed and pride.
85
+
86
+ SICINIUS:
87
+ As big as to love and good news the good news
88
+ fellowsies, which did befall'n us,
89
+ That none of us
90
+ --------------------------------------------------
91
+
92
+ 3. In Brutus, should we trust? there's no
93
+ whatsoever you should come toved him o' the air or thine
94
+ with him put it not a accusers, and thus answer'd:
95
+ 'True is it, my incorporate
96
+ --------------------------------------------------
97
+
98
+ 4. In Brutus, should we trust? Well, or I'll lean
99
+ Of Trustwings or when the wars make us.
100
+
101
+ First Officer:
102
+ This isle with Calibans.
103
+
104
+ COMINIUS:
105
+ Spe
106
+ --------------------------------------------------
107
+
108
+ 5. In Brutus, should we trust? folly
109
+ We three are married, but this is put forth,
110
+ Like men born by the flood, and wind-footed rage,
111
+ Then fair Milan's glory moves.
112
+
113
+ MARiners:
114
+
115
+ --------------------------------------------------
116
+
train_get2_8_init.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Solving for residual std scaling issue
2
+ import os
3
+ import math
4
+ import time
5
+ import inspect
6
+ from dataclasses import dataclass
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from tqdm import tqdm
11
+
12
+ class CausalSelfAttention(nn.Module):
13
+
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ assert config.n_embd % config.n_head == 0
17
+ # key, query, value projections for all heads, but in a batch
18
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
19
+ # output projection
20
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
21
+ self.c_proj.NANGPT_SCALE_INIT = 1
22
+ # regularization
23
+ self.n_head = config.n_head
24
+ self.n_embd = config.n_embd
25
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
26
+
27
+ def forward(self, x):
28
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
29
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
30
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
31
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
32
+ qkv = self.c_attn(x)
33
+ q, k, v = qkv.split(self.n_embd, dim=2)
34
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
35
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
36
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
37
+
38
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
39
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
40
+ att = F.softmax(att, dim=-1)
41
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
42
+
43
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
44
+ # output projection
45
+ y = self.c_proj(y)
46
+ return y
47
+
48
+
49
+ class MLP(nn.Module):
50
+
51
+ def __init__(self, config):
52
+ super().__init__()
53
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
54
+ self.gelu = nn.GELU(approximate='tanh')
55
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
56
+ self.c_proj.NANOGPT_SCALE_INIT = 1
57
+
58
+ def forward(self, x):
59
+ x = self.c_fc(x)
60
+ x = self.gelu(x)
61
+ x = self.c_proj(x)
62
+ return x
63
+
64
+ class Block(nn.Module):
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.ln_1 = nn.LayerNorm(config.n_embd)
69
+ self.attn = CausalSelfAttention(config)
70
+ self.ln_2 = nn.LayerNorm(config.n_embd)
71
+ self.mlp = MLP(config)
72
+
73
+ def forward(self, x):
74
+ x = x + self.attn(self.ln_1(x))
75
+ x = x + self.mlp(self.ln_2(x))
76
+ return x
77
+
78
+
79
+ @dataclass
80
+ class GPTConfig:
81
+ block_size: int = 1024 # max sequence length
82
+ vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
83
+ n_layer: int = 12 # number of layers
84
+ n_head: int = 12 # number of heads
85
+ n_embd: int = 768 # embedding dimension
86
+
87
+
88
+ class GPT(nn.Module):
89
+
90
+ def __init__(self, config):
91
+ super().__init__()
92
+ self.config = config
93
+
94
+ self.transformer = nn.ModuleDict(dict(
95
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
96
+ wpe = nn.Embedding(config.block_size, config.n_embd),
97
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
98
+ ln_f = nn.LayerNorm(config.n_embd),
99
+ ))
100
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
101
+
102
+ # weight sharing
103
+ self.transformer.wte.weight = self.lm_head.weight
104
+
105
+ # weight initialization
106
+ self.apply(self._init_weights)
107
+
108
+ def _init_weights(self, module):
109
+ if isinstance(module, nn.Linear):
110
+ std = 0.02
111
+ if hasattr(module, 'NANGPT_SCALE_INIT'):
112
+ std *= (2 * self.config.n_layer) ** -0.5
113
+ torch.nn.init.normal_(module.weight, mean = 0.0, std = std)
114
+ if module.bias is not None:
115
+ torch.nn.init.zeros_(module.bias)
116
+ elif isinstance(module, nn.Embedding):
117
+ torch.nn.init.normal_(module.weight, mean=0.0, std = 0.02)
118
+
119
+
120
+
121
+ def forward(self, idx, targets=None):
122
+ # idx is of shape (B, T)
123
+ B, T = idx.size()
124
+ assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
125
+ # forward the token and posisition embeddings
126
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
127
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
128
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
129
+ x = tok_emb + pos_emb
130
+ # forward the blocks of the transformer
131
+ for block in self.transformer.h:
132
+ x = block(x)
133
+ # forward the final layernorm and the classifier
134
+ x = self.transformer.ln_f(x)
135
+ logits = self.lm_head(x) # (B, T, vocab_size)
136
+ loss = None
137
+ if targets is not None:
138
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
139
+ return logits, loss
140
+
141
+ @classmethod
142
+ def from_pretrained(cls, model_path):
143
+ """Loads pretrained GPT-2 model weights from local directory"""
144
+ from transformers import GPT2LMHeadModel, GPT2Config
145
+ print("loading weights from local path: %s" % model_path)
146
+
147
+ # n_layer, n_head and n_embd are determined from model_type
148
+ config_args = dict(n_layer=12, n_head=12, n_embd=768) # 124M params
149
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
150
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
151
+
152
+ # create a from-scratch initialized minGPT model
153
+ config = GPTConfig(**config_args)
154
+ model = GPT(config)
155
+ sd = model.state_dict()
156
+ sd_keys = sd.keys()
157
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]
158
+
159
+ # init a huggingface/transformers model
160
+ model_hf = GPT2LMHeadModel.from_pretrained(model_path, local_files_only=True)
161
+ sd_hf = model_hf.state_dict()
162
+
163
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
164
+ sd_keys_hf = sd_hf.keys()
165
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
166
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
167
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
168
+
169
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
170
+ for k in sd_keys_hf:
171
+ if any(k.endswith(w) for w in transposed):
172
+ # special treatment for the Conv1D weights we need to transpose
173
+ assert sd_hf[k].shape[::-1] == sd[k].shape
174
+ with torch.no_grad():
175
+ sd[k].copy_(sd_hf[k].t())
176
+ else:
177
+ # vanilla copy over the other parameters
178
+ assert sd_hf[k].shape == sd[k].shape
179
+ with torch.no_grad():
180
+ sd[k].copy_(sd_hf[k])
181
+
182
+ return model
183
+
184
+ device = 'cpu'
185
+ if torch.cuda.is_available():
186
+ device = 'cuda'
187
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
188
+ device = "mps"
189
+ print(f"using device: {device}")
190
+
191
+ # SEED
192
+ torch.manual_seed(1337)
193
+ if torch.cuda.is_available():
194
+ torch.cuda.manual_seed(1337)
195
+
196
+ # STOP
197
+ num_return_sequences = 5
198
+ max_length = 30
199
+
200
+
201
+
202
+ import tiktoken
203
+
204
+ class DataLoaderLite:
205
+ def __init__(self, B, T):
206
+ self.B = B
207
+ self.T = T
208
+
209
+ # at init load tokens from disk and store them in memory
210
+ with open('input.txt', 'r') as f:
211
+ text = f.read()
212
+ enc = tiktoken.get_encoding('gpt2')
213
+ tokens = enc.encode(text)
214
+ self.tokens = torch.tensor(tokens)
215
+ print(f'loaded {len(self.tokens)} tokens')
216
+ print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
217
+
218
+ # state
219
+ self.current_position = 0
220
+
221
+ def next_batch(self):
222
+ B, T = self.B, self.T
223
+ buf = self.tokens[self.current_position: self.current_position + B * T + 1]
224
+ x = (buf[:-1]).view(B, T) # inputs
225
+ y = (buf[1:]).view(B, T) # targets
226
+ # advance the position in the tensor
227
+ self.current_position += B*T
228
+ # if loading the next batch would be out of bounds, reset
229
+ if self.current_position + (B * T + 1) > len(self.tokens):
230
+ self.current_position = 0
231
+ return x, y
232
+
233
+
234
+ # Move all training-specific code inside main
235
+ if __name__ == "__main__":
236
+ # Device setup
237
+ device = 'cpu'
238
+ if torch.cuda.is_available():
239
+ device = 'cuda'
240
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
241
+ device = "mps"
242
+ print(f"using device: {device}")
243
+
244
+ # SEED
245
+ torch.manual_seed(1337)
246
+ if torch.cuda.is_available():
247
+ torch.cuda.manual_seed(1337)
248
+
249
+ # Initialize model and training
250
+ model = GPT(GPTConfig())
251
+ model.load_state_dict(torch.load('best_model.pt'))
252
+ model.to(device)
253
+
254
+ train_loader = DataLoaderLite(B = 4, T = 128)
255
+
256
+ # Training loop
257
+ num_epochs = 5
258
+ optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)
259
+ n_batches = len(train_loader.tokens) // (train_loader.B * train_loader.T)
260
+ best_loss = float('inf')
261
+ running_loss = 0.0
262
+
263
+ for epoch in range(num_epochs):
264
+ # Train for one complete epoch with progress bar
265
+ progress_bar = tqdm(range(n_batches),
266
+ desc=f'Epoch {epoch+1}/{num_epochs}',
267
+ unit='batch')
268
+
269
+ for i in progress_bar:
270
+ x, y = train_loader.next_batch()
271
+ x, y = x.to(device), y.to(device)
272
+ optimizer.zero_grad()
273
+ logits, loss = model(x, y)
274
+ loss.backward()
275
+ optimizer.step()
276
+
277
+ # Update running loss and progress bar
278
+ running_loss = 0.9 * running_loss + 0.1 * loss.item() # Smoothed loss
279
+ progress_bar.set_postfix({
280
+ 'loss': f'{running_loss:.4f}',
281
+ 'batch': f'{i}/{n_batches}'
282
+ })
283
+
284
+ # Print epoch summary
285
+ print(f'Epoch {epoch+1} completed. Loss: {running_loss:.4f}')
286
+
287
+ # Save best model with correct path
288
+ if running_loss < best_loss:
289
+ best_loss = running_loss
290
+ print(f'New best loss! Saving model...')
291
+ save_path = os.path.join('.', 'best_model.pt')
292
+ torch.save(model.state_dict(), save_path)
293
+
294
+ print(f'Training completed. Best loss: {best_loss:.4f}')