|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from earlystopping import EarlyStopping |
|
from evaluater import evaluate |
|
import torch |
|
|
|
|
|
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.Adam): |
|
history = [] |
|
optimizer = opt_func(model.parameters(), lr=lr) |
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True) |
|
|
|
early_stopping = EarlyStopping(patience=5, verbose=True) |
|
|
|
for epoch in range(epochs): |
|
|
|
model.train() |
|
train_losses = [] |
|
for batch in train_loader: |
|
loss = model.training_step(batch) |
|
train_losses.append(loss) |
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
|
|
result = evaluate(model, val_loader) |
|
result['train_loss'] = torch.stack(train_losses).mean().item() |
|
model.epoch_end(epoch, result) |
|
history.append(result) |
|
|
|
|
|
scheduler.step(result['val_loss']) |
|
early_stopping(result['val_loss'], model) |
|
|
|
if early_stopping.early_stop: |
|
print("Early stopping") |
|
break |
|
|
|
return history |