atom-detection / atoms_detection /dl_detection_evaluation.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
import argparse
import os
import random
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from networkx.tests.test_convert_pandas import pd
from atoms_detection.dl_detection import DLDetection
from atoms_detection.dl_detection_scaled import DLScaled
from atoms_detection.evaluation import Evaluation
from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH
from utils.constants import ModelArgs, Split
from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords
def detection_pipeline(args):
extension_name = args.extension_name
print(f"Storing at {extension_name}")
architecture = ModelArgs.BASICCNN
ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt")
inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]
testing_thresholds = [0.8, 0.85, 0.9, 0.95]
for threshold in testing_thresholds:
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}",
f"dl_detection_{extension_name}_{threshold}")
print(f"Detecting atoms on test data with threshold={threshold}...")
if args.experimental_rescale:
print("Using experimental ruler rescaling")
detection = DLScaled(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=args.dataset,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
else:
detection = DLDetection(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=args.dataset,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
detection.run()
if args.eval:
logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}",
f"dl_detection_{extension_name}_{threshold}.csv")
evaluation = Evaluation(
coords_csv=args.dataset,
predictions_path=detections_path,
logging_filename=logging_filename
)
evaluation.run()
if args.visualise:
vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}")
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}")
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
if args.eval:
gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)
for image_path in detection.image_dataset.iterate_data(Split.TEST):
print(image_path)
img_name = os.path.split(image_path)[-1]
gt_coords = gt_coords_dict[img_name] if args.eval else None
pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv')
df_predicted = pd.read_csv(pred_df_path)
pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]
img = Image.open(image_path)
img_arr = np.array(img).astype(np.float32)
img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
clean_image_name = os.path.splitext(img_name)[0]
vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
plt.close()
print(f"Experiment {extension_name} completed")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"extension_name",
type=str,
help="Experiment extension name"
)
parser.add_argument(
"dataset",
type=str,
help="Dataset file upon which to do inference"
)
parser.add_argument(
"--eval",
action='store_true',
help="Whether to perform evaluation after inference",
default=False
)
parser.add_argument(
"--visualise",
action='store_true',
help="Whether to store inference results as visual png images",
default=False
)
parser.add_argument(
"--experimental_rescale",
action='store_true',
help="Whether to rescale inputs based on the ruler of the image as preprocess",
default=False
)
parser.add_argument('--feature', )
return parser.parse_args()
if __name__=='__main__':
args = get_args()
detection_pipeline(args)