File size: 3,333 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import argparse
import random

import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from torchvision.models import resnet18

from utils.paths import MODELS_PATH, CROPS_PATH, CROPS_DATASET
from utils.constants import ModelArgs, Split, CropsColumns
from atoms_detection.training import train_epoch, val_epoch
from atoms_detection.dataset import ImageClassificationDataset
from atoms_detection.model import BasicCNN


torch.manual_seed(777)
random.seed(777)
np.random.seed(777)


def get_basic_cnn(*args, **kwargs):
    model = BasicCNN(*args, **kwargs)
    return model


def get_resnet(*args, **kwargs):
    model = resnet18(*args, **kwargs)
    model.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
    return model


model_pipeline = {
    ModelArgs.BASICCNN: get_basic_cnn,
    ModelArgs.RESNET18: get_resnet
}

epochs_pipeline = {
    ModelArgs.BASICCNN: 12,
    ModelArgs.RESNET18: 3
}


def train_model(model_arg: ModelArgs, crops_dataset: str, crops_path: str, ckpt_filename: str):

    class CropsDataset(ImageClassificationDataset):
        @staticmethod
        def get_filenames_labels(split: Split):
            df = pd.read_csv(crops_dataset)
            split_df = df[df[CropsColumns.SPLIT] == split]
            filenames = (crops_path + os.sep + split_df[CropsColumns.FILENAME]).to_list()
            labels = (split_df[CropsColumns.LABEL]).to_list()
            return filenames, labels


    # CUDA for PyTorch
    #use_cuda = torch.cuda.is_available()
    use_cuda = torch.backends.mps.is_available()
    device = torch.device("mps" if use_cuda else "cpu")

    train_dataset = CropsDataset.train_dataset()
    val_dataset = CropsDataset.val_dataset()
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=64)
    model = model_pipeline[model_arg](num_classes=train_dataset.get_n_labels()).to(device)

    if torch.cuda.device_count() > 1:
        print("Using {} GPUs!".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    loss_function = torch.nn.CrossEntropyLoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=True)

    epoch = 0
    for epoch in range(epochs_pipeline[model_arg]):
        train_epoch(train_dataloader, model, loss_function, optimizer, device, epoch)
        val_epoch(val_dataloader, model, loss_function, device, epoch)

    if not os.path.exists(MODELS_PATH):
        os.makedirs(MODELS_PATH)

    state = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(state, ckpt_filename)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "experiment_name",
        type=str,
        help="Experiment name"
    )
    parser.add_argument(
        "model",
        type=ModelArgs,
        help="model architecture",
        choices=list(ModelArgs)
    )
    return parser.parse_args()


if __name__ == "__main__":
    extension_name = "replicate"
    ckpt_filename = os.path.join(MODELS_PATH, "basic_replicate2.ckpt")
    crops_folder = CROPS_PATH + f"_{extension_name}"
    train_model(ModelArgs.BASICCNN, CROPS_DATASET, CROPS_PATH, ckpt_filename)