mnmnmnmn commited on
Commit
7fc0f78
1 Parent(s): 361f315

Upload 15 files

Browse files
gpt-2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
gpt-2/__pycache__/model.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
gpt-2/dataloader.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import multiprocessing as mp
3
+ import numpy as np
4
+ import tiktoken
5
+ from datasets import load_dataset # pip install datasets
6
+ from tqdm import tqdm # pip install tqdm
7
+
8
+ # ------------------------------------------
9
+ local_dir = "edu_fineweb10B"
10
+ remote_name = "sample-10BT"
11
+ shard_size = int(1e8) # 100M tokens per shard, total of 100 shards
12
+
13
+ # create the cache the local directory if it doesn't exist yet
14
+ DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), local_dir)
15
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
16
+
17
+ # download the dataset
18
+ fw = load_dataset("HuggingFaceFW/fineweb-edu", name=remote_name, split="train")
19
+
20
+ # init the tokenizer
21
+ enc = tiktoken.get_encoding("gpt2")
22
+ eot = enc._special_tokens['<|endoftext|>'] # end of text token
23
+ def tokenize(doc):
24
+ # tokenizes a single document and returns a numpy array of uint16 tokens
25
+ tokens = [eot] # the special token delimits all documents
26
+ tokens.extend(enc.encode_ordinary(doc["text"]))
27
+ tokens_np = np.array(tokens)
28
+ assert (0 <= tokens_np).all() and (tokens_np < 2**16).all(), "token dictionary too large for uint16"
29
+ tokens_np_uint16 = tokens_np.astype(np.uint16)
30
+ return tokens_np_uint16
31
+
32
+ def write_datafile(filename, tokens_np):
33
+ np.save(filename, tokens_np)
34
+
35
+ if __name__ == '__main__':
36
+ # tokenize all documents and write output shards, each of shard_size tokens (last shard has remainder)
37
+ nprocs = max(1, os.cpu_count()//2)
38
+ with mp.Pool(nprocs) as pool:
39
+ shard_index = 0
40
+ # preallocate buffer to hold current shard
41
+ all_tokens_np = np.empty((shard_size,), dtype=np.uint16)
42
+ token_count = 0
43
+ progress_bar = None
44
+ for tokens in pool.imap(tokenize, fw, chunksize=16):
45
+
46
+ # is there enough space in the current shard for the new tokens?
47
+ if token_count + len(tokens) < shard_size:
48
+ # simply append tokens to current shard
49
+ all_tokens_np[token_count:token_count+len(tokens)] = tokens
50
+ token_count += len(tokens)
51
+ # update progress bar
52
+ if progress_bar is None:
53
+ progress_bar = tqdm(total=shard_size, unit="tokens", desc=f"Shard {shard_index}")
54
+ progress_bar.update(len(tokens))
55
+ else:
56
+ # write the current shard and start a new one
57
+ split = "val" if shard_index == 0 else "train"
58
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
59
+ # split the document into whatever fits in this shard; the remainder goes to next one
60
+ remainder = shard_size - token_count
61
+ progress_bar.update(remainder)
62
+ all_tokens_np[token_count:token_count+remainder] = tokens[:remainder]
63
+ write_datafile(filename, all_tokens_np)
64
+ shard_index += 1
65
+ progress_bar = None
66
+ # populate the next shard with the leftovers of the current doc
67
+ all_tokens_np[0:len(tokens)-remainder] = tokens[remainder:]
68
+ token_count = len(tokens)-remainder
69
+
70
+ # write any remaining tokens as the last shard
71
+ if token_count != 0:
72
+ split = "val" if shard_index == 0 else "train"
73
+ filename = os.path.join(DATA_CACHE_DIR, f"edufineweb_{split}_{shard_index:06d}")
74
+ write_datafile(filename, all_tokens_np[:token_count])
gpt-2/gpt2.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
gpt-2/gpt2_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d663ee22770b02eb68070f343ebc621f91493e3e8b146e68eca76cbe919a3114
3
+ size 548294034
gpt-2/load_and_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
gpt-2/lossi.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd7fd7be14551deea46f888e47512faef42894322aacbb833f366fbc382cb05f
3
+ size 1362
gpt-2/lossi_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b40324e09dcb246105d63bee020a2bf3791c2bc8e06b6374cc550c9c52c907c
3
+ size 73225
gpt-2/model.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+
3
+ from dataclasses import dataclass
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import inspect
8
+ @dataclass
9
+ class GPTConfig:
10
+ vocab_size: int = 50257
11
+ block_size: int = 1024
12
+ n_layer: int = 12
13
+ n_head: int = 12
14
+ n_embd: int = 768 # = 64 * 12
15
+
16
+ class CausalSelfAttention(nn.Module):
17
+ def __init__(self, config):
18
+ super().__init__()
19
+ assert config.n_embd % config.n_head == 0
20
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
21
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
22
+ self.c_proj.NANOGPT_SCALE_INIT = 1
23
+ self.n_head = config.n_head
24
+ self.n_embd = config.n_embd
25
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
26
+ .view(1, 1, config.block_size, config.block_size))
27
+
28
+ def forward(self, x):
29
+ B, T, C = x.size()
30
+ qkv = self.c_attn(x)
31
+ q, k, v = qkv.split(self.n_embd, dim=2)
32
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
33
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
34
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
35
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
36
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
37
+ y = self.c_proj(y)
38
+ return y
39
+
40
+ class MLP(nn.Module):
41
+ def __init__(self, config):
42
+ super().__init__()
43
+ self.c_fc = nn.Linear(config.n_embd, config.n_embd * 4)
44
+ self.c_proj = nn.Linear(config.n_embd * 4, config.n_embd)
45
+ self.gelu = nn.GELU()
46
+ self.NANOGPT_SCALE_INIT = 1
47
+
48
+ def forward(self, x):
49
+ x = self.gelu(self.c_fc(x))
50
+ x = self.c_proj(x)
51
+ return x
52
+
53
+ class Block(nn.Module):
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.ln_1 = nn.LayerNorm(config.n_embd)
57
+ self.ln_2 = nn.LayerNorm(config.n_embd)
58
+ self.attn = CausalSelfAttention(config)
59
+ self.mlp = MLP(config)
60
+
61
+ def forward(self, x):
62
+ x = x + self.attn(self.ln_1(x))
63
+ x = x + self.mlp(self.ln_2(x))
64
+ return x
65
+
66
+ class GPT(nn.Module):
67
+ def __init__(self, config, master_process):
68
+ super().__init__()
69
+ self.master_process = master_process
70
+ self.config = config
71
+ self.transformer = nn.ModuleDict(dict(
72
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
73
+ wpe = nn.Embedding(config.block_size, config.n_embd),
74
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
75
+ ln_f = nn.LayerNorm(config.n_embd)
76
+ ))
77
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
78
+ self.transformer.wte.weight = self.lm_head.weight
79
+ self.apply(self._init_weights)
80
+ if self.master_process:
81
+ print(f"Model initialized. Model has {sum(p.numel() for p in self.parameters() if p.requires_grad):,} trainable parameters")
82
+
83
+ def _init_weights(self, module):
84
+ if isinstance(module, nn.Linear):
85
+ std = 0.2
86
+ if hasattr(module, 'NANOGPT_SCALE_INIT'):
87
+ std*= (2 * self.config.n_layer)**-0.5
88
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
89
+ if module.bias is not None:
90
+ torch.nn.init.zeros_(module.bias)
91
+ elif isinstance(module, nn.Embedding):
92
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
93
+
94
+ def forward(self, idx, targets=None):
95
+ B, T = idx.size()
96
+ assert T <= self.config.block_size, "Cannot forward, model block size is exhausted."
97
+ pos = torch.arange(0, T, dtype=torch.long, device=idx.device)
98
+ pos_emb = self.transformer.wpe(pos)
99
+ tok_emb = self.transformer.wte(idx)
100
+ x = tok_emb + pos_emb
101
+ for block in self.transformer.h:
102
+ x = block(x)
103
+ x = self.transformer.ln_f(x)
104
+ logits = self.lm_head(x)
105
+ loss = None
106
+ if targets is not None:
107
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
108
+ return logits, loss
109
+
110
+ def configure_optimizers(self, weight_decay, learning_rate, device):
111
+ param_dict = {pn: p for pn, p in self.named_parameters()}
112
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
113
+
114
+ decay_params = [p for n, p in param_dict.items() if p.dim() >=2]
115
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
116
+ optim_groups = [
117
+ {"params": decay_params, "weight_decay": weight_decay},
118
+ {"params": nodecay_params, "weight_decay": 0.0},
119
+ ]
120
+ num_decay_params = sum(p.numel() for p in decay_params)
121
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
122
+ if self.master_process:
123
+ print(f"Number of decay parameters tensors: {len(decay_params)}, Number of decay parameters: {num_decay_params:,}")
124
+ print(f"Number of no decay parameters tensors: {len(nodecay_params)}, Number of no decay parameters: {num_nodecay_params:,}")
125
+
126
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
127
+ use_fused = fused_available and 'cuda' == device
128
+ if self.master_process:
129
+ print(f'Using {"fused" if use_fused else "unfused"} AdamW')
130
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8)
131
+ return optimizer
gpt-2/tinyshakespeare.txt ADDED
The diff for this file is too large to render. See raw diff
 
gpt-2/training_full_dataset.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model import GPT, GPTConfig
6
+ import tiktoken
7
+ from torch.utils.data import Dataset, DataLoader, DistributedSampler
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+ import torch.distributed as dist
13
+ import os
14
+ import signal
15
+ import sys
16
+ import numpy as np
17
+ import time
18
+ import logging
19
+
20
+ def seconds_to_hms(seconds):
21
+ return time.strftime('%H:%M:%S', time.gmtime(seconds))
22
+
23
+
24
+ def signal_handler(sig, frame):
25
+ print('Gracefully stopping the training process')
26
+ destroy_process_group()
27
+ sys.exit(0)
28
+
29
+ signal.signal(signal.SIGINT, signal_handler)
30
+ manual_seed = 1339
31
+ torch.manual_seed(manual_seed)
32
+ if torch.cuda.is_available():
33
+ torch.cuda.manual_seed(manual_seed)
34
+
35
+ # ***************************#
36
+ # Device Configuration
37
+ # ***************************#
38
+ device = torch.device("cpu")
39
+ if torch.cuda.is_available():
40
+ device = torch.device("cuda")
41
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
42
+ device = torch.device("mps")
43
+
44
+ print("Using device:", device)
45
+
46
+ # ***************************#
47
+ # Tokenizer Setup
48
+ # ***************************#
49
+ enc = tiktoken.get_encoding('gpt2')
50
+
51
+
52
+ lossi = []
53
+ val_lossi = []
54
+
55
+ # ***************************#
56
+ # Load Text Data
57
+ # ***************************#
58
+ with open("tinyshakespeare.txt", "r") as f:
59
+ text = f.read()
60
+ tokens = enc.encode(text)
61
+ print(f"Number of tokens: {len(tokens):,}")
62
+ # ***************************#
63
+ # Set up DDP
64
+ # ***************************#
65
+ # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
66
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
67
+ if ddp:
68
+ # use of DDP atm demands CUDA, we set the device appropriately according to rank
69
+ assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
70
+ init_process_group(backend='nccl')
71
+ ddp_rank = int(os.environ['RANK'])
72
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
73
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
74
+ device = f'cuda:{ddp_local_rank}'
75
+ torch.cuda.set_device(device)
76
+ # this process will do logging, checkpointing etc.
77
+ master_process = ddp_rank == 0
78
+ else:
79
+ # vanilla, non-DDP run
80
+ ddp_rank = 0
81
+ ddp_local_rank = 0
82
+ ddp_world_size = 1
83
+ master_process = True
84
+
85
+ if master_process:
86
+ print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}")
87
+
88
+ # ***************************#
89
+ # Model Configuration
90
+ # ***************************#
91
+
92
+ gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device)
93
+ if device == torch.device("cuda"):
94
+ gpt.compile()
95
+ if ddp:
96
+ gpt = DDP(gpt, device_ids=[ddp_local_rank])
97
+
98
+ raw_gpt = gpt.module if ddp else gpt
99
+
100
+ # ***************************#
101
+ # Dataset and Dataloader
102
+ # ***************************#
103
+
104
+ def load_tokens(filename):
105
+ npt = np.load(filename)
106
+ npt = npt.astype(np.int32) # added after video
107
+ ptt = torch.tensor(npt, dtype=torch.long)
108
+ return ptt
109
+
110
+ class DataLoader_Custom:
111
+ def __init__(self, B, T, process_rank, num_processes, split, shuffle=False):
112
+ self.B = B
113
+ self.T = T
114
+ self.process_rank = process_rank
115
+ self.num_processes = num_processes
116
+ self.shuffle = shuffle
117
+ assert split in ["train", "val"]
118
+
119
+ data_root = "edu_fineweb10B"
120
+ shards = os.listdir(data_root)
121
+ shards = [s for s in shards if split in s]
122
+ shards = sorted(shards)
123
+ shards = [os.path.join(data_root, s) for s in shards]
124
+ self.shards = shards
125
+ assert len(shards) > 0, "No shards found for split {}".format(split)
126
+ if master_process:
127
+ print("Found {} shards for split {}".format(len(shards), split))
128
+ self.current_shard = 0
129
+ self.tokens = load_tokens(self.shards[self.current_shard])
130
+ self.current_position = self.B * self.T * self.process_rank
131
+
132
+ def next_batch(self):
133
+ B, T = self.B, self.T
134
+ buf = self.tokens[self.current_position:self.current_position + B*T+1]
135
+ x = buf[:-1].view(B, T)
136
+ y = buf[1:].view(B, T)
137
+ self.current_position += B*T * self.num_processes
138
+ if self.current_position + (B*T*self.num_processes+1) > len(self.tokens):
139
+ self.current_shard = self.current_shard + 1 % len(self.shards)
140
+ self.tokens = load_tokens(self.shards[self.current_shard])
141
+ self.current_position = self.B * self.T * self.process_rank
142
+
143
+ return x, y
144
+
145
+ def reset(self):
146
+ self.current_shard = 0
147
+ self.tokens = load_tokens(self.shards[self.current_shard])
148
+ self.current_position = self.B * self.T * self.process_rank
149
+
150
+ T = 4
151
+ batch_size = 1
152
+ total_batch_size = 2**2 # 524,288 = 2**19, in number of tokens
153
+ assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T"
154
+ grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size)
155
+
156
+ if master_process:
157
+ print("Total desired batch size: {:,}".format(total_batch_size))
158
+ print("gradient accumulation steps: {:,}".format(grad_accum_steps))
159
+
160
+ train_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "train")
161
+ val_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "val")
162
+
163
+
164
+ # ***************************#
165
+ # Text Generation Function
166
+ # ***************************#
167
+
168
+
169
+ def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True):
170
+ if print_while_generating:
171
+ print(seed_text, end="")
172
+ model.eval()
173
+ with torch.no_grad():
174
+ tokens = enc.encode(seed_text)
175
+ for _ in range(max_len):
176
+ x = torch.tensor(tokens[-T:], dtype=torch.long,
177
+ device=device).unsqueeze(0)
178
+ logits, _ = model(x)
179
+ next_token = torch.argmax(logits[:, -1, :])
180
+ tokens.append(int(next_token))
181
+
182
+ if print_while_generating:
183
+ print(enc.decode([int(next_token)]), end="")
184
+ print()
185
+
186
+ return enc.decode(tokens)
187
+
188
+
189
+ # ***************************#
190
+ # Optimizer Configuration
191
+ # ***************************#
192
+ if ddp:
193
+ optimizer = raw_gpt.configure_optimizers(
194
+ weight_decay=0.1, learning_rate=6e-4, device=device)
195
+ else:
196
+ optimizer = gpt.configure_optimizers(
197
+ weight_decay=0.1, learning_rate=6e-4, device=device)
198
+ torch.set_float32_matmul_precision('high')
199
+ # ***************************#
200
+ # Learning Rate Scheduler
201
+ # ***************************#
202
+ max_lr = 6e-4
203
+ min_lr = max_lr * 0.1
204
+ warmup_steps = 715
205
+ max_steps = 50
206
+
207
+
208
+ def get_lr(step):
209
+ if step < warmup_steps:
210
+ return max_lr * (step+1) / warmup_steps
211
+ if step > max_steps:
212
+ return min_lr
213
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
214
+ assert 0 <= decay_ratio <= 1
215
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
216
+ return min_lr + coeff * (max_lr - min_lr)
217
+
218
+
219
+ # Check if the device supports bfloat16
220
+ supports_bfloat16 = False
221
+ if device == "cuda":
222
+ capability = torch.cuda.get_device_capability()
223
+ if capability[0] >= 8 and capability[1] >= 0:
224
+ supports_bfloat16 = True
225
+
226
+ print("Supports bfloat16:", supports_bfloat16)
227
+
228
+ # ***************************#
229
+ # Training Loop
230
+ # ***************************#
231
+
232
+ generate_every = 50
233
+ validate_every = 10
234
+ save_every = 5
235
+ t0 = time.time()
236
+
237
+ # Initialize logging
238
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
239
+ logger = logging.getLogger(__name__)
240
+
241
+ # Add a file handler
242
+ file_handler = logging.FileHandler('training_log.txt')
243
+ file_handler.setLevel(logging.INFO)
244
+ file_handler.setFormatter(logging.Formatter('%(message)s'))
245
+ logger.addHandler(file_handler)
246
+
247
+ for step in range(max_steps):
248
+
249
+ loss_accum = 0.0
250
+ gpt.zero_grad()
251
+ for minibatchstep in range(grad_accum_steps):
252
+ x, y = train_dataloader.next_batch()
253
+ x, y = x.to(device), y.to(device)
254
+
255
+ if supports_bfloat16:
256
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
257
+ logits, loss = gpt(x, y)
258
+ else:
259
+ logits, loss = gpt(x, y)
260
+
261
+ loss = loss / grad_accum_steps
262
+ loss_accum += loss.detach()
263
+ if ddp:
264
+ gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1)
265
+ loss.backward()
266
+
267
+ if ddp:
268
+ dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
269
+ lossi.append(loss_accum.item())
270
+ norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
271
+ lr = get_lr(step)
272
+ for param_group in optimizer.param_groups:
273
+ param_group['lr'] = lr
274
+ optimizer.step()
275
+ t_current = time.time()
276
+ elapsed_time = t_current - t0
277
+ steps_completed = step + 1
278
+ avg_time_per_step = elapsed_time / steps_completed
279
+ remaining_steps = max_steps - steps_completed
280
+ remaining_time = remaining_steps * avg_time_per_step
281
+
282
+ if master_process:
283
+ logger.info(f'Step {step} | Loss: {loss_accum:.6f} | Norm: {norm:.4f} | LR: {lr:.2e} | Time: {seconds_to_hms(elapsed_time)} | Remaining: {seconds_to_hms(remaining_time)} | Avg Time/Step: {avg_time_per_step:.2f}')
284
+
285
+ if master_process and step % generate_every == 0:
286
+ generated_text = generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)
287
+ logger.info(f'Generated Text at Step {step}: {generated_text}')
288
+
289
+ # Validation step
290
+ if step % validate_every == 0:
291
+ if master_process:
292
+ logger.info("Validating...")
293
+ gpt.eval()
294
+ val_loss_accum = 0.0
295
+ val_dataloader.reset()
296
+ with torch.no_grad():
297
+ val_loss_accum
298
+ val_loss_steps = 20
299
+ for _ in range(val_loss_steps):
300
+ x, y = val_dataloader.next_batch()
301
+ x, y = x.to(device), y.to(device)
302
+ if supports_bfloat16:
303
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
304
+ val_logits, val_loss = gpt(x, y)
305
+ else:
306
+ val_logits, val_loss = gpt(x, y)
307
+ val_loss = val_loss / val_loss_steps
308
+ val_loss_accum += val_loss.detach()
309
+ if ddp:
310
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
311
+ if master_process:
312
+ logger.info(f'Validation Loss: {val_loss_accum}')
313
+ val_lossi.append(val_loss_accum.item())
314
+
315
+ if step % save_every == 0 and master_process:
316
+ print("Saving model and loss...")
317
+ torch.save(raw_gpt.state_dict(), "gpt2_step_{}.pth".format(step))
318
+ torch.save(torch.tensor(lossi), "lossi_step_{}.pth".format(step))
319
+ torch.save(torch.tensor(val_lossi), "val_lossi_step_{}.pth".format(step))
320
+
321
+ # ***************************#
322
+ # Plot Loss
323
+ # ***************************#
324
+
325
+ plot = True
326
+ if master_process and plot:
327
+ plt.plot(lossi, label="Train Loss")
328
+
329
+ # Stretch val_lossi to match the length of lossi
330
+ val_lossi_stretched = np.interp(
331
+ np.linspace(0, len(val_lossi) - 1, len(lossi)),
332
+ np.arange(len(val_lossi)),
333
+ val_lossi
334
+ )
335
+
336
+ plt.plot(val_lossi_stretched, label="Validation Loss")
337
+ plt.legend()
338
+ plt.xlabel("Step")
339
+ plt.ylabel("Loss")
340
+
341
+ plt.show()
342
+
343
+ # Generate Final Text
344
+ if master_process:
345
+ print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False))
346
+
347
+ # ***************************#
348
+ # Save Model and Loss
349
+ # ***************************#
350
+ if master_process:
351
+ torch.save(gpt.state_dict(), "gpt2_shakespeare.pth")
352
+ torch.save(torch.tensor(lossi), "lossi.pth")
353
+ torch.save(torch.tensor(val_lossi), "val_lossi.pth")
354
+
355
+ # ***************************#
356
+ # Cleanup
357
+ # ***************************#
358
+ if ddp:
359
+ destroy_process_group()
360
+
361
+
362
+ import sys; sys.exit(0)
gpt-2/training_log.txt ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step 0 | Loss: 11.070963 | Norm: 48.8176 | LR: 8.39e-07 | Time: 00:00:02 | Remaining: 00:01:39 | Avg Time/Step: 2.03
2
+ Generated Text at Step 0: The king saidSeptemberSeptember 354 Fill ShameLots may>>>>>>>>umpyurry Apex nurses NEWS159 Vanguard FlemingictionTAJul Jihad LAR $\ underjri Columb
3
+ Validating...
4
+ Validation Loss: 10.916313171386719
5
+ Step 1 | Loss: 11.171237 | Norm: 45.8637 | LR: 1.68e-06 | Time: 00:00:04 | Remaining: 00:01:41 | Avg Time/Step: 2.12
6
+ Step 2 | Loss: 11.089214 | Norm: 49.5361 | LR: 2.52e-06 | Time: 00:00:04 | Remaining: 00:01:10 | Avg Time/Step: 1.50
7
+ Validating...
8
+ Validation Loss: 10.893363952636719
9
+ Step 3 | Loss: 10.763819 | Norm: 52.8166 | LR: 3.36e-06 | Time: 00:00:05 | Remaining: 00:01:05 | Avg Time/Step: 1.41
10
+ Step 4 | Loss: 11.204582 | Norm: 47.4927 | LR: 4.20e-06 | Time: 00:00:06 | Remaining: 00:00:54 | Avg Time/Step: 1.21
11
+ Validating...
12
+ Validation Loss: 10.86690902709961
13
+ Step 5 | Loss: 10.957478 | Norm: 41.5032 | LR: 5.03e-06 | Time: 00:00:07 | Remaining: 00:00:58 | Avg Time/Step: 1.32
14
+ Step 6 | Loss: 10.586459 | Norm: 43.5531 | LR: 5.87e-06 | Time: 00:00:08 | Remaining: 00:00:50 | Avg Time/Step: 1.18
15
+ Validating...
16
+ Validation Loss: 10.835768699645996
17
+ Step 7 | Loss: 11.205253 | Norm: 44.9156 | LR: 6.71e-06 | Time: 00:00:09 | Remaining: 00:00:50 | Avg Time/Step: 1.20
18
+ Step 8 | Loss: 10.609798 | Norm: 48.2627 | LR: 7.55e-06 | Time: 00:00:09 | Remaining: 00:00:44 | Avg Time/Step: 1.10
19
+ Validating...
20
+ Validation Loss: 10.792684555053711
21
+ Step 9 | Loss: 9.896498 | Norm: 43.1797 | LR: 8.39e-06 | Time: 00:00:11 | Remaining: 00:00:44 | Avg Time/Step: 1.12
22
+ Step 10 | Loss: 11.131380 | Norm: 44.4814 | LR: 9.23e-06 | Time: 00:00:11 | Remaining: 00:00:40 | Avg Time/Step: 1.04
23
+ Validating...
24
+ Validation Loss: 10.749573707580566
25
+ Step 11 | Loss: 10.463729 | Norm: 47.8602 | LR: 1.01e-05 | Time: 00:00:12 | Remaining: 00:00:40 | Avg Time/Step: 1.06
26
+ Step 12 | Loss: 10.880756 | Norm: 43.9313 | LR: 1.09e-05 | Time: 00:00:12 | Remaining: 00:00:36 | Avg Time/Step: 1.00
27
+ Validating...
28
+ Validation Loss: 10.712495803833008
29
+ Step 13 | Loss: 9.864075 | Norm: 42.5331 | LR: 1.17e-05 | Time: 00:00:14 | Remaining: 00:00:36 | Avg Time/Step: 1.01
30
+ Step 14 | Loss: 10.922160 | Norm: 44.6511 | LR: 1.26e-05 | Time: 00:00:14 | Remaining: 00:00:33 | Avg Time/Step: 0.96
31
+ Validating...
32
+ Validation Loss: 10.67584228515625
33
+ Step 15 | Loss: 10.775851 | Norm: 44.4024 | LR: 1.34e-05 | Time: 00:00:15 | Remaining: 00:00:33 | Avg Time/Step: 0.98
34
+ Step 16 | Loss: 10.330193 | Norm: 43.8886 | LR: 1.43e-05 | Time: 00:00:15 | Remaining: 00:00:30 | Avg Time/Step: 0.94
35
+ Validating...
36
+ Validation Loss: 10.615331649780273
37
+ Step 17 | Loss: 10.270191 | Norm: 44.5217 | LR: 1.51e-05 | Time: 00:00:17 | Remaining: 00:00:30 | Avg Time/Step: 0.97
38
+ Step 18 | Loss: 10.027596 | Norm: 46.1209 | LR: 1.59e-05 | Time: 00:00:17 | Remaining: 00:00:28 | Avg Time/Step: 0.93
39
+ Validating...
40
+ Validation Loss: 10.553497314453125
41
+ Step 19 | Loss: 10.182181 | Norm: 40.7514 | LR: 1.68e-05 | Time: 00:00:19 | Remaining: 00:00:28 | Avg Time/Step: 0.96
42
+ Step 20 | Loss: 9.555431 | Norm: 34.3714 | LR: 1.76e-05 | Time: 00:00:19 | Remaining: 00:00:26 | Avg Time/Step: 0.93
43
+ Validating...
44
+ Validation Loss: 10.458913803100586
45
+ Step 21 | Loss: 10.136066 | Norm: 35.4013 | LR: 1.85e-05 | Time: 00:00:20 | Remaining: 00:00:26 | Avg Time/Step: 0.95
46
+ Step 22 | Loss: 10.260824 | Norm: 35.9827 | LR: 1.93e-05 | Time: 00:00:21 | Remaining: 00:00:25 | Avg Time/Step: 0.93
47
+ Validating...
48
+ Validation Loss: 10.345619201660156
49
+ Step 23 | Loss: 9.837000 | Norm: 34.4205 | LR: 2.01e-05 | Time: 00:00:22 | Remaining: 00:00:24 | Avg Time/Step: 0.94
50
+ Step 24 | Loss: 10.418470 | Norm: 35.1306 | LR: 2.10e-05 | Time: 00:00:22 | Remaining: 00:00:22 | Avg Time/Step: 0.92
51
+ Validating...
52
+ Validation Loss: 10.242090225219727
53
+ Step 25 | Loss: 10.759716 | Norm: 34.3984 | LR: 2.18e-05 | Time: 00:00:24 | Remaining: 00:00:22 | Avg Time/Step: 0.93
54
+ Step 26 | Loss: 10.433059 | Norm: 33.6258 | LR: 2.27e-05 | Time: 00:00:24 | Remaining: 00:00:20 | Avg Time/Step: 0.91
55
+ Validating...
56
+ Validation Loss: 10.15864372253418
57
+ Step 27 | Loss: 11.198073 | Norm: 33.6489 | LR: 2.35e-05 | Time: 00:00:25 | Remaining: 00:00:20 | Avg Time/Step: 0.93
58
+ Step 28 | Loss: 9.453720 | Norm: 30.4983 | LR: 2.43e-05 | Time: 00:00:26 | Remaining: 00:00:18 | Avg Time/Step: 0.90
59
+ Validating...
60
+ Validation Loss: 10.089692115783691
61
+ Step 29 | Loss: 10.043849 | Norm: 30.8429 | LR: 2.52e-05 | Time: 00:00:27 | Remaining: 00:00:18 | Avg Time/Step: 0.92
62
+ Step 30 | Loss: 10.345837 | Norm: 28.3254 | LR: 2.60e-05 | Time: 00:00:27 | Remaining: 00:00:17 | Avg Time/Step: 0.90
63
+ Validating...
64
+ Validation Loss: 10.014737129211426
65
+ Step 31 | Loss: 9.762772 | Norm: 28.7018 | LR: 2.69e-05 | Time: 00:00:29 | Remaining: 00:00:16 | Avg Time/Step: 0.92
66
+ Step 32 | Loss: 9.099653 | Norm: 28.1757 | LR: 2.77e-05 | Time: 00:00:30 | Remaining: 00:00:15 | Avg Time/Step: 0.91
67
+ Validating...
68
+ Validation Loss: 9.956048011779785
69
+ Step 33 | Loss: 8.908812 | Norm: 25.8786 | LR: 2.85e-05 | Time: 00:00:31 | Remaining: 00:00:14 | Avg Time/Step: 0.92
70
+ Step 34 | Loss: 10.699462 | Norm: 25.3921 | LR: 2.94e-05 | Time: 00:00:31 | Remaining: 00:00:13 | Avg Time/Step: 0.90
71
+ Validating...
72
+ Validation Loss: 9.902624130249023
73
+ Step 35 | Loss: 9.239347 | Norm: 25.3455 | LR: 3.02e-05 | Time: 00:00:33 | Remaining: 00:00:12 | Avg Time/Step: 0.92
74
+ Step 36 | Loss: 10.142147 | Norm: 24.3786 | LR: 3.10e-05 | Time: 00:00:33 | Remaining: 00:00:11 | Avg Time/Step: 0.90
75
+ Validating...
76
+ Validation Loss: 9.841948509216309
77
+ Step 37 | Loss: 10.260188 | Norm: 23.3623 | LR: 3.19e-05 | Time: 00:00:34 | Remaining: 00:00:10 | Avg Time/Step: 0.91
78
+ Step 38 | Loss: 9.482347 | Norm: 24.0785 | LR: 3.27e-05 | Time: 00:00:35 | Remaining: 00:00:09 | Avg Time/Step: 0.90
79
+ Validating...
80
+ Validation Loss: 9.79233169555664
81
+ Step 39 | Loss: 8.717162 | Norm: 23.1963 | LR: 3.36e-05 | Time: 00:00:36 | Remaining: 00:00:09 | Avg Time/Step: 0.91
82
+ Step 40 | Loss: 9.536521 | Norm: 21.8829 | LR: 3.44e-05 | Time: 00:00:36 | Remaining: 00:00:08 | Avg Time/Step: 0.89
83
+ Validating...
84
+ Validation Loss: 9.746158599853516
85
+ Step 41 | Loss: 9.760999 | Norm: 21.4380 | LR: 3.52e-05 | Time: 00:00:38 | Remaining: 00:00:07 | Avg Time/Step: 0.91
86
+ Step 42 | Loss: 9.588884 | Norm: 22.2327 | LR: 3.61e-05 | Time: 00:00:38 | Remaining: 00:00:06 | Avg Time/Step: 0.89
87
+ Validating...
88
+ Validation Loss: 9.688400268554688
89
+ Step 43 | Loss: 8.350541 | Norm: 20.6459 | LR: 3.69e-05 | Time: 00:00:39 | Remaining: 00:00:05 | Avg Time/Step: 0.90
90
+ Step 44 | Loss: 9.594240 | Norm: 20.0493 | LR: 3.78e-05 | Time: 00:00:39 | Remaining: 00:00:04 | Avg Time/Step: 0.89
91
+ Validating...
92
+ Validation Loss: 9.622390747070312
93
+ Step 45 | Loss: 8.240631 | Norm: 20.1186 | LR: 3.86e-05 | Time: 00:00:41 | Remaining: 00:00:03 | Avg Time/Step: 0.90
94
+ Step 46 | Loss: 8.915052 | Norm: 20.4390 | LR: 3.94e-05 | Time: 00:00:41 | Remaining: 00:00:02 | Avg Time/Step: 0.88
95
+ Validating...
96
+ Validation Loss: 9.558349609375
97
+ Step 47 | Loss: 8.285755 | Norm: 20.3787 | LR: 4.03e-05 | Time: 00:00:43 | Remaining: 00:00:01 | Avg Time/Step: 0.90
98
+ Step 48 | Loss: 8.551549 | Norm: 20.1920 | LR: 4.11e-05 | Time: 00:00:43 | Remaining: 00:00:00 | Avg Time/Step: 0.89
99
+ Validating...
100
+ Validation Loss: 9.461584091186523
101
+ Step 49 | Loss: 9.774352 | Norm: 20.2260 | LR: 4.20e-05 | Time: 00:00:45 | Remaining: 00:00:00 | Avg Time/Step: 0.91
102
+ Step 0 | Loss: 11.070963 | Norm: 48.8176 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:46 | Avg Time/Step: 0.95
103
+ Generated Text at Step 0: The king saidSeptemberSeptember 354 Fill ShameLots may>>>>>>>>umpyurry Apex nurses NEWS159 Vanguard FlemingictionTAJul Jihad LAR $\ underjri Columb
104
+ Validating...
105
+ Validation Loss: 10.916313171386719
106
+ Step 1 | Loss: 11.171237 | Norm: 45.8637 | LR: 1.68e-06 | Time: 00:00:03 | Remaining: 00:01:12 | Avg Time/Step: 1.52
107
+ Step 2 | Loss: 11.089214 | Norm: 49.5361 | LR: 2.52e-06 | Time: 00:00:03 | Remaining: 00:00:51 | Avg Time/Step: 1.09
108
+ Validating...
109
+ Validation Loss: 10.893363952636719
110
+ Step 3 | Loss: 10.763819 | Norm: 52.8166 | LR: 3.36e-06 | Time: 00:00:04 | Remaining: 00:00:49 | Avg Time/Step: 1.08
111
+ Step 4 | Loss: 11.204582 | Norm: 47.4927 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:41 | Avg Time/Step: 0.91
112
+ Validating...
113
+ Validation Loss: 10.86690902709961
114
+ Step 5 | Loss: 10.957478 | Norm: 41.5032 | LR: 5.03e-06 | Time: 00:00:05 | Remaining: 00:00:40 | Avg Time/Step: 0.93
115
+ Step 6 | Loss: 10.586459 | Norm: 43.5531 | LR: 5.87e-06 | Time: 00:00:05 | Remaining: 00:00:35 | Avg Time/Step: 0.83
116
+ Validating...
117
+ Validation Loss: 10.835768699645996
118
+ Step 7 | Loss: 11.205253 | Norm: 44.9156 | LR: 6.71e-06 | Time: 00:00:07 | Remaining: 00:00:37 | Avg Time/Step: 0.89
119
+ Step 8 | Loss: 10.609798 | Norm: 48.2627 | LR: 7.55e-06 | Time: 00:00:07 | Remaining: 00:00:33 | Avg Time/Step: 0.82
120
+ Validating...
121
+ Validation Loss: 10.792684555053711
122
+ Step 9 | Loss: 9.896498 | Norm: 43.1797 | LR: 8.39e-06 | Time: 00:00:08 | Remaining: 00:00:35 | Avg Time/Step: 0.88
123
+ Step 10 | Loss: 11.131380 | Norm: 44.4814 | LR: 9.23e-06 | Time: 00:00:09 | Remaining: 00:00:31 | Avg Time/Step: 0.82
124
+ Validating...
125
+ Validation Loss: 10.749573707580566
126
+ Step 11 | Loss: 10.463729 | Norm: 47.8602 | LR: 1.01e-05 | Time: 00:00:10 | Remaining: 00:00:32 | Avg Time/Step: 0.86
127
+ Step 12 | Loss: 10.880756 | Norm: 43.9313 | LR: 1.09e-05 | Time: 00:00:10 | Remaining: 00:00:30 | Avg Time/Step: 0.81
128
+ Validating...
129
+ Validation Loss: 10.712495803833008
130
+ Step 13 | Loss: 9.864075 | Norm: 42.5331 | LR: 1.17e-05 | Time: 00:00:12 | Remaining: 00:00:30 | Avg Time/Step: 0.86
131
+ Step 14 | Loss: 10.922160 | Norm: 44.6511 | LR: 1.26e-05 | Time: 00:00:12 | Remaining: 00:00:28 | Avg Time/Step: 0.82
132
+ Validating...
133
+ Validation Loss: 10.67584228515625
134
+ Step 15 | Loss: 10.775851 | Norm: 44.4024 | LR: 1.34e-05 | Time: 00:00:13 | Remaining: 00:00:28 | Avg Time/Step: 0.85
135
+ Step 16 | Loss: 10.330193 | Norm: 43.8886 | LR: 1.43e-05 | Time: 00:00:13 | Remaining: 00:00:26 | Avg Time/Step: 0.81
136
+ Validating...
137
+ Validation Loss: 10.615331649780273
138
+ Step 17 | Loss: 10.270191 | Norm: 44.5217 | LR: 1.51e-05 | Time: 00:00:15 | Remaining: 00:00:26 | Avg Time/Step: 0.84
139
+ Step 18 | Loss: 10.027596 | Norm: 46.1209 | LR: 1.59e-05 | Time: 00:00:15 | Remaining: 00:00:25 | Avg Time/Step: 0.81
140
+ Validating...
141
+ Validation Loss: 10.553497314453125
142
+ Step 19 | Loss: 10.182181 | Norm: 40.7514 | LR: 1.68e-05 | Time: 00:00:16 | Remaining: 00:00:25 | Avg Time/Step: 0.84
143
+ Step 20 | Loss: 9.555431 | Norm: 34.3714 | LR: 1.76e-05 | Time: 00:00:16 | Remaining: 00:00:23 | Avg Time/Step: 0.81
144
+ Validating...
145
+ Validation Loss: 10.458913803100586
146
+ Step 21 | Loss: 10.136066 | Norm: 35.4013 | LR: 1.85e-05 | Time: 00:00:18 | Remaining: 00:00:23 | Avg Time/Step: 0.83
147
+ Step 22 | Loss: 10.260824 | Norm: 35.9827 | LR: 1.93e-05 | Time: 00:00:18 | Remaining: 00:00:21 | Avg Time/Step: 0.80
148
+ Validating...
149
+ Validation Loss: 10.345619201660156
150
+ Step 23 | Loss: 9.837000 | Norm: 34.4205 | LR: 2.01e-05 | Time: 00:00:19 | Remaining: 00:00:21 | Avg Time/Step: 0.82
151
+ Step 24 | Loss: 10.418470 | Norm: 35.1306 | LR: 2.10e-05 | Time: 00:00:20 | Remaining: 00:00:20 | Avg Time/Step: 0.80
152
+ Validating...
153
+ Validation Loss: 10.242090225219727
154
+ Step 25 | Loss: 10.759716 | Norm: 34.3984 | LR: 2.18e-05 | Time: 00:00:21 | Remaining: 00:00:19 | Avg Time/Step: 0.82
155
+ Step 26 | Loss: 10.433059 | Norm: 33.6258 | LR: 2.27e-05 | Time: 00:00:21 | Remaining: 00:00:18 | Avg Time/Step: 0.80
156
+ Validating...
157
+ Validation Loss: 10.15864372253418
158
+ Step 27 | Loss: 11.198073 | Norm: 33.6489 | LR: 2.35e-05 | Time: 00:00:22 | Remaining: 00:00:17 | Avg Time/Step: 0.81
159
+ Step 28 | Loss: 9.453720 | Norm: 30.4983 | LR: 2.43e-05 | Time: 00:00:23 | Remaining: 00:00:16 | Avg Time/Step: 0.79
160
+ Validating...
161
+ Validation Loss: 10.089692115783691
162
+ Step 29 | Loss: 10.043849 | Norm: 30.8429 | LR: 2.52e-05 | Time: 00:00:24 | Remaining: 00:00:16 | Avg Time/Step: 0.81
163
+ Step 30 | Loss: 10.345837 | Norm: 28.3254 | LR: 2.60e-05 | Time: 00:00:24 | Remaining: 00:00:15 | Avg Time/Step: 0.79
164
+ Validating...
165
+ Validation Loss: 10.014737129211426
166
+ Step 31 | Loss: 9.762772 | Norm: 28.7018 | LR: 2.69e-05 | Time: 00:00:25 | Remaining: 00:00:14 | Avg Time/Step: 0.81
167
+ Step 32 | Loss: 9.099653 | Norm: 28.1757 | LR: 2.77e-05 | Time: 00:00:26 | Remaining: 00:00:13 | Avg Time/Step: 0.79
168
+ Validating...
169
+ Validation Loss: 9.956048011779785
170
+ Step 33 | Loss: 8.908812 | Norm: 25.8786 | LR: 2.85e-05 | Time: 00:00:27 | Remaining: 00:00:12 | Avg Time/Step: 0.81
171
+ Step 34 | Loss: 10.699462 | Norm: 25.3921 | LR: 2.94e-05 | Time: 00:00:27 | Remaining: 00:00:11 | Avg Time/Step: 0.79
172
+ Validating...
173
+ Validation Loss: 9.902624130249023
174
+ Step 35 | Loss: 9.239347 | Norm: 25.3455 | LR: 3.02e-05 | Time: 00:00:28 | Remaining: 00:00:11 | Avg Time/Step: 0.80
175
+ Step 36 | Loss: 10.142147 | Norm: 24.3786 | LR: 3.10e-05 | Time: 00:00:29 | Remaining: 00:00:10 | Avg Time/Step: 0.79
176
+ Validating...
177
+ Validation Loss: 9.841948509216309
178
+ Step 37 | Loss: 10.260188 | Norm: 23.3623 | LR: 3.19e-05 | Time: 00:00:30 | Remaining: 00:00:09 | Avg Time/Step: 0.80
179
+ Step 38 | Loss: 9.482347 | Norm: 24.0785 | LR: 3.27e-05 | Time: 00:00:30 | Remaining: 00:00:08 | Avg Time/Step: 0.79
180
+ Validating...
181
+ Validation Loss: 9.79233169555664
182
+ Step 39 | Loss: 8.717162 | Norm: 23.1963 | LR: 3.36e-05 | Time: 00:00:32 | Remaining: 00:00:08 | Avg Time/Step: 0.81
183
+ Step 40 | Loss: 9.536521 | Norm: 21.8829 | LR: 3.44e-05 | Time: 00:00:32 | Remaining: 00:00:07 | Avg Time/Step: 0.79
184
+ Validating...
185
+ Validation Loss: 9.746158599853516
186
+ Step 41 | Loss: 9.760999 | Norm: 21.4380 | LR: 3.52e-05 | Time: 00:00:33 | Remaining: 00:00:06 | Avg Time/Step: 0.81
187
+ Step 42 | Loss: 9.588884 | Norm: 22.2327 | LR: 3.61e-05 | Time: 00:00:34 | Remaining: 00:00:05 | Avg Time/Step: 0.79
188
+ Validating...
189
+ Validation Loss: 9.688400268554688
190
+ Step 43 | Loss: 8.350541 | Norm: 20.6459 | LR: 3.69e-05 | Time: 00:00:35 | Remaining: 00:00:04 | Avg Time/Step: 0.81
191
+ Step 44 | Loss: 9.594240 | Norm: 20.0493 | LR: 3.78e-05 | Time: 00:00:36 | Remaining: 00:00:04 | Avg Time/Step: 0.80
192
+ Validating...
193
+ Validation Loss: 9.622390747070312
194
+ Step 45 | Loss: 8.240631 | Norm: 20.1186 | LR: 3.86e-05 | Time: 00:00:37 | Remaining: 00:00:03 | Avg Time/Step: 0.81
195
+ Step 46 | Loss: 8.915052 | Norm: 20.4390 | LR: 3.94e-05 | Time: 00:00:37 | Remaining: 00:00:02 | Avg Time/Step: 0.80
196
+ Validating...
197
+ Validation Loss: 9.558349609375
198
+ Step 47 | Loss: 8.285755 | Norm: 20.3787 | LR: 4.03e-05 | Time: 00:00:39 | Remaining: 00:00:01 | Avg Time/Step: 0.81
199
+ Step 48 | Loss: 8.551549 | Norm: 20.1920 | LR: 4.11e-05 | Time: 00:00:39 | Remaining: 00:00:00 | Avg Time/Step: 0.80
200
+ Validating...
201
+ Validation Loss: 9.461584091186523
202
+ Step 49 | Loss: 9.774352 | Norm: 20.2260 | LR: 4.20e-05 | Time: 00:00:40 | Remaining: 00:00:00 | Avg Time/Step: 0.81
203
+ Step 0 | Loss: 11.070963 | Norm: 48.8176 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:44 | Avg Time/Step: 0.91
204
+ Generated Text at Step 0: The king saidSeptemberSeptember 354 Fill ShameLots may>>>>>>>>umpyurry Apex nurses NEWS159 Vanguard FlemingictionTAJul Jihad LAR $\ underjri Columb
205
+ Validating...
206
+ Validation Loss: 10.916313171386719
207
+ Step 1 | Loss: 11.171237 | Norm: 45.8637 | LR: 1.68e-06 | Time: 00:00:02 | Remaining: 00:01:11 | Avg Time/Step: 1.50
208
+ Step 2 | Loss: 11.089214 | Norm: 49.5361 | LR: 2.52e-06 | Time: 00:00:03 | Remaining: 00:00:50 | Avg Time/Step: 1.07
209
+ Validating...
210
+ Validation Loss: 10.893363952636719
211
+ Step 3 | Loss: 10.763819 | Norm: 52.8166 | LR: 3.36e-06 | Time: 00:00:04 | Remaining: 00:00:49 | Avg Time/Step: 1.08
212
+ Step 4 | Loss: 11.204582 | Norm: 47.4927 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:40 | Avg Time/Step: 0.91
213
+ Validating...
214
+ Validation Loss: 10.86690902709961
215
+ Step 5 | Loss: 10.957478 | Norm: 41.5032 | LR: 5.03e-06 | Time: 00:00:05 | Remaining: 00:00:40 | Avg Time/Step: 0.92
216
+ Step 6 | Loss: 10.586459 | Norm: 43.5531 | LR: 5.87e-06 | Time: 00:00:05 | Remaining: 00:00:35 | Avg Time/Step: 0.83
217
+ Validating...
218
+ Validation Loss: 10.835768699645996
219
+ Step 7 | Loss: 11.205253 | Norm: 44.9156 | LR: 6.71e-06 | Time: 00:00:07 | Remaining: 00:00:36 | Avg Time/Step: 0.88
220
+ Step 8 | Loss: 10.609798 | Norm: 48.2627 | LR: 7.55e-06 | Time: 00:00:07 | Remaining: 00:00:33 | Avg Time/Step: 0.81
221
+ Validating...
222
+ Validation Loss: 10.792684555053711
223
+ Step 9 | Loss: 9.896498 | Norm: 43.1797 | LR: 8.39e-06 | Time: 00:00:08 | Remaining: 00:00:34 | Avg Time/Step: 0.86
224
+ Step 10 | Loss: 11.131380 | Norm: 44.4814 | LR: 9.23e-06 | Time: 00:00:08 | Remaining: 00:00:31 | Avg Time/Step: 0.80
225
+ Validating...
226
+ Validation Loss: 10.749573707580566
227
+ Step 11 | Loss: 10.463729 | Norm: 47.8602 | LR: 1.01e-05 | Time: 00:00:10 | Remaining: 00:00:32 | Avg Time/Step: 0.84
228
+ Step 12 | Loss: 10.880756 | Norm: 43.9313 | LR: 1.09e-05 | Time: 00:00:10 | Remaining: 00:00:29 | Avg Time/Step: 0.80
229
+ Validating...
230
+ Validation Loss: 10.712495803833008
231
+ Step 13 | Loss: 9.864075 | Norm: 42.5331 | LR: 1.17e-05 | Time: 00:00:11 | Remaining: 00:00:30 | Avg Time/Step: 0.85
232
+ Step 14 | Loss: 10.922160 | Norm: 44.6511 | LR: 1.26e-05 | Time: 00:00:12 | Remaining: 00:00:28 | Avg Time/Step: 0.81
233
+ Validating...
234
+ Validation Loss: 10.67584228515625
235
+ Step 15 | Loss: 10.775851 | Norm: 44.4024 | LR: 1.34e-05 | Time: 00:00:13 | Remaining: 00:00:28 | Avg Time/Step: 0.84
236
+ Step 16 | Loss: 10.330193 | Norm: 43.8886 | LR: 1.43e-05 | Time: 00:00:13 | Remaining: 00:00:26 | Avg Time/Step: 0.80
237
+ Validating...
238
+ Validation Loss: 10.615331649780273
239
+ Step 17 | Loss: 10.270191 | Norm: 44.5217 | LR: 1.51e-05 | Time: 00:00:15 | Remaining: 00:00:27 | Avg Time/Step: 0.84
240
+ Step 18 | Loss: 10.027596 | Norm: 46.1209 | LR: 1.59e-05 | Time: 00:00:15 | Remaining: 00:00:25 | Avg Time/Step: 0.81
241
+ Validating...
242
+ Validation Loss: 10.553497314453125
243
+ Step 19 | Loss: 10.182181 | Norm: 40.7514 | LR: 1.68e-05 | Time: 00:00:16 | Remaining: 00:00:25 | Avg Time/Step: 0.85
244
+ Step 20 | Loss: 9.555431 | Norm: 34.3714 | LR: 1.76e-05 | Time: 00:00:17 | Remaining: 00:00:23 | Avg Time/Step: 0.82
245
+ Validating...
246
+ Validation Loss: 10.458913803100586
247
+ Step 21 | Loss: 10.136066 | Norm: 35.4013 | LR: 1.85e-05 | Time: 00:00:18 | Remaining: 00:00:23 | Avg Time/Step: 0.84
248
+ Step 22 | Loss: 10.260824 | Norm: 35.9827 | LR: 1.93e-05 | Time: 00:00:18 | Remaining: 00:00:21 | Avg Time/Step: 0.81
249
+ Validating...
250
+ Validation Loss: 10.345619201660156
251
+ Step 23 | Loss: 9.837000 | Norm: 34.4205 | LR: 2.01e-05 | Time: 00:00:20 | Remaining: 00:00:22 | Avg Time/Step: 0.85
252
+ Step 24 | Loss: 10.418470 | Norm: 35.1306 | LR: 2.10e-05 | Time: 00:00:20 | Remaining: 00:00:20 | Avg Time/Step: 0.83
253
+ Validating...
254
+ Validation Loss: 10.242090225219727
255
+ Step 25 | Loss: 10.759716 | Norm: 34.3984 | LR: 2.18e-05 | Time: 00:00:22 | Remaining: 00:00:20 | Avg Time/Step: 0.85
256
+ Step 26 | Loss: 10.433059 | Norm: 33.6258 | LR: 2.27e-05 | Time: 00:00:22 | Remaining: 00:00:19 | Avg Time/Step: 0.85
257
+ Validating...
258
+ Validation Loss: 10.15864372253418
259
+ Step 27 | Loss: 11.198073 | Norm: 33.6489 | LR: 2.35e-05 | Time: 00:00:24 | Remaining: 00:00:18 | Avg Time/Step: 0.86
260
+ Step 28 | Loss: 9.453720 | Norm: 30.4983 | LR: 2.43e-05 | Time: 00:00:24 | Remaining: 00:00:17 | Avg Time/Step: 0.84
261
+ Validating...
262
+ Validation Loss: 10.089692115783691
263
+ Step 29 | Loss: 10.043849 | Norm: 30.8429 | LR: 2.52e-05 | Time: 00:00:25 | Remaining: 00:00:17 | Avg Time/Step: 0.86
264
+ Step 30 | Loss: 10.345837 | Norm: 28.3254 | LR: 2.60e-05 | Time: 00:00:26 | Remaining: 00:00:15 | Avg Time/Step: 0.84
265
+ Validating...
266
+ Validation Loss: 10.014737129211426
267
+ Step 31 | Loss: 9.762772 | Norm: 28.7018 | LR: 2.69e-05 | Time: 00:00:27 | Remaining: 00:00:15 | Avg Time/Step: 0.86
268
+ Step 32 | Loss: 9.099653 | Norm: 28.1757 | LR: 2.77e-05 | Time: 00:00:27 | Remaining: 00:00:14 | Avg Time/Step: 0.84
269
+ Validating...
270
+ Validation Loss: 9.956048011779785
271
+ Step 33 | Loss: 8.908812 | Norm: 25.8786 | LR: 2.85e-05 | Time: 00:00:29 | Remaining: 00:00:13 | Avg Time/Step: 0.87
272
+ Step 34 | Loss: 10.699462 | Norm: 25.3921 | LR: 2.94e-05 | Time: 00:00:29 | Remaining: 00:00:12 | Avg Time/Step: 0.85
273
+ Validating...
274
+ Validation Loss: 9.902624130249023
275
+ Step 35 | Loss: 9.239347 | Norm: 25.3455 | LR: 3.02e-05 | Time: 00:00:31 | Remaining: 00:00:12 | Avg Time/Step: 0.87
276
+ Step 36 | Loss: 10.142147 | Norm: 24.3786 | LR: 3.10e-05 | Time: 00:00:31 | Remaining: 00:00:11 | Avg Time/Step: 0.85
277
+ Validating...
278
+ Validation Loss: 9.841948509216309
279
+ Step 37 | Loss: 10.260188 | Norm: 23.3623 | LR: 3.19e-05 | Time: 00:00:32 | Remaining: 00:00:10 | Avg Time/Step: 0.87
280
+ Step 38 | Loss: 9.482347 | Norm: 24.0785 | LR: 3.27e-05 | Time: 00:00:33 | Remaining: 00:00:09 | Avg Time/Step: 0.85
281
+ Validating...
282
+ Validation Loss: 9.79233169555664
283
+ Step 39 | Loss: 8.717162 | Norm: 23.1963 | LR: 3.36e-05 | Time: 00:00:34 | Remaining: 00:00:08 | Avg Time/Step: 0.86
284
+ Step 40 | Loss: 9.536521 | Norm: 21.8829 | LR: 3.44e-05 | Time: 00:00:34 | Remaining: 00:00:07 | Avg Time/Step: 0.84
285
+ Validating...
286
+ Validation Loss: 9.746158599853516
287
+ Step 41 | Loss: 9.760999 | Norm: 21.4380 | LR: 3.52e-05 | Time: 00:00:35 | Remaining: 00:00:06 | Avg Time/Step: 0.85
288
+ Step 42 | Loss: 9.588884 | Norm: 22.2327 | LR: 3.61e-05 | Time: 00:00:36 | Remaining: 00:00:05 | Avg Time/Step: 0.84
289
+ Validating...
290
+ Validation Loss: 9.688400268554688
291
+ Step 43 | Loss: 8.350541 | Norm: 20.6459 | LR: 3.69e-05 | Time: 00:00:37 | Remaining: 00:00:05 | Avg Time/Step: 0.85
292
+ Step 44 | Loss: 9.594240 | Norm: 20.0493 | LR: 3.78e-05 | Time: 00:00:37 | Remaining: 00:00:04 | Avg Time/Step: 0.83
293
+ Validating...
294
+ Validation Loss: 9.622390747070312
295
+ Step 45 | Loss: 8.240631 | Norm: 20.1186 | LR: 3.86e-05 | Time: 00:00:38 | Remaining: 00:00:03 | Avg Time/Step: 0.84
296
+ Step 46 | Loss: 8.915052 | Norm: 20.4390 | LR: 3.94e-05 | Time: 00:00:39 | Remaining: 00:00:02 | Avg Time/Step: 0.83
297
+ Validating...
298
+ Validation Loss: 9.558349609375
299
+ Step 47 | Loss: 8.285755 | Norm: 20.3787 | LR: 4.03e-05 | Time: 00:00:40 | Remaining: 00:00:01 | Avg Time/Step: 0.84
300
+ Step 48 | Loss: 8.551549 | Norm: 20.1920 | LR: 4.11e-05 | Time: 00:00:40 | Remaining: 00:00:00 | Avg Time/Step: 0.83
301
+ Validating...
302
+ Validation Loss: 9.461584091186523
303
+ Step 49 | Loss: 9.774352 | Norm: 20.2260 | LR: 4.20e-05 | Time: 00:00:42 | Remaining: 00:00:00 | Avg Time/Step: 0.84
304
+ Step 0 | Loss: 11.070963 | Norm: 48.8176 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:42 | Avg Time/Step: 0.87
305
+ Generated Text at Step 0: The king saidSeptemberSeptember 354 Fill ShameLots may>>>>>>>>umpyurry Apex nurses NEWS159 Vanguard FlemingictionTAJul Jihad LAR $\ underjri Columb
306
+ Validating...
307
+ Validation Loss: 10.916313171386719
308
+ Step 1 | Loss: 11.171237 | Norm: 45.8637 | LR: 1.68e-06 | Time: 00:00:03 | Remaining: 00:01:12 | Avg Time/Step: 1.51
309
+ Step 2 | Loss: 11.089214 | Norm: 49.5361 | LR: 2.52e-06 | Time: 00:00:03 | Remaining: 00:00:50 | Avg Time/Step: 1.08
310
+ Step 3 | Loss: 10.763819 | Norm: 52.8166 | LR: 3.36e-06 | Time: 00:00:03 | Remaining: 00:00:40 | Avg Time/Step: 0.88
311
+ Step 4 | Loss: 11.204582 | Norm: 47.4927 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:36 | Avg Time/Step: 0.81
312
+ Step 5 | Loss: 10.957478 | Norm: 41.5032 | LR: 5.03e-06 | Time: 00:00:04 | Remaining: 00:00:32 | Avg Time/Step: 0.73
313
+ Step 6 | Loss: 10.586459 | Norm: 43.5531 | LR: 5.87e-06 | Time: 00:00:04 | Remaining: 00:00:29 | Avg Time/Step: 0.68
314
+ Step 7 | Loss: 11.205253 | Norm: 44.9156 | LR: 6.71e-06 | Time: 00:00:05 | Remaining: 00:00:26 | Avg Time/Step: 0.64
315
+ Step 8 | Loss: 10.609798 | Norm: 48.2627 | LR: 7.55e-06 | Time: 00:00:05 | Remaining: 00:00:25 | Avg Time/Step: 0.61
316
+ Step 9 | Loss: 9.896498 | Norm: 43.1797 | LR: 8.39e-06 | Time: 00:00:05 | Remaining: 00:00:23 | Avg Time/Step: 0.59
317
+ Step 10 | Loss: 11.131380 | Norm: 44.4814 | LR: 9.23e-06 | Time: 00:00:06 | Remaining: 00:00:22 | Avg Time/Step: 0.57
318
+ Validating...
319
+ Validation Loss: 10.749573707580566
320
+ Step 11 | Loss: 10.463729 | Norm: 47.8602 | LR: 1.01e-05 | Time: 00:00:07 | Remaining: 00:00:23 | Avg Time/Step: 0.62
321
+ Step 12 | Loss: 10.880756 | Norm: 43.9313 | LR: 1.09e-05 | Time: 00:00:07 | Remaining: 00:00:22 | Avg Time/Step: 0.60
322
+ Step 13 | Loss: 9.864075 | Norm: 42.5331 | LR: 1.17e-05 | Time: 00:00:08 | Remaining: 00:00:20 | Avg Time/Step: 0.57
323
+ Step 14 | Loss: 10.922160 | Norm: 44.6511 | LR: 1.26e-05 | Time: 00:00:08 | Remaining: 00:00:19 | Avg Time/Step: 0.57
324
+ Step 15 | Loss: 10.775851 | Norm: 44.4024 | LR: 1.34e-05 | Time: 00:00:08 | Remaining: 00:00:18 | Avg Time/Step: 0.55
325
+ Step 16 | Loss: 10.330193 | Norm: 43.8886 | LR: 1.43e-05 | Time: 00:00:09 | Remaining: 00:00:18 | Avg Time/Step: 0.55
326
+ Step 17 | Loss: 10.270191 | Norm: 44.5217 | LR: 1.51e-05 | Time: 00:00:09 | Remaining: 00:00:17 | Avg Time/Step: 0.53
327
+ Step 18 | Loss: 10.027596 | Norm: 46.1209 | LR: 1.59e-05 | Time: 00:00:10 | Remaining: 00:00:16 | Avg Time/Step: 0.53
328
+ Step 19 | Loss: 10.182181 | Norm: 40.7514 | LR: 1.68e-05 | Time: 00:00:10 | Remaining: 00:00:15 | Avg Time/Step: 0.52
329
+ Step 20 | Loss: 9.555431 | Norm: 34.3714 | LR: 1.76e-05 | Time: 00:00:10 | Remaining: 00:00:14 | Avg Time/Step: 0.51
330
+ Validating...
331
+ Validation Loss: 10.458913803100586
332
+ Step 21 | Loss: 10.136066 | Norm: 35.4013 | LR: 1.85e-05 | Time: 00:00:12 | Remaining: 00:00:15 | Avg Time/Step: 0.56
333
+ Step 22 | Loss: 10.260824 | Norm: 35.9827 | LR: 1.93e-05 | Time: 00:00:12 | Remaining: 00:00:14 | Avg Time/Step: 0.54
334
+ Step 23 | Loss: 9.837000 | Norm: 34.4205 | LR: 2.01e-05 | Time: 00:00:12 | Remaining: 00:00:14 | Avg Time/Step: 0.54
335
+ Step 24 | Loss: 10.418470 | Norm: 35.1306 | LR: 2.10e-05 | Time: 00:00:13 | Remaining: 00:00:13 | Avg Time/Step: 0.53
336
+ Step 25 | Loss: 10.759716 | Norm: 34.3984 | LR: 2.18e-05 | Time: 00:00:13 | Remaining: 00:00:12 | Avg Time/Step: 0.53
337
+ Step 26 | Loss: 10.433059 | Norm: 33.6258 | LR: 2.27e-05 | Time: 00:00:14 | Remaining: 00:00:12 | Avg Time/Step: 0.53
338
+ Step 27 | Loss: 11.198073 | Norm: 33.6489 | LR: 2.35e-05 | Time: 00:00:14 | Remaining: 00:00:11 | Avg Time/Step: 0.53
339
+ Step 28 | Loss: 9.453720 | Norm: 30.4983 | LR: 2.43e-05 | Time: 00:00:15 | Remaining: 00:00:10 | Avg Time/Step: 0.52
340
+ Step 29 | Loss: 10.043849 | Norm: 30.8429 | LR: 2.52e-05 | Time: 00:00:15 | Remaining: 00:00:10 | Avg Time/Step: 0.52
341
+ Step 30 | Loss: 10.345837 | Norm: 28.3254 | LR: 2.60e-05 | Time: 00:00:16 | Remaining: 00:00:09 | Avg Time/Step: 0.52
342
+ Validating...
343
+ Validation Loss: 10.014737129211426
344
+ Step 31 | Loss: 9.762772 | Norm: 28.7018 | LR: 2.69e-05 | Time: 00:00:17 | Remaining: 00:00:09 | Avg Time/Step: 0.54
345
+ Step 32 | Loss: 9.099653 | Norm: 28.1757 | LR: 2.77e-05 | Time: 00:00:17 | Remaining: 00:00:09 | Avg Time/Step: 0.53
346
+ Step 33 | Loss: 8.908812 | Norm: 25.8786 | LR: 2.85e-05 | Time: 00:00:17 | Remaining: 00:00:08 | Avg Time/Step: 0.53
347
+ Step 34 | Loss: 10.699462 | Norm: 25.3921 | LR: 2.94e-05 | Time: 00:00:18 | Remaining: 00:00:07 | Avg Time/Step: 0.52
348
+ Step 35 | Loss: 9.239347 | Norm: 25.3455 | LR: 3.02e-05 | Time: 00:00:18 | Remaining: 00:00:07 | Avg Time/Step: 0.52
349
+ Step 36 | Loss: 10.142147 | Norm: 24.3786 | LR: 3.10e-05 | Time: 00:00:19 | Remaining: 00:00:06 | Avg Time/Step: 0.52
350
+ Step 37 | Loss: 10.260188 | Norm: 23.3623 | LR: 3.19e-05 | Time: 00:00:19 | Remaining: 00:00:06 | Avg Time/Step: 0.52
351
+ Step 38 | Loss: 9.482347 | Norm: 24.0785 | LR: 3.27e-05 | Time: 00:00:20 | Remaining: 00:00:05 | Avg Time/Step: 0.52
352
+ Step 39 | Loss: 8.717162 | Norm: 23.1963 | LR: 3.36e-05 | Time: 00:00:20 | Remaining: 00:00:05 | Avg Time/Step: 0.52
353
+ Step 40 | Loss: 9.536521 | Norm: 21.8829 | LR: 3.44e-05 | Time: 00:00:21 | Remaining: 00:00:04 | Avg Time/Step: 0.51
354
+ Validating...
355
+ Validation Loss: 9.746158599853516
356
+ Step 41 | Loss: 9.760999 | Norm: 21.4380 | LR: 3.52e-05 | Time: 00:00:22 | Remaining: 00:00:04 | Avg Time/Step: 0.53
357
+ Step 42 | Loss: 9.588884 | Norm: 22.2327 | LR: 3.61e-05 | Time: 00:00:22 | Remaining: 00:00:03 | Avg Time/Step: 0.53
358
+ Step 43 | Loss: 8.350541 | Norm: 20.6459 | LR: 3.69e-05 | Time: 00:00:23 | Remaining: 00:00:03 | Avg Time/Step: 0.53
359
+ Step 44 | Loss: 9.594240 | Norm: 20.0493 | LR: 3.78e-05 | Time: 00:00:23 | Remaining: 00:00:02 | Avg Time/Step: 0.52
360
+ Step 45 | Loss: 8.240631 | Norm: 20.1186 | LR: 3.86e-05 | Time: 00:00:23 | Remaining: 00:00:02 | Avg Time/Step: 0.52
361
+ Step 46 | Loss: 8.915052 | Norm: 20.4390 | LR: 3.94e-05 | Time: 00:00:24 | Remaining: 00:00:01 | Avg Time/Step: 0.51
362
+ Step 47 | Loss: 8.285755 | Norm: 20.3787 | LR: 4.03e-05 | Time: 00:00:24 | Remaining: 00:00:01 | Avg Time/Step: 0.51
363
+ Step 48 | Loss: 8.551549 | Norm: 20.1920 | LR: 4.11e-05 | Time: 00:00:24 | Remaining: 00:00:00 | Avg Time/Step: 0.51
364
+ Step 49 | Loss: 9.774352 | Norm: 20.2260 | LR: 4.20e-05 | Time: 00:00:25 | Remaining: 00:00:00 | Avg Time/Step: 0.51
365
+ Step 0 | Loss: 11.633898 | Norm: 44.3633 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:42 | Avg Time/Step: 0.86
366
+ Generated Text at Step 0: The king said ginger stupid 194 idi shrugged outperPLIC Pitch chapter chapter 169 Drac darkest darkesttic Suk encrypted outperGraphics bisexual PitchBC1987 Cobra drives
367
+ Validating...
368
+ Validation Loss: 10.924361228942871
369
+ Step 1 | Loss: 10.875149 | Norm: 51.7900 | LR: 1.68e-06 | Time: 00:00:02 | Remaining: 00:01:09 | Avg Time/Step: 1.45
370
+ Step 2 | Loss: 10.981276 | Norm: 44.8046 | LR: 2.52e-06 | Time: 00:00:03 | Remaining: 00:00:49 | Avg Time/Step: 1.06
371
+ Step 3 | Loss: 10.517224 | Norm: 49.3383 | LR: 3.36e-06 | Time: 00:00:03 | Remaining: 00:00:39 | Avg Time/Step: 0.86
372
+ Step 4 | Loss: 11.220371 | Norm: 50.4130 | LR: 4.20e-06 | Time: 00:00:03 | Remaining: 00:00:32 | Avg Time/Step: 0.73
373
+ Step 5 | Loss: 11.176923 | Norm: 47.4072 | LR: 5.03e-06 | Time: 00:00:03 | Remaining: 00:00:28 | Avg Time/Step: 0.65
374
+ Step 6 | Loss: 10.935453 | Norm: 45.3805 | LR: 5.87e-06 | Time: 00:00:04 | Remaining: 00:00:25 | Avg Time/Step: 0.59
375
+ Step 7 | Loss: 10.582232 | Norm: 46.3087 | LR: 6.71e-06 | Time: 00:00:04 | Remaining: 00:00:23 | Avg Time/Step: 0.56
376
+ Step 8 | Loss: 11.022345 | Norm: 43.2213 | LR: 7.55e-06 | Time: 00:00:04 | Remaining: 00:00:22 | Avg Time/Step: 0.55
377
+ Step 9 | Loss: 10.926727 | Norm: 44.3769 | LR: 8.39e-06 | Time: 00:00:05 | Remaining: 00:00:21 | Avg Time/Step: 0.53
378
+ Step 10 | Loss: 10.986204 | Norm: 45.1195 | LR: 9.23e-06 | Time: 00:00:05 | Remaining: 00:00:20 | Avg Time/Step: 0.52
379
+ Validating...
380
+ Validation Loss: 10.735955238342285
381
+ Step 11 | Loss: 11.179207 | Norm: 47.9544 | LR: 1.01e-05 | Time: 00:00:07 | Remaining: 00:00:22 | Avg Time/Step: 0.59
382
+ Step 12 | Loss: 10.763081 | Norm: 43.9656 | LR: 1.09e-05 | Time: 00:00:07 | Remaining: 00:00:20 | Avg Time/Step: 0.56
383
+ Step 13 | Loss: 10.720469 | Norm: 43.3385 | LR: 1.17e-05 | Time: 00:00:07 | Remaining: 00:00:19 | Avg Time/Step: 0.54
384
+ Step 14 | Loss: 11.064083 | Norm: 43.1475 | LR: 1.26e-05 | Time: 00:00:07 | Remaining: 00:00:18 | Avg Time/Step: 0.53
385
+ Step 15 | Loss: 10.534277 | Norm: 44.0213 | LR: 1.34e-05 | Time: 00:00:08 | Remaining: 00:00:17 | Avg Time/Step: 0.52
386
+ Step 16 | Loss: 10.638024 | Norm: 42.8165 | LR: 1.43e-05 | Time: 00:00:08 | Remaining: 00:00:16 | Avg Time/Step: 0.51
387
+ Step 17 | Loss: 11.247206 | Norm: 39.1321 | LR: 1.51e-05 | Time: 00:00:09 | Remaining: 00:00:16 | Avg Time/Step: 0.51
388
+ Step 18 | Loss: 10.615942 | Norm: 45.8611 | LR: 1.59e-05 | Time: 00:00:09 | Remaining: 00:00:15 | Avg Time/Step: 0.50
389
+ Step 19 | Loss: 10.434818 | Norm: 35.2029 | LR: 1.68e-05 | Time: 00:00:09 | Remaining: 00:00:14 | Avg Time/Step: 0.49
390
+ Step 20 | Loss: 9.872961 | Norm: 37.0101 | LR: 1.76e-05 | Time: 00:00:10 | Remaining: 00:00:14 | Avg Time/Step: 0.49
391
+ Validating...
392
+ Validation Loss: 10.45177936553955
393
+ Step 21 | Loss: 10.303642 | Norm: 38.1966 | LR: 1.85e-05 | Time: 00:00:11 | Remaining: 00:00:14 | Avg Time/Step: 0.53
394
+ Step 22 | Loss: 10.344124 | Norm: 36.7267 | LR: 1.93e-05 | Time: 00:00:11 | Remaining: 00:00:13 | Avg Time/Step: 0.51
395
+ Step 23 | Loss: 10.358528 | Norm: 33.4473 | LR: 2.01e-05 | Time: 00:00:12 | Remaining: 00:00:13 | Avg Time/Step: 0.50
396
+ Step 24 | Loss: 10.899721 | Norm: 33.1147 | LR: 2.10e-05 | Time: 00:00:12 | Remaining: 00:00:12 | Avg Time/Step: 0.50
397
+ Step 25 | Loss: 10.167845 | Norm: 32.0061 | LR: 2.18e-05 | Time: 00:00:12 | Remaining: 00:00:11 | Avg Time/Step: 0.50
398
+ Step 26 | Loss: 10.658374 | Norm: 32.8027 | LR: 2.27e-05 | Time: 00:00:13 | Remaining: 00:00:11 | Avg Time/Step: 0.49
399
+ Step 27 | Loss: 11.409204 | Norm: 32.2853 | LR: 2.35e-05 | Time: 00:00:13 | Remaining: 00:00:10 | Avg Time/Step: 0.49
400
+ Step 28 | Loss: 9.699551 | Norm: 28.3168 | LR: 2.43e-05 | Time: 00:00:14 | Remaining: 00:00:10 | Avg Time/Step: 0.48
401
+ Step 29 | Loss: 10.293508 | Norm: 29.6286 | LR: 2.52e-05 | Time: 00:00:14 | Remaining: 00:00:09 | Avg Time/Step: 0.48
402
+ Step 30 | Loss: 10.796824 | Norm: 32.4335 | LR: 2.60e-05 | Time: 00:00:14 | Remaining: 00:00:09 | Avg Time/Step: 0.48
403
+ Validating...
404
+ Validation Loss: 10.094558715820312
405
+ Step 31 | Loss: 9.871226 | Norm: 29.4167 | LR: 2.69e-05 | Time: 00:00:16 | Remaining: 00:00:09 | Avg Time/Step: 0.50
406
+ Step 32 | Loss: 9.355142 | Norm: 29.9377 | LR: 2.77e-05 | Time: 00:00:16 | Remaining: 00:00:08 | Avg Time/Step: 0.49
407
+ Step 33 | Loss: 9.169601 | Norm: 28.7526 | LR: 2.85e-05 | Time: 00:00:16 | Remaining: 00:00:07 | Avg Time/Step: 0.49
408
+ Step 34 | Loss: 11.027575 | Norm: 25.0330 | LR: 2.94e-05 | Time: 00:00:17 | Remaining: 00:00:07 | Avg Time/Step: 0.49
409
+ Step 35 | Loss: 9.624268 | Norm: 26.2367 | LR: 3.02e-05 | Time: 00:00:17 | Remaining: 00:00:06 | Avg Time/Step: 0.48
410
+ Step 36 | Loss: 10.801857 | Norm: 26.3915 | LR: 3.10e-05 | Time: 00:00:17 | Remaining: 00:00:06 | Avg Time/Step: 0.48
411
+ Step 37 | Loss: 10.625546 | Norm: 24.5588 | LR: 3.19e-05 | Time: 00:00:18 | Remaining: 00:00:05 | Avg Time/Step: 0.48
412
+ Step 38 | Loss: 9.325054 | Norm: 23.9140 | LR: 3.27e-05 | Time: 00:00:18 | Remaining: 00:00:05 | Avg Time/Step: 0.48
413
+ Step 39 | Loss: 8.672618 | Norm: 22.6858 | LR: 3.36e-05 | Time: 00:00:18 | Remaining: 00:00:04 | Avg Time/Step: 0.47
414
+ Step 40 | Loss: 9.316482 | Norm: 23.7950 | LR: 3.44e-05 | Time: 00:00:19 | Remaining: 00:00:04 | Avg Time/Step: 0.47
415
+ Validating...
416
+ Validation Loss: 9.847529411315918
417
+ Step 41 | Loss: 9.895099 | Norm: 22.9021 | LR: 3.52e-05 | Time: 00:00:20 | Remaining: 00:00:03 | Avg Time/Step: 0.49
418
+ Step 42 | Loss: 9.908270 | Norm: 22.2757 | LR: 3.61e-05 | Time: 00:00:20 | Remaining: 00:00:03 | Avg Time/Step: 0.48
419
+ Step 43 | Loss: 8.863647 | Norm: 21.6877 | LR: 3.69e-05 | Time: 00:00:21 | Remaining: 00:00:02 | Avg Time/Step: 0.48
420
+ Step 44 | Loss: 9.615014 | Norm: 22.1502 | LR: 3.78e-05 | Time: 00:00:21 | Remaining: 00:00:02 | Avg Time/Step: 0.48
421
+ Step 45 | Loss: 7.558504 | Norm: 20.6337 | LR: 3.86e-05 | Time: 00:00:21 | Remaining: 00:00:01 | Avg Time/Step: 0.48
422
+ Step 46 | Loss: 9.626184 | Norm: 22.2072 | LR: 3.94e-05 | Time: 00:00:22 | Remaining: 00:00:01 | Avg Time/Step: 0.47
423
+ Step 47 | Loss: 8.649675 | Norm: 21.2089 | LR: 4.03e-05 | Time: 00:00:22 | Remaining: 00:00:00 | Avg Time/Step: 0.47
424
+ Step 48 | Loss: 8.570056 | Norm: 21.1816 | LR: 4.11e-05 | Time: 00:00:23 | Remaining: 00:00:00 | Avg Time/Step: 0.47
425
+ Step 49 | Loss: 9.796856 | Norm: 20.8208 | LR: 4.20e-05 | Time: 00:00:23 | Remaining: 00:00:00 | Avg Time/Step: 0.47
426
+ Step 0 | Loss: 11.633898 | Norm: 44.3633 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:44 | Avg Time/Step: 0.91
427
+ Generated Text at Step 0: The king said ginger stupid 194 idi shrugged outperPLIC Pitch chapter chapter 169 Drac darkest darkesttic Suk encrypted outperGraphics bisexual PitchBC1987 Cobra drives
428
+ Validating...
429
+ Validation Loss: 10.924361228942871
430
+ Step 1 | Loss: 10.875149 | Norm: 51.7900 | LR: 1.68e-06 | Time: 00:00:03 | Remaining: 00:01:34 | Avg Time/Step: 1.97
431
+ Step 2 | Loss: 10.981276 | Norm: 44.8046 | LR: 2.52e-06 | Time: 00:00:04 | Remaining: 00:01:05 | Avg Time/Step: 1.39
432
+ Step 3 | Loss: 10.517224 | Norm: 49.3383 | LR: 3.36e-06 | Time: 00:00:04 | Remaining: 00:00:50 | Avg Time/Step: 1.10
433
+ Step 4 | Loss: 11.220371 | Norm: 50.4130 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:41 | Avg Time/Step: 0.92
434
+ Step 5 | Loss: 11.176923 | Norm: 47.4072 | LR: 5.03e-06 | Time: 00:00:04 | Remaining: 00:00:35 | Avg Time/Step: 0.80
435
+ Step 6 | Loss: 10.935453 | Norm: 45.3805 | LR: 5.87e-06 | Time: 00:00:05 | Remaining: 00:00:31 | Avg Time/Step: 0.72
436
+ Step 7 | Loss: 10.582232 | Norm: 46.3087 | LR: 6.71e-06 | Time: 00:00:05 | Remaining: 00:00:27 | Avg Time/Step: 0.66
437
+ Step 8 | Loss: 11.022345 | Norm: 43.2213 | LR: 7.55e-06 | Time: 00:00:05 | Remaining: 00:00:25 | Avg Time/Step: 0.62
438
+ Step 9 | Loss: 10.926727 | Norm: 44.3769 | LR: 8.39e-06 | Time: 00:00:06 | Remaining: 00:00:24 | Avg Time/Step: 0.61
439
+ Step 10 | Loss: 10.986204 | Norm: 45.1195 | LR: 9.23e-06 | Time: 00:00:06 | Remaining: 00:00:22 | Avg Time/Step: 0.59
440
+ Validating...
441
+ Validation Loss: 10.735955238342285
442
+ Step 11 | Loss: 11.179207 | Norm: 47.9544 | LR: 1.01e-05 | Time: 00:00:07 | Remaining: 00:00:25 | Avg Time/Step: 0.66
443
+ Step 12 | Loss: 10.763081 | Norm: 43.9656 | LR: 1.09e-05 | Time: 00:00:08 | Remaining: 00:00:23 | Avg Time/Step: 0.63
444
+ Step 13 | Loss: 10.720469 | Norm: 43.3385 | LR: 1.17e-05 | Time: 00:00:08 | Remaining: 00:00:22 | Avg Time/Step: 0.63
445
+ Step 14 | Loss: 11.064083 | Norm: 43.1475 | LR: 1.26e-05 | Time: 00:00:09 | Remaining: 00:00:21 | Avg Time/Step: 0.62
446
+ Step 15 | Loss: 10.534277 | Norm: 44.0213 | LR: 1.34e-05 | Time: 00:00:09 | Remaining: 00:00:20 | Avg Time/Step: 0.60
447
+ Step 16 | Loss: 10.638024 | Norm: 42.8165 | LR: 1.43e-05 | Time: 00:00:10 | Remaining: 00:00:20 | Avg Time/Step: 0.61
448
+ Step 17 | Loss: 11.247206 | Norm: 39.1321 | LR: 1.51e-05 | Time: 00:00:10 | Remaining: 00:00:18 | Avg Time/Step: 0.59
449
+ Step 18 | Loss: 10.615942 | Norm: 45.8611 | LR: 1.59e-05 | Time: 00:00:11 | Remaining: 00:00:18 | Avg Time/Step: 0.59
450
+ Step 19 | Loss: 10.434818 | Norm: 35.2029 | LR: 1.68e-05 | Time: 00:00:11 | Remaining: 00:00:17 | Avg Time/Step: 0.58
451
+ Step 20 | Loss: 9.872961 | Norm: 37.0101 | LR: 1.76e-05 | Time: 00:00:11 | Remaining: 00:00:16 | Avg Time/Step: 0.57
452
+ Validating...
453
+ Validation Loss: 10.45177936553955
454
+ Step 21 | Loss: 10.303642 | Norm: 38.1966 | LR: 1.85e-05 | Time: 00:00:13 | Remaining: 00:00:16 | Avg Time/Step: 0.60
455
+ Step 22 | Loss: 10.344124 | Norm: 36.7267 | LR: 1.93e-05 | Time: 00:00:13 | Remaining: 00:00:15 | Avg Time/Step: 0.59
456
+ Step 23 | Loss: 10.358528 | Norm: 33.4473 | LR: 2.01e-05 | Time: 00:00:14 | Remaining: 00:00:15 | Avg Time/Step: 0.59
457
+ Step 24 | Loss: 10.899721 | Norm: 33.1147 | LR: 2.10e-05 | Time: 00:00:14 | Remaining: 00:00:14 | Avg Time/Step: 0.59
458
+ Step 25 | Loss: 10.167845 | Norm: 32.0061 | LR: 2.18e-05 | Time: 00:00:14 | Remaining: 00:00:13 | Avg Time/Step: 0.58
459
+ Step 26 | Loss: 10.658374 | Norm: 32.8027 | LR: 2.27e-05 | Time: 00:00:15 | Remaining: 00:00:13 | Avg Time/Step: 0.59
460
+ Step 27 | Loss: 11.409204 | Norm: 32.2853 | LR: 2.35e-05 | Time: 00:00:16 | Remaining: 00:00:12 | Avg Time/Step: 0.58
461
+ Step 28 | Loss: 9.699551 | Norm: 28.3168 | LR: 2.43e-05 | Time: 00:00:16 | Remaining: 00:00:12 | Avg Time/Step: 0.58
462
+ Step 29 | Loss: 10.293508 | Norm: 29.6286 | LR: 2.52e-05 | Time: 00:00:17 | Remaining: 00:00:11 | Avg Time/Step: 0.58
463
+ Step 30 | Loss: 10.796824 | Norm: 32.4335 | LR: 2.60e-05 | Time: 00:00:17 | Remaining: 00:00:10 | Avg Time/Step: 0.57
464
+ Validating...
465
+ Validation Loss: 10.094558715820312
466
+ Step 31 | Loss: 9.871226 | Norm: 29.4167 | LR: 2.69e-05 | Time: 00:00:19 | Remaining: 00:00:10 | Avg Time/Step: 0.60
467
+ Step 32 | Loss: 9.355142 | Norm: 29.9377 | LR: 2.77e-05 | Time: 00:00:19 | Remaining: 00:00:09 | Avg Time/Step: 0.58
468
+ Step 33 | Loss: 9.169601 | Norm: 28.7526 | LR: 2.85e-05 | Time: 00:00:19 | Remaining: 00:00:09 | Avg Time/Step: 0.58
469
+ Step 34 | Loss: 11.027575 | Norm: 25.0330 | LR: 2.94e-05 | Time: 00:00:20 | Remaining: 00:00:08 | Avg Time/Step: 0.58
470
+ Step 35 | Loss: 9.624268 | Norm: 26.2367 | LR: 3.02e-05 | Time: 00:00:20 | Remaining: 00:00:08 | Avg Time/Step: 0.58
471
+ Step 36 | Loss: 10.801857 | Norm: 26.3915 | LR: 3.10e-05 | Time: 00:00:21 | Remaining: 00:00:07 | Avg Time/Step: 0.58
472
+ Step 37 | Loss: 10.625546 | Norm: 24.5588 | LR: 3.19e-05 | Time: 00:00:22 | Remaining: 00:00:06 | Avg Time/Step: 0.58
473
+ Step 38 | Loss: 9.325054 | Norm: 23.9140 | LR: 3.27e-05 | Time: 00:00:22 | Remaining: 00:00:06 | Avg Time/Step: 0.57
474
+ Step 39 | Loss: 8.672618 | Norm: 22.6858 | LR: 3.36e-05 | Time: 00:00:22 | Remaining: 00:00:05 | Avg Time/Step: 0.57
475
+ Step 40 | Loss: 9.316482 | Norm: 23.7950 | LR: 3.44e-05 | Time: 00:00:23 | Remaining: 00:00:05 | Avg Time/Step: 0.56
476
+ Validating...
477
+ Validation Loss: 9.847529411315918
478
+ Step 41 | Loss: 9.895099 | Norm: 22.9021 | LR: 3.52e-05 | Time: 00:00:24 | Remaining: 00:00:04 | Avg Time/Step: 0.58
479
+ Step 42 | Loss: 9.908270 | Norm: 22.2757 | LR: 3.61e-05 | Time: 00:00:24 | Remaining: 00:00:04 | Avg Time/Step: 0.57
480
+ Step 43 | Loss: 8.863647 | Norm: 21.6877 | LR: 3.69e-05 | Time: 00:00:25 | Remaining: 00:00:03 | Avg Time/Step: 0.57
481
+ Step 44 | Loss: 9.615014 | Norm: 22.1502 | LR: 3.78e-05 | Time: 00:00:25 | Remaining: 00:00:02 | Avg Time/Step: 0.57
482
+ Step 45 | Loss: 7.558504 | Norm: 20.6337 | LR: 3.86e-05 | Time: 00:00:25 | Remaining: 00:00:02 | Avg Time/Step: 0.56
483
+ Step 46 | Loss: 9.626184 | Norm: 22.2072 | LR: 3.94e-05 | Time: 00:00:26 | Remaining: 00:00:01 | Avg Time/Step: 0.56
484
+ Step 47 | Loss: 8.649675 | Norm: 21.2089 | LR: 4.03e-05 | Time: 00:00:26 | Remaining: 00:00:01 | Avg Time/Step: 0.55
485
+ Step 48 | Loss: 8.570056 | Norm: 21.1816 | LR: 4.11e-05 | Time: 00:00:27 | Remaining: 00:00:00 | Avg Time/Step: 0.55
486
+ Step 49 | Loss: 9.796856 | Norm: 20.8208 | LR: 4.20e-05 | Time: 00:00:27 | Remaining: 00:00:00 | Avg Time/Step: 0.55
487
+ Step 0 | Loss: 11.633898 | Norm: 44.3633 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:42 | Avg Time/Step: 0.86
488
+ Generated Text at Step 0: The king said ginger stupid 194 idi shrugged outperPLIC Pitch chapter chapter 169 Drac darkest darkesttic Suk encrypted outperGraphics bisexual PitchBC1987 Cobra drives
489
+ Validating...
490
+ Validation Loss: 10.924361228942871
491
+ Step 1 | Loss: 10.875149 | Norm: 51.7900 | LR: 1.68e-06 | Time: 00:00:03 | Remaining: 00:01:29 | Avg Time/Step: 1.87
492
+ Step 2 | Loss: 10.981276 | Norm: 44.8046 | LR: 2.52e-06 | Time: 00:00:03 | Remaining: 00:01:01 | Avg Time/Step: 1.32
493
+ Step 3 | Loss: 10.517224 | Norm: 49.3383 | LR: 3.36e-06 | Time: 00:00:04 | Remaining: 00:00:48 | Avg Time/Step: 1.04
494
+ Step 4 | Loss: 11.220371 | Norm: 50.4130 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:39 | Avg Time/Step: 0.88
495
+ Step 5 | Loss: 11.176923 | Norm: 47.4072 | LR: 5.03e-06 | Time: 00:00:04 | Remaining: 00:00:33 | Avg Time/Step: 0.77
496
+ Step 6 | Loss: 10.935453 | Norm: 45.3805 | LR: 5.87e-06 | Time: 00:00:05 | Remaining: 00:00:35 | Avg Time/Step: 0.82
497
+ Step 7 | Loss: 10.582232 | Norm: 46.3087 | LR: 6.71e-06 | Time: 00:00:05 | Remaining: 00:00:31 | Avg Time/Step: 0.74
498
+ Step 8 | Loss: 11.022345 | Norm: 43.2213 | LR: 7.55e-06 | Time: 00:00:06 | Remaining: 00:00:28 | Avg Time/Step: 0.68
499
+ Step 9 | Loss: 10.926727 | Norm: 44.3769 | LR: 8.39e-06 | Time: 00:00:06 | Remaining: 00:00:25 | Avg Time/Step: 0.64
500
+ Step 10 | Loss: 10.986204 | Norm: 45.1195 | LR: 9.23e-06 | Time: 00:00:06 | Remaining: 00:00:23 | Avg Time/Step: 0.60
501
+ Validating...
502
+ Validation Loss: 10.735955238342285
503
+ Step 11 | Loss: 11.179207 | Norm: 47.9544 | LR: 1.01e-05 | Time: 00:00:08 | Remaining: 00:00:27 | Avg Time/Step: 0.72
504
+ Step 12 | Loss: 10.763081 | Norm: 43.9656 | LR: 1.09e-05 | Time: 00:00:08 | Remaining: 00:00:25 | Avg Time/Step: 0.68
505
+ Step 13 | Loss: 10.720469 | Norm: 43.3385 | LR: 1.17e-05 | Time: 00:00:09 | Remaining: 00:00:23 | Avg Time/Step: 0.65
506
+ Step 14 | Loss: 11.064083 | Norm: 43.1475 | LR: 1.26e-05 | Time: 00:00:09 | Remaining: 00:00:22 | Avg Time/Step: 0.64
507
+ Step 15 | Loss: 10.534277 | Norm: 44.0213 | LR: 1.34e-05 | Time: 00:00:09 | Remaining: 00:00:21 | Avg Time/Step: 0.62
508
+ Step 16 | Loss: 10.638024 | Norm: 42.8165 | LR: 1.43e-05 | Time: 00:00:11 | Remaining: 00:00:22 | Avg Time/Step: 0.67
509
+ Step 17 | Loss: 11.247206 | Norm: 39.1321 | LR: 1.51e-05 | Time: 00:00:11 | Remaining: 00:00:20 | Avg Time/Step: 0.65
510
+ Step 18 | Loss: 10.615942 | Norm: 45.8611 | LR: 1.59e-05 | Time: 00:00:12 | Remaining: 00:00:19 | Avg Time/Step: 0.64
511
+ Step 19 | Loss: 10.434818 | Norm: 35.2029 | LR: 1.68e-05 | Time: 00:00:12 | Remaining: 00:00:19 | Avg Time/Step: 0.64
512
+ Step 20 | Loss: 9.872961 | Norm: 37.0101 | LR: 1.76e-05 | Time: 00:00:13 | Remaining: 00:00:18 | Avg Time/Step: 0.63
513
+ Validating...
514
+ Validation Loss: 10.45177936553955
515
+ Step 21 | Loss: 10.303642 | Norm: 38.1966 | LR: 1.85e-05 | Time: 00:00:15 | Remaining: 00:00:19 | Avg Time/Step: 0.70
516
+ Step 22 | Loss: 10.344124 | Norm: 36.7267 | LR: 1.93e-05 | Time: 00:00:15 | Remaining: 00:00:18 | Avg Time/Step: 0.68
517
+ Step 23 | Loss: 10.358528 | Norm: 33.4473 | LR: 2.01e-05 | Time: 00:00:16 | Remaining: 00:00:17 | Avg Time/Step: 0.67
518
+ Step 24 | Loss: 10.899721 | Norm: 33.1147 | LR: 2.10e-05 | Time: 00:00:16 | Remaining: 00:00:16 | Avg Time/Step: 0.66
519
+ Step 25 | Loss: 10.167845 | Norm: 32.0061 | LR: 2.18e-05 | Time: 00:00:16 | Remaining: 00:00:15 | Avg Time/Step: 0.65
520
+ Step 26 | Loss: 10.658374 | Norm: 32.8027 | LR: 2.27e-05 | Time: 00:00:18 | Remaining: 00:00:15 | Avg Time/Step: 0.68
521
+ Step 27 | Loss: 11.409204 | Norm: 32.2853 | LR: 2.35e-05 | Time: 00:00:18 | Remaining: 00:00:14 | Avg Time/Step: 0.66
522
+ Step 28 | Loss: 9.699551 | Norm: 28.3168 | LR: 2.43e-05 | Time: 00:00:18 | Remaining: 00:00:13 | Avg Time/Step: 0.65
523
+ Step 29 | Loss: 10.293508 | Norm: 29.6286 | LR: 2.52e-05 | Time: 00:00:19 | Remaining: 00:00:12 | Avg Time/Step: 0.64
524
+ Step 30 | Loss: 10.796824 | Norm: 32.4335 | LR: 2.60e-05 | Time: 00:00:19 | Remaining: 00:00:12 | Avg Time/Step: 0.63
525
+ Validating...
526
+ Validation Loss: 10.094558715820312
527
+ Step 31 | Loss: 9.871226 | Norm: 29.4167 | LR: 2.69e-05 | Time: 00:00:21 | Remaining: 00:00:12 | Avg Time/Step: 0.68
528
+ Step 32 | Loss: 9.355142 | Norm: 29.9377 | LR: 2.77e-05 | Time: 00:00:22 | Remaining: 00:00:11 | Avg Time/Step: 0.67
529
+ Step 33 | Loss: 9.169601 | Norm: 28.7526 | LR: 2.85e-05 | Time: 00:00:22 | Remaining: 00:00:10 | Avg Time/Step: 0.66
530
+ Step 34 | Loss: 11.027575 | Norm: 25.0330 | LR: 2.94e-05 | Time: 00:00:22 | Remaining: 00:00:09 | Avg Time/Step: 0.65
531
+ Step 35 | Loss: 9.624268 | Norm: 26.2367 | LR: 3.02e-05 | Time: 00:00:23 | Remaining: 00:00:08 | Avg Time/Step: 0.64
532
+ Step 36 | Loss: 10.801857 | Norm: 26.3915 | LR: 3.10e-05 | Time: 00:00:24 | Remaining: 00:00:08 | Avg Time/Step: 0.66
533
+ Step 37 | Loss: 10.625546 | Norm: 24.5588 | LR: 3.19e-05 | Time: 00:00:24 | Remaining: 00:00:07 | Avg Time/Step: 0.66
534
+ Step 38 | Loss: 9.325054 | Norm: 23.9140 | LR: 3.27e-05 | Time: 00:00:25 | Remaining: 00:00:07 | Avg Time/Step: 0.65
535
+ Step 39 | Loss: 8.672618 | Norm: 22.6858 | LR: 3.36e-05 | Time: 00:00:26 | Remaining: 00:00:06 | Avg Time/Step: 0.65
536
+ Step 40 | Loss: 9.316482 | Norm: 23.7950 | LR: 3.44e-05 | Time: 00:00:26 | Remaining: 00:00:05 | Avg Time/Step: 0.65
537
+ Validating...
538
+ Validation Loss: 9.847529411315918
539
+ Step 41 | Loss: 9.895099 | Norm: 22.9021 | LR: 3.52e-05 | Time: 00:00:28 | Remaining: 00:00:05 | Avg Time/Step: 0.68
540
+ Step 42 | Loss: 9.908270 | Norm: 22.2757 | LR: 3.61e-05 | Time: 00:00:28 | Remaining: 00:00:04 | Avg Time/Step: 0.67
541
+ Step 43 | Loss: 8.863647 | Norm: 21.6877 | LR: 3.69e-05 | Time: 00:00:29 | Remaining: 00:00:03 | Avg Time/Step: 0.66
542
+ Step 44 | Loss: 9.615014 | Norm: 22.1502 | LR: 3.78e-05 | Time: 00:00:29 | Remaining: 00:00:03 | Avg Time/Step: 0.65
543
+ Step 45 | Loss: 7.558504 | Norm: 20.6337 | LR: 3.86e-05 | Time: 00:00:29 | Remaining: 00:00:02 | Avg Time/Step: 0.65
544
+ Step 46 | Loss: 9.626184 | Norm: 22.2072 | LR: 3.94e-05 | Time: 00:00:31 | Remaining: 00:00:01 | Avg Time/Step: 0.66
545
+ Step 47 | Loss: 8.649675 | Norm: 21.2089 | LR: 4.03e-05 | Time: 00:00:31 | Remaining: 00:00:01 | Avg Time/Step: 0.66
546
+ Step 48 | Loss: 8.570056 | Norm: 21.1816 | LR: 4.11e-05 | Time: 00:00:32 | Remaining: 00:00:00 | Avg Time/Step: 0.66
547
+ Step 49 | Loss: 9.796856 | Norm: 20.8208 | LR: 4.20e-05 | Time: 00:00:32 | Remaining: 00:00:00 | Avg Time/Step: 0.66
548
+ Step 0 | Loss: 11.633898 | Norm: 44.3633 | LR: 8.39e-07 | Time: 00:00:00 | Remaining: 00:00:42 | Avg Time/Step: 0.87
549
+ Generated Text at Step 0: The king said ginger stupid 194 idi shrugged outperPLIC Pitch chapter chapter 169 Drac darkest darkesttic Suk encrypted outperGraphics bisexual PitchBC1987 Cobra drives
550
+ Validating...
551
+ Validation Loss: 10.924361228942871
552
+ Step 1 | Loss: 10.875149 | Norm: 51.7900 | LR: 1.68e-06 | Time: 00:00:03 | Remaining: 00:01:31 | Avg Time/Step: 1.91
553
+ Step 2 | Loss: 10.981276 | Norm: 44.8046 | LR: 2.52e-06 | Time: 00:00:04 | Remaining: 00:01:03 | Avg Time/Step: 1.35
554
+ Step 3 | Loss: 10.517224 | Norm: 49.3383 | LR: 3.36e-06 | Time: 00:00:04 | Remaining: 00:00:49 | Avg Time/Step: 1.07
555
+ Step 4 | Loss: 11.220371 | Norm: 50.4130 | LR: 4.20e-06 | Time: 00:00:04 | Remaining: 00:00:40 | Avg Time/Step: 0.90
556
+ Step 5 | Loss: 11.176923 | Norm: 47.4072 | LR: 5.03e-06 | Time: 00:00:04 | Remaining: 00:00:34 | Avg Time/Step: 0.79
557
+ Step 6 | Loss: 10.935453 | Norm: 45.3805 | LR: 5.87e-06 | Time: 00:00:05 | Remaining: 00:00:35 | Avg Time/Step: 0.83
558
+ Step 7 | Loss: 10.582232 | Norm: 46.3087 | LR: 6.71e-06 | Time: 00:00:06 | Remaining: 00:00:31 | Avg Time/Step: 0.75
559
+ Step 8 | Loss: 11.022345 | Norm: 43.2213 | LR: 7.55e-06 | Time: 00:00:06 | Remaining: 00:00:28 | Avg Time/Step: 0.70
560
+ Step 9 | Loss: 10.926727 | Norm: 44.3769 | LR: 8.39e-06 | Time: 00:00:06 | Remaining: 00:00:25 | Avg Time/Step: 0.65
561
+ Step 10 | Loss: 10.986204 | Norm: 45.1195 | LR: 9.23e-06 | Time: 00:00:06 | Remaining: 00:00:23 | Avg Time/Step: 0.61
562
+ Validating...
563
+ Validation Loss: 10.735955238342285
564
+ Step 11 | Loss: 11.179207 | Norm: 47.9544 | LR: 1.01e-05 | Time: 00:00:08 | Remaining: 00:00:28 | Avg Time/Step: 0.75
565
+ Step 12 | Loss: 10.763081 | Norm: 43.9656 | LR: 1.09e-05 | Time: 00:00:09 | Remaining: 00:00:26 | Avg Time/Step: 0.71
566
+ Step 13 | Loss: 10.720469 | Norm: 43.3385 | LR: 1.17e-05 | Time: 00:00:09 | Remaining: 00:00:24 | Avg Time/Step: 0.67
567
+ Step 14 | Loss: 11.064083 | Norm: 43.1475 | LR: 1.26e-05 | Time: 00:00:09 | Remaining: 00:00:22 | Avg Time/Step: 0.65
568
+ Step 15 | Loss: 10.534277 | Norm: 44.0213 | LR: 1.34e-05 | Time: 00:00:10 | Remaining: 00:00:21 | Avg Time/Step: 0.63
569
+ Step 16 | Loss: 10.638024 | Norm: 42.8165 | LR: 1.43e-05 | Time: 00:00:11 | Remaining: 00:00:22 | Avg Time/Step: 0.68
570
+ Step 17 | Loss: 11.247206 | Norm: 39.1321 | LR: 1.51e-05 | Time: 00:00:11 | Remaining: 00:00:20 | Avg Time/Step: 0.65
571
+ Step 18 | Loss: 10.615942 | Norm: 45.8611 | LR: 1.59e-05 | Time: 00:00:11 | Remaining: 00:00:19 | Avg Time/Step: 0.63
572
+ Step 19 | Loss: 10.434818 | Norm: 35.2029 | LR: 1.68e-05 | Time: 00:00:12 | Remaining: 00:00:18 | Avg Time/Step: 0.63
573
+ Step 20 | Loss: 9.872961 | Norm: 37.0101 | LR: 1.76e-05 | Time: 00:00:12 | Remaining: 00:00:17 | Avg Time/Step: 0.61
574
+ Validating...
575
+ Validation Loss: 10.45177936553955
576
+ Step 21 | Loss: 10.303642 | Norm: 38.1966 | LR: 1.85e-05 | Time: 00:00:15 | Remaining: 00:00:19 | Avg Time/Step: 0.69
577
+ Step 22 | Loss: 10.344124 | Norm: 36.7267 | LR: 1.93e-05 | Time: 00:00:15 | Remaining: 00:00:18 | Avg Time/Step: 0.67
578
+ Step 23 | Loss: 10.358528 | Norm: 33.4473 | LR: 2.01e-05 | Time: 00:00:15 | Remaining: 00:00:16 | Avg Time/Step: 0.65
579
+ Step 24 | Loss: 10.899721 | Norm: 33.1147 | LR: 2.10e-05 | Time: 00:00:15 | Remaining: 00:00:15 | Avg Time/Step: 0.64
580
+ Step 25 | Loss: 10.167845 | Norm: 32.0061 | LR: 2.18e-05 | Time: 00:00:16 | Remaining: 00:00:15 | Avg Time/Step: 0.63
581
+ Step 26 | Loss: 10.658374 | Norm: 32.8027 | LR: 2.27e-05 | Time: 00:00:17 | Remaining: 00:00:14 | Avg Time/Step: 0.65
582
+ Step 27 | Loss: 11.409204 | Norm: 32.2853 | LR: 2.35e-05 | Time: 00:00:17 | Remaining: 00:00:14 | Avg Time/Step: 0.64
583
+ Step 28 | Loss: 9.699551 | Norm: 28.3168 | LR: 2.43e-05 | Time: 00:00:18 | Remaining: 00:00:13 | Avg Time/Step: 0.62
584
+ Step 29 | Loss: 10.293508 | Norm: 29.6286 | LR: 2.52e-05 | Time: 00:00:18 | Remaining: 00:00:12 | Avg Time/Step: 0.62
585
+ Step 30 | Loss: 10.796824 | Norm: 32.4335 | LR: 2.60e-05 | Time: 00:00:18 | Remaining: 00:00:11 | Avg Time/Step: 0.61
586
+ Validating...
587
+ Validation Loss: 10.094558715820312
588
+ Step 31 | Loss: 9.871226 | Norm: 29.4167 | LR: 2.69e-05 | Time: 00:00:21 | Remaining: 00:00:11 | Avg Time/Step: 0.66
589
+ Step 32 | Loss: 9.355142 | Norm: 29.9377 | LR: 2.77e-05 | Time: 00:00:21 | Remaining: 00:00:10 | Avg Time/Step: 0.65
590
+ Step 33 | Loss: 9.169601 | Norm: 28.7526 | LR: 2.85e-05 | Time: 00:00:21 | Remaining: 00:00:10 | Avg Time/Step: 0.63
591
+ Step 34 | Loss: 11.027575 | Norm: 25.0330 | LR: 2.94e-05 | Time: 00:00:21 | Remaining: 00:00:09 | Avg Time/Step: 0.62
592
+ Step 35 | Loss: 9.624268 | Norm: 26.2367 | LR: 3.02e-05 | Time: 00:00:22 | Remaining: 00:00:08 | Avg Time/Step: 0.62
593
+ Step 36 | Loss: 10.801857 | Norm: 26.3915 | LR: 3.10e-05 | Time: 00:00:23 | Remaining: 00:00:08 | Avg Time/Step: 0.64
594
+ Step 37 | Loss: 10.625546 | Norm: 24.5588 | LR: 3.19e-05 | Time: 00:00:23 | Remaining: 00:00:07 | Avg Time/Step: 0.63
595
+ Step 38 | Loss: 9.325054 | Norm: 23.9140 | LR: 3.27e-05 | Time: 00:00:24 | Remaining: 00:00:06 | Avg Time/Step: 0.62
596
+ Step 39 | Loss: 8.672618 | Norm: 22.6858 | LR: 3.36e-05 | Time: 00:00:24 | Remaining: 00:00:06 | Avg Time/Step: 0.61
597
+ Step 40 | Loss: 9.316482 | Norm: 23.7950 | LR: 3.44e-05 | Time: 00:00:24 | Remaining: 00:00:05 | Avg Time/Step: 0.60
598
+ Validating...
599
+ Validation Loss: 9.847529411315918
600
+ Step 41 | Loss: 9.895099 | Norm: 22.9021 | LR: 3.52e-05 | Time: 00:00:26 | Remaining: 00:00:05 | Avg Time/Step: 0.64
601
+ Step 42 | Loss: 9.908270 | Norm: 22.2757 | LR: 3.61e-05 | Time: 00:00:27 | Remaining: 00:00:04 | Avg Time/Step: 0.63
602
+ Step 43 | Loss: 8.863647 | Norm: 21.6877 | LR: 3.69e-05 | Time: 00:00:27 | Remaining: 00:00:03 | Avg Time/Step: 0.62
603
+ Step 44 | Loss: 9.615014 | Norm: 22.1502 | LR: 3.78e-05 | Time: 00:00:27 | Remaining: 00:00:03 | Avg Time/Step: 0.62
604
+ Step 45 | Loss: 7.558504 | Norm: 20.6337 | LR: 3.86e-05 | Time: 00:00:28 | Remaining: 00:00:02 | Avg Time/Step: 0.61
605
+ Step 46 | Loss: 9.626184 | Norm: 22.2072 | LR: 3.94e-05 | Time: 00:00:29 | Remaining: 00:00:01 | Avg Time/Step: 0.63
606
+ Step 47 | Loss: 8.649675 | Norm: 21.2089 | LR: 4.03e-05 | Time: 00:00:29 | Remaining: 00:00:01 | Avg Time/Step: 0.62
607
+ Step 48 | Loss: 8.570056 | Norm: 21.1816 | LR: 4.11e-05 | Time: 00:00:30 | Remaining: 00:00:00 | Avg Time/Step: 0.61
608
+ Step 49 | Loss: 9.796856 | Norm: 20.8208 | LR: 4.20e-05 | Time: 00:00:30 | Remaining: 00:00:00 | Avg Time/Step: 0.61
gpt-2/training_shakespeare.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from model import GPT, GPTConfig
6
+ import tiktoken
7
+ from torch.utils.data import Dataset, DataLoader, DistributedSampler
8
+ import math
9
+ import matplotlib.pyplot as plt
10
+ from torch.distributed import init_process_group, destroy_process_group
11
+ from torch.nn.parallel import DistributedDataParallel as DDP
12
+ import torch.distributed as dist
13
+ import os
14
+
15
+
16
+ import signal
17
+ import sys
18
+
19
+ def signal_handler(sig, frame):
20
+ print('Gracefully stopping the training process')
21
+ destroy_process_group()
22
+ sys.exit(0)
23
+
24
+ signal.signal(signal.SIGINT, signal_handler)
25
+
26
+ torch.manual_seed(1337)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.manual_seed(1337)
29
+
30
+ # ***************************#
31
+ # Device Configuration
32
+ # ***************************#
33
+ device = torch.device("cpu")
34
+ if torch.cuda.is_available():
35
+ device = torch.device("cuda")
36
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
37
+ device = torch.device("mps")
38
+
39
+ print("Using device:", device)
40
+
41
+ # ***************************#
42
+ # Tokenizer Setup
43
+ # ***************************#
44
+ enc = tiktoken.get_encoding('gpt2')
45
+
46
+
47
+ lossi = []
48
+ val_lossi = []
49
+
50
+ # ***************************#
51
+ # Load Text Data
52
+ # ***************************#
53
+ with open("tinyshakespeare.txt", "r") as f:
54
+ text = f.read()
55
+ tokens = enc.encode(text)
56
+ print(f"Number of tokens: {len(tokens):,}")
57
+ # ***************************#
58
+ # Set up DDP
59
+ # ***************************#
60
+ # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
61
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
62
+ if ddp:
63
+ # use of DDP atm demands CUDA, we set the device appropriately according to rank
64
+ assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
65
+ init_process_group(backend='nccl')
66
+ ddp_rank = int(os.environ['RANK'])
67
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
68
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
69
+ device = f'cuda:{ddp_local_rank}'
70
+ torch.cuda.set_device(device)
71
+ # this process will do logging, checkpointing etc.
72
+ master_process = ddp_rank == 0
73
+ else:
74
+ # vanilla, non-DDP run
75
+ ddp_rank = 0
76
+ ddp_local_rank = 0
77
+ ddp_world_size = 1
78
+ master_process = True
79
+
80
+ if master_process:
81
+ print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}")
82
+
83
+ # ***************************#
84
+ # Model Configuration
85
+ # ***************************#
86
+
87
+ gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device)
88
+ if device == torch.device("cuda"):
89
+ gpt.compile()
90
+ if ddp:
91
+ gpt = DDP(gpt, device_ids=[ddp_local_rank])
92
+
93
+ raw_gpt = gpt.module if ddp else gpt
94
+
95
+ # ***************************#
96
+ # Dataset and Dataloader
97
+ # ***************************#
98
+ from torch.utils.data import Subset
99
+
100
+ class ShakespeareDataset(Dataset):
101
+ def __init__(self, tokens, seq_len):
102
+ self.tokens = tokens
103
+ self.seq_len = seq_len
104
+
105
+ def __len__(self):
106
+ return len(self.tokens) - self.seq_len - 1
107
+
108
+ def __getitem__(self, idx):
109
+ x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long)
110
+ y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long)
111
+ return x, y
112
+
113
+ # Split the dataset into training and validation sets
114
+ def split_dataset(dataset, val_ratio=0.0005):
115
+ dataset_size = len(dataset)
116
+ indices = list(range(dataset_size))
117
+ split = int(val_ratio * dataset_size)
118
+
119
+ train_indices, val_indices = indices[split:], indices[:split]
120
+ train_dataset = Subset(dataset, train_indices)
121
+ val_dataset = Subset(dataset, val_indices)
122
+
123
+ return train_dataset, val_dataset
124
+
125
+ T = 8
126
+ batch_size = 4
127
+ total_batch_size = 2**8 # 524,288 = 2**19, in number of tokens
128
+ assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T"
129
+ grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size)
130
+
131
+ if master_process:
132
+ print("Total desired batch size: {:,}".format(total_batch_size))
133
+ print("gradient accumulation steps: {:,}".format(grad_accum_steps))
134
+
135
+ dataset = ShakespeareDataset(tokens, T)
136
+ train_dataset, val_dataset = split_dataset(dataset)
137
+
138
+ if ddp:
139
+ train_sampler = DistributedSampler(train_dataset)
140
+ val_sampler = DistributedSampler(val_dataset)
141
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
142
+ val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)
143
+ else:
144
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
145
+ val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
146
+
147
+ if master_process:
148
+ print(f"The training dataloader has {len(train_dataloader):,} individual batches")
149
+ print(f"The validation dataloader has {len(val_dataloader):,} individual batches")
150
+
151
+ # ***************************#
152
+ # Text Generation Function
153
+ # ***************************#
154
+
155
+
156
+ def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True):
157
+ model.eval()
158
+ with torch.no_grad():
159
+ tokens = enc.encode(seed_text)
160
+ for _ in range(max_len):
161
+ x = torch.tensor(tokens[-T:], dtype=torch.long,
162
+ device=device).unsqueeze(0)
163
+ logits, _ = model(x)
164
+ next_token = torch.argmax(logits[:, -1, :])
165
+ tokens.append(int(next_token))
166
+
167
+ if print_while_generating:
168
+ print(enc.decode([int(next_token)]), end="")
169
+ print()
170
+
171
+ return enc.decode(tokens)
172
+
173
+
174
+ # ***************************#
175
+ # Optimizer Configuration
176
+ # ***************************#
177
+ if ddp:
178
+ optimizer = raw_gpt.configure_optimizers(
179
+ weight_decay=0.1, learning_rate=6e-4, device=device)
180
+ else:
181
+ optimizer = gpt.configure_optimizers(
182
+ weight_decay=0.1, learning_rate=6e-4, device=device)
183
+ torch.set_float32_matmul_precision('high')
184
+ # ***************************#
185
+ # Learning Rate Scheduler
186
+ # ***************************#
187
+ max_lr = 6e-4
188
+ min_lr = max_lr * 0.1
189
+ warmup_steps = 10
190
+ max_steps = 20000
191
+
192
+
193
+ def get_lr(step):
194
+ if step < warmup_steps:
195
+ return max_lr * (step+1) / warmup_steps
196
+ if step > max_steps:
197
+ return min_lr
198
+ decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
199
+ assert 0 <= decay_ratio <= 1
200
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
201
+ return min_lr + coeff * (max_lr - min_lr)
202
+
203
+
204
+ # Check if the device supports bfloat16
205
+ supports_bfloat16 = False
206
+ if device == "cuda":
207
+ capability = torch.cuda.get_device_capability()
208
+ if capability[0] >= 8 and capability[1] >= 0:
209
+ supports_bfloat16 = True
210
+
211
+ # ***************************#
212
+ # Training Loop
213
+ # ***************************#
214
+ generate_every = 50
215
+ validate_every = 5
216
+ for step in range(max_steps):
217
+ gpt.zero_grad()
218
+ loss_accum = 0.0
219
+ for minibatchstep in range(grad_accum_steps):
220
+ x, y = next(iter(train_dataloader))
221
+ x, y = x.to(device), y.to(device)
222
+
223
+ if supports_bfloat16:
224
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
225
+ logits, loss = gpt(x, y)
226
+ else:
227
+ logits, loss = gpt(x, y)
228
+
229
+ loss = loss / grad_accum_steps
230
+ loss_accum += loss.detach()
231
+ if ddp:
232
+ gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1)
233
+ loss.backward()
234
+
235
+ if ddp:
236
+ dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
237
+ lossi.append(loss_accum.item())
238
+ norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
239
+ lr = get_lr(step)
240
+ for param_group in optimizer.param_groups:
241
+ param_group['lr'] = lr
242
+ optimizer.step()
243
+
244
+ if master_process:
245
+ print(f'Step {step}, Loss: {loss_accum}, Norm: {norm}')
246
+
247
+ if step % generate_every == 0 and master_process:
248
+ print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False))
249
+
250
+ # Validation step
251
+ if step % validate_every == 0:
252
+ if master_process:
253
+ print("Validating...")
254
+ gpt.eval()
255
+ val_loss_accum = 0.0
256
+ with torch.no_grad():
257
+ for val_x, val_y in val_dataloader:
258
+ val_x, val_y = val_x.to(device), val_y.to(device)
259
+ if supports_bfloat16:
260
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
261
+ val_logits, val_loss = gpt(val_x, val_y)
262
+ else:
263
+ val_logits, val_loss = gpt(val_x, val_y)
264
+
265
+ val_loss_accum += val_loss.detach()
266
+ val_lossi.append(val_loss_accum.item())
267
+ if ddp:
268
+ dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
269
+ val_loss_avg = val_loss_accum / len(val_dataloader)
270
+ if master_process:
271
+ print(f'Validation Loss: {val_loss_avg}')
272
+ gpt.train()
273
+
274
+ # ***************************#
275
+ # Plot Loss
276
+ # ***************************#
277
+ if master_process:
278
+ plt.plot(lossi)
279
+ plt.show()
280
+
281
+ # Generate Final Text
282
+ if master_process:
283
+ generate_text("The king said", gpt, enc, max_len=25)
284
+
285
+ # ***************************#
286
+ # Save Model and Loss
287
+ # ***************************#
288
+ if master_process:
289
+ torch.save(gpt.state_dict(), "gpt2_shakespeare.pth")
290
+ torch.save(torch.tensor(lossi), "lossi.pth")
291
+
292
+ # ***************************#
293
+ # Cleanup
294
+ # ***************************#
295
+ if ddp:
296
+ destroy_process_group()
297
+
298
+ import sys; sys.exit(0)
gpt-2/val_lossi.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0931c157e2c170276acc822cad49e2860a2a21eb1bda709f8a0f5baf137e1d56
3
+ size 1190
gpt-2/val_lossi_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86ec4c787ab69379e81310a0262652d43ad2d84484d35f0db34d7566d606faf7
3
+ size 1949