File size: 399 Bytes
488f448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from fire import Fire
from src.classifier.classifier import get_model
from src.dataset.dataset import prepare_dataset
from src.utils import load_config_file

def trainer(config_path):
    config = load_config_file(config_path)
    dataset = prepare_dataset(config)
    model = get_model(config,dataset)
    model.train()
    model.save_model_results()

if __name__ == '__main__':
    Fire(trainer)