File size: 1,128 Bytes
11bec05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from daedalus_mobile import DaedalusMobile
from tokenizer import DaedalusTokenizer
from config import config
def train(model, device, train_loader, optimizer):
model.train()
total_loss = 0
for batch in train_loader:
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
optimizer.zero_grad()
loss = model.train_step((input_ids, attention_mask, labels))
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def main():
device = torch.device(config.device)
model = DaedalusMobile(config)
model.to(device)
tokenizer = DaedalusTokenizer(config)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True)
optimizer = model.configure_optimizers()
for epoch in range(config.epochs):
loss = train(model, device, train_loader, optimizer)
print(f'Epoch {epoch+1}, Loss: {loss:.4f}')
if __name__ == '__main__':
main() |