|
from dataclasses import dataclass |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from model import GPT, GPTConfig |
|
import tiktoken |
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
|
import math |
|
import matplotlib.pyplot as plt |
|
from torch.distributed import init_process_group, destroy_process_group |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
import torch.distributed as dist |
|
import os |
|
import signal |
|
import sys |
|
import numpy as np |
|
import time |
|
import logging |
|
|
|
def seconds_to_hms(seconds): |
|
return time.strftime('%H:%M:%S', time.gmtime(seconds)) |
|
|
|
|
|
def signal_handler(sig, frame): |
|
print('Gracefully stopping the training process') |
|
destroy_process_group() |
|
sys.exit(0) |
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
manual_seed = 1339 |
|
torch.manual_seed(manual_seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(manual_seed) |
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
|
|
print("Using device:", device) |
|
|
|
|
|
|
|
|
|
enc = tiktoken.get_encoding('gpt2') |
|
|
|
|
|
lossi = [] |
|
val_lossi = [] |
|
|
|
|
|
|
|
|
|
with open("tinyshakespeare.txt", "r") as f: |
|
text = f.read() |
|
tokens = enc.encode(text) |
|
print(f"Number of tokens: {len(tokens):,}") |
|
|
|
|
|
|
|
|
|
ddp = int(os.environ.get('RANK', -1)) != -1 |
|
if ddp: |
|
|
|
assert torch.cuda.is_available(), "for now i think we need CUDA for DDP" |
|
init_process_group(backend='nccl') |
|
ddp_rank = int(os.environ['RANK']) |
|
ddp_local_rank = int(os.environ['LOCAL_RANK']) |
|
ddp_world_size = int(os.environ['WORLD_SIZE']) |
|
device = f'cuda:{ddp_local_rank}' |
|
torch.cuda.set_device(device) |
|
|
|
master_process = ddp_rank == 0 |
|
else: |
|
|
|
ddp_rank = 0 |
|
ddp_local_rank = 0 |
|
ddp_world_size = 1 |
|
master_process = True |
|
|
|
if master_process: |
|
print(f"ddp: {ddp}, rank: {ddp_rank}, local_rank: {ddp_local_rank}, world_size: {ddp_world_size}, master_process: {master_process}") |
|
|
|
|
|
|
|
|
|
|
|
gpt = GPT(GPTConfig(vocab_size=50304), master_process).to(device) |
|
if device == torch.device("cuda"): |
|
gpt.compile() |
|
if ddp: |
|
gpt = DDP(gpt, device_ids=[ddp_local_rank]) |
|
|
|
raw_gpt = gpt.module if ddp else gpt |
|
|
|
|
|
|
|
|
|
|
|
def load_tokens(filename): |
|
npt = np.load(filename) |
|
npt = npt.astype(np.int32) |
|
ptt = torch.tensor(npt, dtype=torch.long) |
|
return ptt |
|
|
|
class DataLoader_Custom: |
|
def __init__(self, B, T, process_rank, num_processes, split, shuffle=False): |
|
self.B = B |
|
self.T = T |
|
self.process_rank = process_rank |
|
self.num_processes = num_processes |
|
self.shuffle = shuffle |
|
assert split in ["train", "val"] |
|
|
|
data_root = "edu_fineweb10B" |
|
shards = os.listdir(data_root) |
|
shards = [s for s in shards if split in s] |
|
shards = sorted(shards) |
|
shards = [os.path.join(data_root, s) for s in shards] |
|
self.shards = shards |
|
assert len(shards) > 0, "No shards found for split {}".format(split) |
|
if master_process: |
|
print("Found {} shards for split {}".format(len(shards), split)) |
|
self.current_shard = 0 |
|
self.tokens = load_tokens(self.shards[self.current_shard]) |
|
self.current_position = self.B * self.T * self.process_rank |
|
|
|
def next_batch(self): |
|
B, T = self.B, self.T |
|
buf = self.tokens[self.current_position:self.current_position + B*T+1] |
|
x = buf[:-1].view(B, T) |
|
y = buf[1:].view(B, T) |
|
self.current_position += B*T * self.num_processes |
|
if self.current_position + (B*T*self.num_processes+1) > len(self.tokens): |
|
self.current_shard = self.current_shard + 1 % len(self.shards) |
|
self.tokens = load_tokens(self.shards[self.current_shard]) |
|
self.current_position = self.B * self.T * self.process_rank |
|
|
|
return x, y |
|
|
|
def reset(self): |
|
self.current_shard = 0 |
|
self.tokens = load_tokens(self.shards[self.current_shard]) |
|
self.current_position = self.B * self.T * self.process_rank |
|
|
|
T = 4 |
|
batch_size = 1 |
|
total_batch_size = 2**2 |
|
assert total_batch_size % (T*batch_size*ddp_world_size) == 0, "Batch size is not divisible by B*T" |
|
grad_accum_steps = total_batch_size // (T*batch_size*ddp_world_size) |
|
|
|
if master_process: |
|
print("Total desired batch size: {:,}".format(total_batch_size)) |
|
print("gradient accumulation steps: {:,}".format(grad_accum_steps)) |
|
|
|
train_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "train") |
|
val_dataloader = DataLoader_Custom(batch_size, T, ddp_local_rank, ddp_world_size, "val") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_text(seed_text, model, enc, max_len=100, print_while_generating=True): |
|
if print_while_generating: |
|
print(seed_text, end="") |
|
model.eval() |
|
with torch.no_grad(): |
|
tokens = enc.encode(seed_text) |
|
for _ in range(max_len): |
|
x = torch.tensor(tokens[-T:], dtype=torch.long, |
|
device=device).unsqueeze(0) |
|
logits, _ = model(x) |
|
next_token = torch.argmax(logits[:, -1, :]) |
|
tokens.append(int(next_token)) |
|
|
|
if print_while_generating: |
|
print(enc.decode([int(next_token)]), end="") |
|
print() |
|
|
|
return enc.decode(tokens) |
|
|
|
|
|
|
|
|
|
|
|
if ddp: |
|
optimizer = raw_gpt.configure_optimizers( |
|
weight_decay=0.1, learning_rate=6e-4, device=device) |
|
else: |
|
optimizer = gpt.configure_optimizers( |
|
weight_decay=0.1, learning_rate=6e-4, device=device) |
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
max_lr = 6e-4 |
|
min_lr = max_lr * 0.1 |
|
warmup_steps = 715 |
|
max_steps = 50 |
|
|
|
|
|
def get_lr(step): |
|
if step < warmup_steps: |
|
return max_lr * (step+1) / warmup_steps |
|
if step > max_steps: |
|
return min_lr |
|
decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps) |
|
assert 0 <= decay_ratio <= 1 |
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
return min_lr + coeff * (max_lr - min_lr) |
|
|
|
|
|
|
|
supports_bfloat16 = False |
|
if device == "cuda": |
|
capability = torch.cuda.get_device_capability() |
|
if capability[0] >= 8 and capability[1] >= 0: |
|
supports_bfloat16 = True |
|
|
|
print("Supports bfloat16:", supports_bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
generate_every = 50 |
|
validate_every = 10 |
|
save_every = 5 |
|
t0 = time.time() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
file_handler = logging.FileHandler('training_log.txt') |
|
file_handler.setLevel(logging.INFO) |
|
file_handler.setFormatter(logging.Formatter('%(message)s')) |
|
logger.addHandler(file_handler) |
|
|
|
for step in range(max_steps): |
|
|
|
loss_accum = 0.0 |
|
gpt.zero_grad() |
|
for minibatchstep in range(grad_accum_steps): |
|
x, y = train_dataloader.next_batch() |
|
x, y = x.to(device), y.to(device) |
|
|
|
if supports_bfloat16: |
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
|
logits, loss = gpt(x, y) |
|
else: |
|
logits, loss = gpt(x, y) |
|
|
|
loss = loss / grad_accum_steps |
|
loss_accum += loss.detach() |
|
if ddp: |
|
gpt.require_backward_grad_sync = (minibatchstep == grad_accum_steps - 1) |
|
loss.backward() |
|
|
|
if ddp: |
|
dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG) |
|
lossi.append(loss_accum.item()) |
|
norm = torch.nn.utils.clip_grad_norm_(gpt.parameters(), 1.0) |
|
lr = get_lr(step) |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
optimizer.step() |
|
t_current = time.time() |
|
elapsed_time = t_current - t0 |
|
steps_completed = step + 1 |
|
avg_time_per_step = elapsed_time / steps_completed |
|
remaining_steps = max_steps - steps_completed |
|
remaining_time = remaining_steps * avg_time_per_step |
|
|
|
if master_process: |
|
logger.info(f'Step {step} | Loss: {loss_accum:.6f} | Norm: {norm:.4f} | LR: {lr:.2e} | Time: {seconds_to_hms(elapsed_time)} | Remaining: {seconds_to_hms(remaining_time)} | Avg Time/Step: {avg_time_per_step:.2f}') |
|
|
|
if master_process and step % generate_every == 0: |
|
generated_text = generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False) |
|
logger.info(f'Generated Text at Step {step}: {generated_text}') |
|
|
|
|
|
if step % validate_every == 0: |
|
if master_process: |
|
logger.info("Validating...") |
|
gpt.eval() |
|
val_loss_accum = 0.0 |
|
val_dataloader.reset() |
|
with torch.no_grad(): |
|
val_loss_accum |
|
val_loss_steps = 20 |
|
for _ in range(val_loss_steps): |
|
x, y = val_dataloader.next_batch() |
|
x, y = x.to(device), y.to(device) |
|
if supports_bfloat16: |
|
with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
|
val_logits, val_loss = gpt(x, y) |
|
else: |
|
val_logits, val_loss = gpt(x, y) |
|
val_loss = val_loss / val_loss_steps |
|
val_loss_accum += val_loss.detach() |
|
if ddp: |
|
dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) |
|
if master_process: |
|
logger.info(f'Validation Loss: {val_loss_accum}') |
|
val_lossi.append(val_loss_accum.item()) |
|
|
|
if step % save_every == 0 and master_process: |
|
print("Saving model and loss...") |
|
torch.save(raw_gpt.state_dict(), "gpt2_step_{}.pth".format(step)) |
|
torch.save(torch.tensor(lossi), "lossi_step_{}.pth".format(step)) |
|
torch.save(torch.tensor(val_lossi), "val_lossi_step_{}.pth".format(step)) |
|
|
|
|
|
|
|
|
|
|
|
plot = True |
|
if master_process and plot: |
|
plt.plot(lossi, label="Train Loss") |
|
|
|
|
|
val_lossi_stretched = np.interp( |
|
np.linspace(0, len(val_lossi) - 1, len(lossi)), |
|
np.arange(len(val_lossi)), |
|
val_lossi |
|
) |
|
|
|
plt.plot(val_lossi_stretched, label="Validation Loss") |
|
plt.legend() |
|
plt.xlabel("Step") |
|
plt.ylabel("Loss") |
|
|
|
plt.show() |
|
|
|
|
|
if master_process: |
|
print(generate_text("The king said", gpt, enc, max_len=25, print_while_generating=False)) |
|
|
|
|
|
|
|
|
|
if master_process: |
|
torch.save(gpt.state_dict(), "gpt2_shakespeare.pth") |
|
torch.save(torch.tensor(lossi), "lossi.pth") |
|
torch.save(torch.tensor(val_lossi), "val_lossi.pth") |
|
|
|
|
|
|
|
|
|
if ddp: |
|
destroy_process_group() |
|
|
|
|
|
import sys; sys.exit(0) |
|
|