from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F from model import GPT, GPTConfig import tiktoken from torch.utils.data import Dataset, DataLoader, DistributedSampler import math import matplotlib.pyplot as plt from torch.distributed import init_process_group, destroy_process_group from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist import os import signal import sys import numpy as np import time import logging def seconds_to_hms(seconds): return time.strftime('%H:%M:%S', time.gmtime(seconds)) def signal_handler(sig, frame): print('Gracefully stopping the training process') destroy_process_group() sys.exit(0) signal.signal(signal.SIGINT, signal_handler) manual_seed = 1339 torch.manual_seed(manual_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(manual_seed) # ***************************# # Device Configuration # ***************************# device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): device = torch.device("mps") print("Using device:", device) # ***************************# # Tokenizer Setup # ***************************# enc = tiktoken.get_encoding('gpt2') lossi = [] val_lossi = [] # ***************************# # Load Text Data # ***************************# with open("tinyshakespeare.txt", "r") as f: text = f.read() tokens = enc.encode(text) print(f"Number of tokens: {len(tokens):,}") # ***************************# # Set up DDP # ***************************# # torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: # use of DDP atm demands CUDA, we set the device appropriately according to rank assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" init_process_group(backend='nccl') ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) ddp_world_size = int(os.environ['WORLD_SIZE']) device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 else: # vanilla, non-DDP run ddp_rank = 0 ddp_local_rank = 0 ddp_world_size = 1 master_process = True if master_process: print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}") # ***************************# # Model Configuration # ***************************# gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device) if device == torch.device("cuda"): gpt.compile() if ddp: gpt = DDP(gpt, device_ids=[ddp_local_rank]) raw_gpt = gpt.module if ddp else gpt # ***************************# # Dataset and Dataloader # ***************************# def load_tokens(filename): npt = np.load(filename) npt = npt.astype(np.int32) # added after video ptt = torch.tensor(npt, dtype=torch.long) return ptt class DataLoader_Custom: def __init__(self, B, T, process_rank, num_processes, split, shuffle=False): self.B = B self.T = T self.process_rank = process_rank self.num_processes = num_processes self.shuffle = shuffle assert split in ["train", "val"] data_root = "edu_fineweb10B" shards = os.listdir(data_root) shards = [s for s in shards if split in s] shards = sorted(shards) shards = [os.path.join(data_root, s) for s in shards] self.shards = shards assert len(shards) > 0, "No shards found for split {}".format(split) if master_process: print("Found {} shards for split {}".format(len(shards), split)) self.current_shard = 0 self.tokens = load_tokens(self.shards[self.current_shard]) self.current_position = self.B * self.T * self.process_rank def next_batch(self): B, T = self.B, self.T buf = self.tokens[self.current_position:self.current_position + B*T+1] x = buf[:-1].view(B, T) y = buf[1:].view(B, T) self.current_position += B*T * self.num_processes if self.current_position + (B*T*self.num_processes+1) > len(self.tokens): self.current_shard = self.current_shard + 1 % len(self.shards) self.tokens = load_tokens(self.shards[self.current_shard]) self.current_position = self.B * self.T * self.process_rank return x, y def reset(self): self.current_shard = 0 self.tokens = load_tokens(self.shards[self.current_shard]) self.current_position = self.B * self.T * self.process_rank T = 4 batch_size = 1 total_batch_size = 2**2 # 524,288 = 2**19, in number of tokens assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T" grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size) if master_process: print("Total desired batch size: {:,}".format(total_batch_size)) print("gradient accumulation steps: {:,}".format(grad_accum_steps)) train_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "train") val_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "val") # ***************************# # Text Generation Function # ***************************# def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True): if print_while_generating: print(seed_text, end="") model.eval() with torch.no_grad(): tokens = enc.encode(seed_text) for _ in range(max_len): x = torch.tensor(tokens[-T:], dtype=torch.long, device=device).unsqueeze(0) logits, _ = model(x) next_token = torch.argmax(logits[:, -1, :]) tokens.append(int(next_token)) if print_while_generating: print(enc.decode([int(next_token)]), end="") print() return enc.decode(tokens) # ***************************# # Optimizer Configuration # ***************************# if ddp: optimizer = raw_gpt.configure_optimizers( weight_decay=0.1, learning_rate=6e-4, device=device) else: optimizer = gpt.configure_optimizers( weight_decay=0.1, learning_rate=6e-4, device=device) torch.set_float32_matmul_precision('high') # ***************************# # Learning Rate Scheduler # ***************************# max_lr = 6e-4 min_lr = max_lr * 0.1 warmup_steps = 715 max_steps = 50 def get_lr(step): if step < warmup_steps: return max_lr * (step+1) / warmup_steps if step > max_steps: return min_lr decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (max_lr - min_lr) # Check if the device supports bfloat16 supports_bfloat16 = False if device == "cuda": capability = torch.cuda.get_device_capability() if capability[0] >= 8 and capability[1] >= 0: supports_bfloat16 = True print("Supports bfloat16:", supports_bfloat16) # ***************************# # Training Loop # ***************************# generate_every = 50 validate_every = 10 save_every = 5 t0 = time.time() # Initialize logging logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger(__name__) # Add a file handler file_handler = logging.FileHandler('training_log.txt') file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter('%(message)s')) logger.addHandler(file_handler) for step in range(max_steps): loss_accum = 0.0 gpt.zero_grad() for minibatchstep in range(grad_accum_steps): x, y = train_dataloader.next_batch() x, y = x.to(device), y.to(device) if supports_bfloat16: with torch.autocast(device_type='cuda', dtype=torch.bfloat16): logits, loss = gpt(x, y) else: logits, loss = gpt(x, y) loss = loss / grad_accum_steps loss_accum += loss.detach() if ddp: gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1) loss.backward() if ddp: dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) lossi.append(loss_accum.item()) norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0) lr = get_lr(step) for param_group in optimizer.param_groups: param_group['lr'] = lr optimizer.step() t_current = time.time() elapsed_time = t_current - t0 steps_completed = step + 1 avg_time_per_step = elapsed_time / steps_completed remaining_steps = max_steps - steps_completed remaining_time = remaining_steps * avg_time_per_step if master_process: logger.info(f'Step {step} | Loss: {loss_accum:.6f} | Norm: {norm:.4f} | LR: {lr:.2e} | Time: {seconds_to_hms(elapsed_time)} | Remaining: {seconds_to_hms(remaining_time)} | Avg Time/Step: {avg_time_per_step:.2f}') if master_process and step % generate_every == 0: generated_text = generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False) logger.info(f'Generated Text at Step {step}: {generated_text}') # Validation step if step % validate_every == 0: if master_process: logger.info("Validating...") gpt.eval() val_loss_accum = 0.0 val_dataloader.reset() with torch.no_grad(): val_loss_accum val_loss_steps = 20 for _ in range(val_loss_steps): x, y = val_dataloader.next_batch() x, y = x.to(device), y.to(device) if supports_bfloat16: with torch.autocast(device_type='cuda', dtype=torch.bfloat16): val_logits, val_loss = gpt(x, y) else: val_logits, val_loss = gpt(x, y) val_loss = val_loss / val_loss_steps val_loss_accum += val_loss.detach() if ddp: dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) if master_process: logger.info(f'Validation Loss: {val_loss_accum}') val_lossi.append(val_loss_accum.item()) if step % save_every == 0 and master_process: print("Saving model and loss...") torch.save(raw_gpt.state_dict(), "gpt2_step_{}.pth".format(step)) torch.save(torch.tensor(lossi), "lossi_step_{}.pth".format(step)) torch.save(torch.tensor(val_lossi), "val_lossi_step_{}.pth".format(step)) # ***************************# # Plot Loss # ***************************# plot = True if master_process and plot: plt.plot(lossi, label="Train Loss") # Stretch val_lossi to match the length of lossi val_lossi_stretched = np.interp( np.linspace(0, len(val_lossi) - 1, len(lossi)), np.arange(len(val_lossi)), val_lossi ) plt.plot(val_lossi_stretched, label="Validation Loss") plt.legend() plt.xlabel("Step") plt.ylabel("Loss") plt.show() # Generate Final Text if master_process: print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)) # ***************************# # Save Model and Loss # ***************************# if master_process: torch.save(gpt.state_dict(), "gpt2_shakespeare.pth") torch.save(torch.tensor(lossi), "lossi.pth") torch.save(torch.tensor(val_lossi), "val_lossi.pth") # ***************************# # Cleanup # ***************************# if ddp: destroy_process_group() import sys; sys.exit(0)