gpt-project / gpt-2 /training_full_dataset.py
mnmnmnmn's picture
Upload 15 files
7fc0f78 verified
raw
history blame
11.8 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import GPT, GPTConfig
import tiktoken
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import math
import matplotlib.pyplot as plt
from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os
import signal
import sys
import numpy as np
import time
import logging
def seconds_to_hms(seconds):
return time.strftime('%H:%M:%S', time.gmtime(seconds))
def signal_handler(sig, frame):
print('Gracefully stopping the training process')
destroy_process_group()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
manual_seed = 1339
torch.manual_seed(manual_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(manual_seed)
# ***************************#
# Device Configuration
# ***************************#
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
print("Using device:", device)
# ***************************#
# Tokenizer Setup
# ***************************#
enc = tiktoken.get_encoding('gpt2')
lossi = []
val_lossi = []
# ***************************#
# Load Text Data
# ***************************#
with open("tinyshakespeare.txt", "r") as f:
text = f.read()
tokens = enc.encode(text)
print(f"Number of tokens: {len(tokens):,}")
# ***************************#
# Set up DDP
# ***************************#
# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
# use of DDP atm demands CUDA, we set the device appropriately according to rank
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
# this process will do logging, checkpointing etc.
master_process = ddp_rank == 0
else:
# vanilla, non-DDP run
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
if master_process:
print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}")
# ***************************#
# Model Configuration
# ***************************#
gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device)
if device == torch.device("cuda"):
gpt.compile()
if ddp:
gpt = DDP(gpt, device_ids=[ddp_local_rank])
raw_gpt = gpt.module if ddp else gpt
# ***************************#
# Dataset and Dataloader
# ***************************#
def load_tokens(filename):
npt = np.load(filename)
npt = npt.astype(np.int32) # added after video
ptt = torch.tensor(npt, dtype=torch.long)
return ptt
class DataLoader_Custom:
def __init__(self, B, T, process_rank, num_processes, split, shuffle=False):
self.B = B
self.T = T
self.process_rank = process_rank
self.num_processes = num_processes
self.shuffle = shuffle
assert split in ["train", "val"]
data_root = "edu_fineweb10B"
shards = os.listdir(data_root)
shards = [s for s in shards if split in s]
shards = sorted(shards)
shards = [os.path.join(data_root, s) for s in shards]
self.shards = shards
assert len(shards) > 0, "No shards found for split {}".format(split)
if master_process:
print("Found {} shards for split {}".format(len(shards), split))
self.current_shard = 0
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position:self.current_position + B*T+1]
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)
self.current_position += B*T * self.num_processes
if self.current_position + (B*T*self.num_processes+1) > len(self.tokens):
self.current_shard = self.current_shard + 1 % len(self.shards)
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
return x, y
def reset(self):
self.current_shard = 0
self.tokens = load_tokens(self.shards[self.current_shard])
self.current_position = self.B * self.T * self.process_rank
T = 4
batch_size = 1
total_batch_size = 2**2 # 524,288 = 2**19, in number of tokens
assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T"
grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size)
if master_process:
print("Total desired batch size: {:,}".format(total_batch_size))
print("gradient accumulation steps: {:,}".format(grad_accum_steps))
train_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "train")
val_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "val")
# ***************************#
# Text Generation Function
# ***************************#
def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True):
if print_while_generating:
print(seed_text, end="")
model.eval()
with torch.no_grad():
tokens = enc.encode(seed_text)
for _ in range(max_len):
x = torch.tensor(tokens[-T:], dtype=torch.long,
device=device).unsqueeze(0)
logits, _ = model(x)
next_token = torch.argmax(logits[:, -1, :])
tokens.append(int(next_token))
if print_while_generating:
print(enc.decode([int(next_token)]), end="")
print()
return enc.decode(tokens)
# ***************************#
# Optimizer Configuration
# ***************************#
if ddp:
optimizer = raw_gpt.configure_optimizers(
weight_decay=0.1, learning_rate=6e-4, device=device)
else:
optimizer = gpt.configure_optimizers(
weight_decay=0.1, learning_rate=6e-4, device=device)
torch.set_float32_matmul_precision('high')
# ***************************#
# Learning Rate Scheduler
# ***************************#
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 715
max_steps = 50
def get_lr(step):
if step < warmup_steps:
return max_lr * (step+1) / warmup_steps
if step > max_steps:
return min_lr
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
return min_lr + coeff * (max_lr - min_lr)
# Check if the device supports bfloat16
supports_bfloat16 = False
if device == "cuda":
capability = torch.cuda.get_device_capability()
if capability[0] >= 8 and capability[1] >= 0:
supports_bfloat16 = True
print("Supports bfloat16:", supports_bfloat16)
# ***************************#
# Training Loop
# ***************************#
generate_every = 50
validate_every = 10
save_every = 5
t0 = time.time()
# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)
# Add a file handler
file_handler = logging.FileHandler('training_log.txt')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(message)s'))
logger.addHandler(file_handler)
for step in range(max_steps):
loss_accum = 0.0
gpt.zero_grad()
for minibatchstep in range(grad_accum_steps):
x, y = train_dataloader.next_batch()
x, y = x.to(device), y.to(device)
if supports_bfloat16:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
logits, loss = gpt(x, y)
else:
logits, loss = gpt(x, y)
loss = loss / grad_accum_steps
loss_accum += loss.detach()
if ddp:
gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1)
loss.backward()
if ddp:
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)
lossi.append(loss_accum.item())
norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0)
lr = get_lr(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
t_current = time.time()
elapsed_time = t_current - t0
steps_completed = step + 1
avg_time_per_step = elapsed_time / steps_completed
remaining_steps = max_steps - steps_completed
remaining_time = remaining_steps * avg_time_per_step
if master_process:
logger.info(f'Step {step} | Loss: {loss_accum:.6f} | Norm: {norm:.4f} | LR: {lr:.2e} | Time: {seconds_to_hms(elapsed_time)} | Remaining: {seconds_to_hms(remaining_time)} | Avg Time/Step: {avg_time_per_step:.2f}')
if master_process and step % generate_every == 0:
generated_text = generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)
logger.info(f'Generated Text at Step {step}: {generated_text}')
# Validation step
if step % validate_every == 0:
if master_process:
logger.info("Validating...")
gpt.eval()
val_loss_accum = 0.0
val_dataloader.reset()
with torch.no_grad():
val_loss_accum
val_loss_steps = 20
for _ in range(val_loss_steps):
x, y = val_dataloader.next_batch()
x, y = x.to(device), y.to(device)
if supports_bfloat16:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
val_logits, val_loss = gpt(x, y)
else:
val_logits, val_loss = gpt(x, y)
val_loss = val_loss / val_loss_steps
val_loss_accum += val_loss.detach()
if ddp:
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
if master_process:
logger.info(f'Validation Loss: {val_loss_accum}')
val_lossi.append(val_loss_accum.item())
if step % save_every == 0 and master_process:
print("Saving model and loss...")
torch.save(raw_gpt.state_dict(), "gpt2_step_{}.pth".format(step))
torch.save(torch.tensor(lossi), "lossi_step_{}.pth".format(step))
torch.save(torch.tensor(val_lossi), "val_lossi_step_{}.pth".format(step))
# ***************************#
# Plot Loss
# ***************************#
plot = True
if master_process and plot:
plt.plot(lossi, label="Train Loss")
# Stretch val_lossi to match the length of lossi
val_lossi_stretched = np.interp(
np.linspace(0, len(val_lossi) - 1, len(lossi)),
np.arange(len(val_lossi)),
val_lossi
)
plt.plot(val_lossi_stretched, label="Validation Loss")
plt.legend()
plt.xlabel("Step")
plt.ylabel("Loss")
plt.show()
# Generate Final Text
if master_process:
print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False))
# ***************************#
# Save Model and Loss
# ***************************#
if master_process:
torch.save(gpt.state_dict(), "gpt2_shakespeare.pth")
torch.save(torch.tensor(lossi), "lossi.pth")
torch.save(torch.tensor(val_lossi), "val_lossi.pth")
# ***************************#
# Cleanup
# ***************************#
if ddp:
destroy_process_group()
import sys; sys.exit(0)