|
import torch |
|
from daedalus_mobile import DaedalusMobile |
|
from tokenizer import DaedalusTokenizer |
|
from config import config |
|
|
|
def train(model, device, train_loader, optimizer): |
|
model.train() |
|
total_loss = 0 |
|
for batch in train_loader: |
|
input_ids, attention_mask, labels = batch |
|
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
loss = model.train_step((input_ids, attention_mask, labels)) |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
return total_loss / len(train_loader) |
|
|
|
def main(): |
|
device = torch.device(config.device) |
|
model = DaedalusMobile(config) |
|
model.to(device) |
|
tokenizer = DaedalusTokenizer(config) |
|
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True) |
|
optimizer = model.configure_optimizers() |
|
for epoch in range(config.epochs): |
|
loss = train(model, device, train_loader, optimizer) |
|
print(f'Epoch {epoch+1}, Loss: {loss:.4f}') |
|
|
|
if __name__ == '__main__': |
|
main() |