from src.config.config import setup_logging from pipeline import Preprocessor, NYCDataLoader, Trainer, VanillaLSTM, Transformer, VAE, save_model from path_config import RAW_DATA_PATH def train(): seq_length = 24 setup_logging() # Preprocess the data preprocessor = Preprocessor() preprocessor.preprocess_data(file_path=RAW_DATA_PATH, window_size=seq_length) # Load the preprocessed data data_loader = NYCDataLoader(batch_size=32) train_loader, val_loader, test_loader = data_loader.load_data() # Initialize the Trainer trainer = Trainer() # Train Vanilla LSTM model trainer.init_model(model=VanillaLSTM(), model_type="lstm") trainer.config_train(batch_size=32, n_epochs=20, lr=0.001) lstm_model, lstm_history = trainer.train(train_loader=train_loader, val_loader=val_loader) # Train VAE model trainer.init_model(model=VAE(seq_len=seq_length), model_type="vae") trainer.config_train(batch_size=32, n_epochs=20, lr=0.001) vae_model, vae_history = trainer.train(train_loader=train_loader, val_loader=val_loader) # Train Transformer model trainer.init_model(model=Transformer(), model_type="transformer") trainer.config_train(batch_size=32, n_epochs=5, lr=0.001) transformer_model, transformer_history = trainer.train(train_loader=train_loader, val_loader=val_loader) # Save the models save_model(lstm_model, "lstm_model_small.pth") save_model(vae_model, "vae_model_small.pth") save_model(transformer_model, "transformer_model_small.pth") if __name__ == '__main__': train()