File size: 3,287 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import List

import argparse
import os

from atoms_detection.create_crop_dataset import create_crops_dataset
from atoms_detection.dl_detection import DLDetection
from atoms_detection.evaluation import Evaluation
from atoms_detection.training_model import train_model
from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH
from utils.constants import ModelArgs


def dl_full_pipeline(
        extension_name: str,
        architecture: ModelArgs,
        coords_csv: str,
        thresholds_list: List[float],
        force: bool = False
):
    # Create crops data
    crops_folder = CROPS_PATH + f"_{extension_name}"
    crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
    if force or not os.path.exists(crops_dataset):
        print("Creating crops dataset...")
        create_crops_dataset(crops_folder, coords_csv, crops_dataset)

    # training DL model
    ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
    if force or not os.path.exists(ckpt_filename):
        print("Training DL crops model...")
        train_model(architecture, crops_dataset, crops_folder, ckpt_filename)

    force = True
    # DL Detection & Evaluation
    for threshold in thresholds_list:
        inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
        detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", f"dl_detection_{extension_name}_{threshold}")
        if force or not os.path.exists(detections_path):
            print(f"Detecting atoms on test data with threshold={threshold}...")
            detection = DLDetection(
                model_name=architecture,
                ckpt_filename=ckpt_filename,
                dataset_csv=coords_csv,
                threshold=threshold,
                detections_path=detections_path,
                inference_cache_path=inference_cache_path
            )
            detection.run()

        logging_filename = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{threshold}.csv")
        if force or not os.path.exists(logging_filename):
            evaluation = Evaluation(
                coords_csv=coords_csv,
                predictions_path=detections_path,
                logging_filename=logging_filename
            )
            evaluation.run()


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "extension_name",
        type=str,
        help="Experiment extension name"
    )
    parser.add_argument(
        "architecture",
        type=ModelArgs,
        choices=ModelArgs,
        help="Architecture name"
    )
    parser.add_argument(
        "coords_csv",
        type=str,
        help="Coordinates CSV file to use as input"
    )
    parser.add_argument(
        "-t"
        "--thresholds",
        nargs="+",
        type=float,
        help="Coordinates CSV file to use as input"
    )
    parser.add_argument(
        "--force",
        action="store_true"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    print(args)
    dl_full_pipeline(args.extension_name, args.architecture, args.coords_csv, args.t__thresholds, args.force)