Spaces:
Sleeping
Sleeping
from src.config.config import setup_logging | |
from pipeline import Preprocessor, NYCDataLoader, Trainer, VanillaLSTM, Transformer, VAE, save_model | |
from path_config import RAW_DATA_PATH | |
def train(): | |
seq_length = 24 | |
setup_logging() | |
# Preprocess the data | |
preprocessor = Preprocessor() | |
preprocessor.preprocess_data(file_path=RAW_DATA_PATH, window_size=seq_length) | |
# Load the preprocessed data | |
data_loader = NYCDataLoader(batch_size=32) | |
train_loader, val_loader, test_loader = data_loader.load_data() | |
# Initialize the Trainer | |
trainer = Trainer() | |
# Train Vanilla LSTM model | |
trainer.init_model(model=VanillaLSTM(), model_type="lstm") | |
trainer.config_train(batch_size=32, n_epochs=20, lr=0.001) | |
lstm_model, lstm_history = trainer.train(train_loader=train_loader, val_loader=val_loader) | |
# Train VAE model | |
trainer.init_model(model=VAE(seq_len=seq_length), model_type="vae") | |
trainer.config_train(batch_size=32, n_epochs=20, lr=0.001) | |
vae_model, vae_history = trainer.train(train_loader=train_loader, val_loader=val_loader) | |
# Train Transformer model | |
trainer.init_model(model=Transformer(), model_type="transformer") | |
trainer.config_train(batch_size=32, n_epochs=5, lr=0.001) | |
transformer_model, transformer_history = trainer.train(train_loader=train_loader, val_loader=val_loader) | |
# Save the models | |
save_model(lstm_model, "lstm_model_small.pth") | |
save_model(vae_model, "vae_model_small.pth") | |
save_model(transformer_model, "transformer_model_small.pth") | |
if __name__ == '__main__': | |
train() |