|
""" |
|
Simple training loop; Boilerplate that could apply to any arbitrary neural network, |
|
so nothing in this file really has anything to do with GPT specifically. |
|
""" |
|
|
|
import time |
|
from collections import defaultdict |
|
|
|
import torch |
|
from torch.utils.data.dataloader import DataLoader |
|
from mingpt.utils import CfgNode as CN |
|
|
|
class Trainer: |
|
|
|
@staticmethod |
|
def get_default_config(): |
|
C = CN() |
|
|
|
C.device = 'auto' |
|
|
|
C.num_workers = 4 |
|
|
|
C.max_iters = None |
|
C.batch_size = 64 |
|
C.learning_rate = 3e-4 |
|
C.betas = (0.9, 0.95) |
|
C.weight_decay = 0.1 |
|
C.grad_norm_clip = 1.0 |
|
return C |
|
|
|
def __init__(self, config, model, train_dataset): |
|
self.config = config |
|
self.model = model |
|
self.optimizer = None |
|
self.train_dataset = train_dataset |
|
self.callbacks = defaultdict(list) |
|
|
|
|
|
if config.device == 'auto': |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
else: |
|
self.device = config.device |
|
self.model = self.model.to(self.device) |
|
print("running on device", self.device) |
|
|
|
|
|
self.iter_num = 0 |
|
self.iter_time = 0.0 |
|
self.iter_dt = 0.0 |
|
|
|
def add_callback(self, onevent: str, callback): |
|
self.callbacks[onevent].append(callback) |
|
|
|
def set_callback(self, onevent: str, callback): |
|
self.callbacks[onevent] = [callback] |
|
|
|
def trigger_callbacks(self, onevent: str): |
|
for callback in self.callbacks.get(onevent, []): |
|
callback(self) |
|
|
|
def run(self): |
|
model, config = self.model, self.config |
|
|
|
|
|
self.optimizer = model.configure_optimizers(config) |
|
|
|
|
|
train_loader = DataLoader( |
|
self.train_dataset, |
|
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)), |
|
shuffle=False, |
|
pin_memory=True, |
|
batch_size=config.batch_size, |
|
num_workers=config.num_workers, |
|
) |
|
|
|
model.train() |
|
self.iter_num = 0 |
|
self.iter_time = time.time() |
|
data_iter = iter(train_loader) |
|
while True: |
|
|
|
|
|
try: |
|
batch = next(data_iter) |
|
except StopIteration: |
|
data_iter = iter(train_loader) |
|
batch = next(data_iter) |
|
batch = [t.to(self.device) for t in batch] |
|
x, y = batch |
|
|
|
|
|
logits, self.loss = model(x, y) |
|
|
|
|
|
model.zero_grad(set_to_none=True) |
|
self.loss.backward() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) |
|
self.optimizer.step() |
|
|
|
self.trigger_callbacks('on_batch_end') |
|
self.iter_num += 1 |
|
tnow = time.time() |
|
self.iter_dt = tnow - self.iter_time |
|
self.iter_time = tnow |
|
|
|
|
|
if config.max_iters is not None and self.iter_num >= config.max_iters: |
|
break |
|
|