File size: 5,059 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
import random
import numpy as np
from matplotlib import pyplot as plt

from PIL import Image
from networkx.tests.test_convert_pandas import pd

from atoms_detection.dl_detection import DLDetection
from atoms_detection.dl_detection_scaled import DLScaled
from atoms_detection.evaluation import Evaluation
from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH
from utils.constants import ModelArgs, Split
from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords


def detection_pipeline(args):
    extension_name = args.extension_name
    print(f"Storing at {extension_name}")
    architecture = ModelArgs.BASICCNN
    ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt")

    inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")

    testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]
    testing_thresholds = [0.8, 0.85, 0.9, 0.95]
    for threshold in testing_thresholds:
        detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}",
                                       f"dl_detection_{extension_name}_{threshold}")
        print(f"Detecting atoms on test data with threshold={threshold}...")
        if args.experimental_rescale:
            print("Using experimental ruler rescaling")
            detection = DLScaled(
                model_name=architecture,
                ckpt_filename=ckpt_filename,
                dataset_csv=args.dataset,
                threshold=threshold,
                detections_path=detections_path,
                inference_cache_path=inference_cache_path
            )
        else:
            detection = DLDetection(
                model_name=architecture,
                ckpt_filename=ckpt_filename,
                dataset_csv=args.dataset,
                threshold=threshold,
                detections_path=detections_path,
                inference_cache_path=inference_cache_path
            )
        detection.run()
        if args.eval:
            logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}",
                                            f"dl_detection_{extension_name}_{threshold}.csv")
            evaluation = Evaluation(
                coords_csv=args.dataset,
                predictions_path=detections_path,
                logging_filename=logging_filename
            )
            evaluation.run()
        if args.visualise:

            vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}")
            if not os.path.exists(vis_folder):
                os.makedirs(vis_folder)

            vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}")
            if not os.path.exists(vis_folder):
                os.makedirs(vis_folder)

            if args.eval:
                gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)

            for image_path in detection.image_dataset.iterate_data(Split.TEST):
                print(image_path)
                img_name = os.path.split(image_path)[-1]
                gt_coords = gt_coords_dict[img_name] if args.eval else None
                pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv')
                df_predicted = pd.read_csv(pred_df_path)
                pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]
                img = Image.open(image_path)
                img_arr = np.array(img).astype(np.float32)
                img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())

                plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
                clean_image_name = os.path.splitext(img_name)[0]
                vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
                plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
                plt.close()

    print(f"Experiment {extension_name} completed")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "extension_name",
        type=str,
        help="Experiment extension name"
    )
    parser.add_argument(
        "dataset",
        type=str,
        help="Dataset file upon which to do inference"
    )
    parser.add_argument(
        "--eval",
        action='store_true',
        help="Whether to perform evaluation after inference",
        default=False
    )
    parser.add_argument(
        "--visualise",
        action='store_true',
        help="Whether to store inference results as visual png images",
        default=False
    )
    parser.add_argument(
        "--experimental_rescale",
        action='store_true',
        help="Whether to rescale inputs based on the ruler of the image as preprocess",
        default=False
    )
    parser.add_argument('--feature', )
    return parser.parse_args()


if __name__=='__main__':
    args = get_args()
    detection_pipeline(args)