""" Training script for Dynamic Token-Aware Transformer (DTAT) on enwik8 dataset. Based on NanoGPT's training structure with modifications for token importance awareness. """ import os import time import math import pickle from contextlib import nullcontext import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import matplotlib.pyplot as plt import wandb from tqdm import tqdm from datetime import datetime from model_dtat import DTATTransformer from config.dtat_config import get_config # ----------------------------------------------------------------------------- # I/O def get_batch(data, block_size, batch_size, device): """Generate a small batch of data of inputs x and targets y.""" ix = torch.randint(len(data) - block_size, (batch_size,)) x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) x, y = x.to(device), y.to(device) return x, y def compute_freq_table(data, vocab_size=256): """Compute frequency table for the dataset.""" freq = np.bincount(data, minlength=vocab_size) return freq / len(data) def visualize_importance(tokens, importance_scores, iter_num): """ Visualize token importance scores """ plt.figure(figsize=(15, 5)) # Detach and move to CPU before converting to numpy scores = importance_scores.detach().squeeze().cpu() plt.bar(range(len(tokens)), scores) plt.title(f'Token Importance Scores (Iteration {iter_num})') plt.xlabel('Token Position') plt.ylabel('Importance Score') # Add token labels if sequence is not too long if len(tokens) <= 50: plt.xticks(range(len(tokens)), tokens, rotation=45) # Save plot to wandb wandb.log({ 'importance_scores': wandb.Image(plt), 'iter': iter_num }) plt.close() # ----------------------------------------------------------------------------- # Training def estimate_loss(model, data, config): out = {} model.eval() losses = torch.zeros(config.eval_iters) for k in range(config.eval_iters): X, Y = get_batch(data, config.block_size, config.batch_size, config.device) with torch.no_grad(): logits, loss, _ = model(X, Y) losses[k] = loss.item() out = losses.mean() model.train() return out def get_lr(it, config): """ Learning rate scheduler with linear warmup and cosine decay """ # Linear warmup if it < config.warmup_iters: return config.learning_rate * it / config.warmup_iters # Cosine decay with minimum learning rate if config.decay_lr: decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters) decay_ratio = min(decay_ratio, 1.0) # Cap at 1.0 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return config.min_lr + coeff * (config.learning_rate - config.min_lr) return config.learning_rate def main(): # Initialize distributed training if needed ddp = int(os.environ.get('RANK', -1)) != -1 if ddp: init_process_group(backend='nccl') ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) device = f'cuda:{ddp_local_rank}' master_process = ddp_rank == 0 seed_offset = ddp_rank assert config.batch_size % torch.cuda.device_count() == 0 config.batch_size = config.batch_size // torch.cuda.device_count() else: device = 'cuda' if torch.cuda.is_available() else 'cpu' master_process = True seed_offset = 0 # Set seed for reproducibility torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device_type = 'cuda' if 'cuda' in device else 'cpu' # Get config config = get_config() config.device = device # Initialize wandb if master_process: wandb.init(project="enwik8-dtat") wandb.config.update(config.__dict__) # Adjust warmup config.warmup_iters = 2000 # Increased warmup iterations config.learning_rate = 6e-4 # Confirmed learning rate # Data loading print("Loading data...") data_dir = os.path.join('data') train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r') val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r') # Compute frequency table for the training data freq_table = compute_freq_table(train_data) # Model init print("Initializing model...") model = DTATTransformer(config) model.to(device) # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay ) if ddp: model = DDP(model, device_ids=[ddp_local_rank]) # Enable torch compile if available (PyTorch 2.0+) if hasattr(torch, 'compile'): try: model = torch.compile(model) print("Using torch.compile() for faster training") except: print("torch.compile() failed, falling back to default model") # Gradient scaler for mixed precision scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision) # Enable cuDNN benchmarking for faster training torch.backends.cudnn.benchmark = True # Create checkpoint directory if it doesn't exist checkpoint_dir = os.path.join('checkpoints', 'dtat') os.makedirs(checkpoint_dir, exist_ok=True) # Training loop print("Starting training...") print(f"Saving checkpoints to: {checkpoint_dir}") # Calculate total steps and epochs total_steps = config.max_iters batch_size = config.batch_size block_size = config.block_size total_epochs = (total_steps * batch_size * block_size) // len(train_data) # Create progress bar pbar = tqdm(range(config.max_iters), desc=f"Training (0/{total_epochs} epochs)") best_val_loss = float('inf') no_improvement = 0 running_mfu = -1.0 t0 = time.time() for iter_num in pbar: # Early stopping check if no_improvement >= config.patience: print(f"\nEarly stopping triggered after {iter_num} iterations") print(f"Best validation loss: {best_val_loss:.4f}") break # Update learning rate lr = get_lr(iter_num, config) for param_group in optimizer.param_groups: param_group['lr'] = lr # Sample a batch of data X, Y = get_batch(train_data, config.block_size, config.batch_size, device) # Mixed precision training with torch.cuda.amp.autocast(enabled=config.mixed_precision): logits, loss, importance_scores = model(X, Y) # Backward pass with gradient scaling optimizer.zero_grad(set_to_none=True) # Slightly faster than zero_grad() scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) scaler.step(optimizer) scaler.update() # Logging if iter_num % config.log_interval == 0: # Calculate current epoch current_tokens = (iter_num + 1) * batch_size * block_size current_epoch = current_tokens / len(train_data) # Calculate gradients and importance stats grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item() importance_mean = importance_scores.mean().item() # Update progress bar pbar.set_description( f"Training ({current_epoch:.1f}/{total_epochs} epochs) | " f"loss: {loss.item():.4f} | " # This is now directly in BPC f"bpc: {loss.item():.2f} | " # Same as loss since it's already BPC f"imp: {importance_mean:.2f} | " f"lr: {lr:.1e} | " f"tokens/sec: {(batch_size * block_size) / (time.time() - t0):.1f}" ) # Log to wandb wandb.log({ "iter": iter_num, "loss": loss.item(), # This is now directly in BPC "bpc": loss.item(), # Same as loss since it's already BPC "lr": lr, "grad_norm": grad_norm, "importance_mean": importance_mean, "epoch": current_epoch, "tokens_per_sec": (batch_size * block_size) / (time.time() - t0), }) # Reset timer t0 = time.time() # Visualize importance scores periodically if iter_num % (config.log_interval * 10) == 0: visualize_importance( X[0].cpu().numpy(), importance_scores[0], iter_num ) # Evaluation if iter_num > 0 and iter_num % config.eval_interval == 0: val_loss = estimate_loss(model, val_data, config) # Check for improvement if val_loss < best_val_loss - config.min_delta: best_val_loss = val_loss no_improvement = 0 print(f"Saved best model at iteration {iter_num} with val_loss: {val_loss:.4f}") torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'best.pt')) else: no_improvement += 1 # Log validation metrics wandb.log({ "iter": iter_num, "val_loss": val_loss, "val_bpc": val_loss, "epoch": current_epoch, }) # Save regular checkpoint every 5000 iterations if iter_num % 1000 == 0: checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'iter_num': iter_num, 'best_val_loss': best_val_loss, 'config': config, } checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_{iter_num:06d}.pt') torch.save(checkpoint, checkpoint_path) print(f"\nSaved checkpoint at iteration {iter_num} to {checkpoint_path}") wandb.finish() if __name__ == '__main__': main()