File size: 3,990 Bytes
b2ffc9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import argparse
import os
from enum import Enum

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

from atoms_detection.dataset import CoordinatesDataset
from utils.constants import Split, Catalyst, Method
from utils.paths import DETECTION_LOGS, IMG_PATH, PRED_GT_VIS_PATH, PT_DATASET, FE_DATASET, DETECTION_PATH
from visualizations.utils import plot_gt_pred_on_img


def main(args):
    catalyst = args.catalyst
    method = args.method
    if not os.path.exists(PRED_GT_VIS_PATH):
        os.makedirs(PRED_GT_VIS_PATH)
    if catalyst == Catalyst.Pt:
        coordinates_dataset = CoordinatesDataset(PT_DATASET)
        if method == Method.DL:
            detection_path = "data/detection_data/dl_detection_sac_cnn/dl_detection_sac_cnn_0.89"
        elif method == Method.CV:
            detection_path = os.path.join(DETECTION_PATH, "cv_detection_trial_0.18")
        elif method == Method.TEM:
            detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_pt",
                                          "tem_imagenet_pt_denoise-bg_Gen1GaussianMask")
        else:
            raise NotImplementedError

    elif catalyst == Catalyst.Fe:
        coordinates_dataset = CoordinatesDataset(FE_DATASET)
        if method == Method.DL:
            detection_path = os.path.join(DETECTION_PATH, f"dl_fe_detection_trial",
                                          f"dl_fe_detection_trial_0.97")
        elif method == Method.CV:
            detection_path = os.path.join(DETECTION_PATH, "cv_fe_detection_trial",
                                          "cv_fe_detection_trial_0.21")
        elif method == Method.TEM:
            detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_fe",
                                          "tem_imagenet_fe_denoise-bg_Gen1GaussianMask")
        else:
            raise NotImplementedError

    else:
        raise NotImplementedError

    gt_coords_dict = get_gt_coords(coordinates_dataset)
    for name_file in os.listdir(detection_path):
        image_name = os.path.splitext(name_file)[0] + ".tif"
        print(image_name)
        if image_name not in gt_coords_dict:
            continue

        filepath = os.path.join(detection_path, name_file)
        image_filename = os.path.join(IMG_PATH, image_name)
        img = Image.open(image_filename)

        gt_coords = gt_coords_dict[image_name]
        df_predicted = pd.read_csv(filepath)
        pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]

        img_arr = np.array(img).astype(np.float32)
        img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())

        plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)

        vis_folder = os.path.join(PRED_GT_VIS_PATH, f"{catalyst}-Catalyst_{method}-Method")
        if not os.path.exists(vis_folder):
            os.makedirs(vis_folder)

        clean_image_name = os.path.splitext(image_name)[0]
        vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
        plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
        plt.close()



def get_gt_coords(coordinates_dataset):
    gt_coords_dict = {}
    for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST):
        # orig . image_name = os.path.splitext(os.path.basename(image_path))[0] + ".tif"
        image_name = os.path.basename(image_path)
        gt_coords = coordinates_dataset.load_coordinates(coordinates_path)
        gt_coords_dict[image_name] = gt_coords
    return gt_coords_dict


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "catalyst",
        type=Catalyst,
        choices=Catalyst,
        help="Select data by catalyst"
    )
    parser.add_argument(
        "method",
        type=Method,
        choices=Method,
        help="Select method"
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)