from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from model import GPT, GPTConfig import tiktoken from torch.utils.data import Dataset, DataLoader, DistributedSampler import math import matplotlib.pyplot as plt from torch.distributed import init_process_group, destroy_process_group from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist import os import signal import sys def signal_handler(sig, frame): print('Gracefully stopping the training process') destroy_process_group() sys.exit(0) signal.signal(signal.SIGINT, signal_handler) torch.manual_seed(1337) if torch.cuda.is_available(): torch.cuda.manual_seed(1337) # ***************************# # Device Configuration # ***************************# device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") print("Using device:", device) # ***************************# # Tokenizer Setup # ***************************# enc = tiktoken.get_encoding('gpt2') lossi = [] val_lossi = [] # ***************************# # Load Text Data # ***************************# with open("tinyshakespeare.txt", "r") as f: text = f.read() tokens = enc.encode(text) print(f"Number of tokens: {len(tokens):,}") # ***************************# # Set up DDP # ***************************# # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: # use of DDP atm demands CUDA, we set the device appropriately according to rank assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" init_process_group(backend='nccl') ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 else: # vanilla, non-DDP run ddp_rank = 0 ddp_local_rank = 0 ddp_world_size = 1 master_process = True if master_process: print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}") # ***************************# # Model Configuration # ***************************# gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device) if device == torch.device("cuda"): gpt.compile() if ddp: gpt = DDP(gpt, device_ids=[ddp_local_rank]) raw_gpt = gpt.module if ddp else gpt # ***************************# # Dataset and Dataloader # ***************************# from torch.utils.data import Subset class ShakespeareDataset(Dataset): def __init__(self, tokens, seq_len): self.tokens = tokens self.seq_len = seq_len def __len__(self): return len(self.tokens) - self.seq_len - 1 def __getitem__(self, idx): x = torch.tensor(self.tokens[idx:idx + self.seq_len], dtype=torch.long) y = torch.tensor(self.tokens[idx + 1:idx + self.seq_len + 1], dtype=torch.long) return x, y # Split the dataset into training and validation sets def split_dataset(dataset, val_ratio=0.0005): dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(val_ratio * dataset_size) train_indices, val_indices = indices[split:], indices[:split] train_dataset = Subset(dataset, train_indices) val_dataset = Subset(dataset, val_indices) return train_dataset, val_dataset T = 8 batch_size = 4 total_batch_size = 2**8 # 524,288 = 2**19, in number of tokens assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T" grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size) if master_process: print("Total desired batch size: {:,}".format(total_batch_size)) print("gradient accumulation steps: {:,}".format(grad_accum_steps)) dataset = ShakespeareDataset(tokens, T) train_dataset, val_dataset = split_dataset(dataset) if ddp: train_sampler = DistributedSampler(train_dataset) val_sampler = DistributedSampler(val_dataset) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler) else: train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) if master_process: print(f"The training dataloader has {len(train_dataloader):,} individual batches") print(f"The validation dataloader has {len(val_dataloader):,} individual batches") # ***************************# # Text Generation Function # ***************************# def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True): model.eval() with torch.no_grad(): tokens = enc.encode(seed_text) for _ in range(max_len): x = torch.tensor(tokens[-T:], dtype=torch.long, device=device).unsqueeze(0) logits, _ = model(x) next_token = torch.argmax(logits[:, -1, :]) tokens.append(int(next_token)) if print_while_generating: print(enc.decode([int(next_token)]), end="") print() return enc.decode(tokens) # ***************************# # Optimizer Configuration # ***************************# if ddp: optimizer = raw_gpt.configure_optimizers( weight_decay=0.1, learning_rate=6e-4, device=device) else: optimizer = gpt.configure_optimizers( weight_decay=0.1, learning_rate=6e-4, device=device) torch.set_float32_matmul_precision('high') # ***************************# # Learning Rate Scheduler # ***************************# max_lr = 6e-4 min_lr = max_lr * 0.1 warmup_steps = 10 max_steps = 20000 def get_lr(step): if step < warmup_steps: return max_lr * (step+1) / warmup_steps if step > max_steps: return min_lr decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (max_lr - min_lr) # Check if the device supports bfloat16 supports_bfloat16 = False if device == "cuda": capability = torch.cuda.get_device_capability() if capability[0] >= 8 and capability[1] >= 0: supports_bfloat16 = True # ***************************# # Training Loop # ***************************# generate_every = 50 validate_every = 5 for step in range(max_steps): gpt.zero_grad() loss_accum = 0.0 for minibatchstep in range(grad_accum_steps): x, y = next(iter(train_dataloader)) x, y = x.to(device), y.to(device) if supports_bfloat16: with torch.autocast(device_type='cuda', dtype=torch.bfloat16): logits, loss = gpt(x, y) else: logits, loss = gpt(x, y) loss = loss / grad_accum_steps loss_accum += loss.detach() if ddp: gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1) loss.backward() if ddp: dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) lossi.append(loss_accum.item()) norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0) lr = get_lr(step) for param_group in optimizer.param_groups: param_group['lr'] = lr optimizer.step() if master_process: print(f'Step {step}, Loss: {loss_accum}, Norm: {norm}') if step % generate_every == 0 and master_process: print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)) # Validation step if step % validate_every == 0: if master_process: print("Validating...") gpt.eval() val_loss_accum = 0.0 with torch.no_grad(): for val_x, val_y in val_dataloader: val_x, val_y = val_x.to(device), val_y.to(device) if supports_bfloat16: with torch.autocast(device_type='cuda', dtype=torch.bfloat16): val_logits, val_loss = gpt(val_x, val_y) else: val_logits, val_loss = gpt(val_x, val_y) val_loss_accum += val_loss.detach() val_lossi.append(val_loss_accum.item()) if ddp: dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) val_loss_avg = val_loss_accum / len(val_dataloader) if master_process: print(f'Validation Loss: {val_loss_avg}') gpt.train() # ***************************# # Plot Loss # ***************************# if master_process: plt.plot(lossi) plt.show() # Generate Final Text if master_process: generate_text("The king said", gpt, enc, max_len=25) # ***************************# # Save Model and Loss # ***************************# if master_process: torch.save(gpt.state_dict(), "gpt2_shakespeare.pth") torch.save(torch.tensor(lossi), "lossi.pth") # ***************************# # Cleanup # ***************************# if ddp: destroy_process_group() import sys; sys.exit(0)