Spaces:
Sleeping
Sleeping
File size: 3,990 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import argparse
import os
from enum import Enum
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from atoms_detection.dataset import CoordinatesDataset
from utils.constants import Split, Catalyst, Method
from utils.paths import DETECTION_LOGS, IMG_PATH, PRED_GT_VIS_PATH, PT_DATASET, FE_DATASET, DETECTION_PATH
from visualizations.utils import plot_gt_pred_on_img
def main(args):
catalyst = args.catalyst
method = args.method
if not os.path.exists(PRED_GT_VIS_PATH):
os.makedirs(PRED_GT_VIS_PATH)
if catalyst == Catalyst.Pt:
coordinates_dataset = CoordinatesDataset(PT_DATASET)
if method == Method.DL:
detection_path = "data/detection_data/dl_detection_sac_cnn/dl_detection_sac_cnn_0.89"
elif method == Method.CV:
detection_path = os.path.join(DETECTION_PATH, "cv_detection_trial_0.18")
elif method == Method.TEM:
detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_pt",
"tem_imagenet_pt_denoise-bg_Gen1GaussianMask")
else:
raise NotImplementedError
elif catalyst == Catalyst.Fe:
coordinates_dataset = CoordinatesDataset(FE_DATASET)
if method == Method.DL:
detection_path = os.path.join(DETECTION_PATH, f"dl_fe_detection_trial",
f"dl_fe_detection_trial_0.97")
elif method == Method.CV:
detection_path = os.path.join(DETECTION_PATH, "cv_fe_detection_trial",
"cv_fe_detection_trial_0.21")
elif method == Method.TEM:
detection_path = os.path.join(DETECTION_PATH, "tem_imagenet_fe",
"tem_imagenet_fe_denoise-bg_Gen1GaussianMask")
else:
raise NotImplementedError
else:
raise NotImplementedError
gt_coords_dict = get_gt_coords(coordinates_dataset)
for name_file in os.listdir(detection_path):
image_name = os.path.splitext(name_file)[0] + ".tif"
print(image_name)
if image_name not in gt_coords_dict:
continue
filepath = os.path.join(detection_path, name_file)
image_filename = os.path.join(IMG_PATH, image_name)
img = Image.open(image_filename)
gt_coords = gt_coords_dict[image_name]
df_predicted = pd.read_csv(filepath)
pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]
img_arr = np.array(img).astype(np.float32)
img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
vis_folder = os.path.join(PRED_GT_VIS_PATH, f"{catalyst}-Catalyst_{method}-Method")
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
clean_image_name = os.path.splitext(image_name)[0]
vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
plt.close()
def get_gt_coords(coordinates_dataset):
gt_coords_dict = {}
for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST):
# orig . image_name = os.path.splitext(os.path.basename(image_path))[0] + ".tif"
image_name = os.path.basename(image_path)
gt_coords = coordinates_dataset.load_coordinates(coordinates_path)
gt_coords_dict[image_name] = gt_coords
return gt_coords_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"catalyst",
type=Catalyst,
choices=Catalyst,
help="Select data by catalyst"
)
parser.add_argument(
"method",
type=Method,
choices=Method,
help="Select method"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)
|