Spaces:
Sleeping
Sleeping
File size: 5,059 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import argparse
import os
import random
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from networkx.tests.test_convert_pandas import pd
from atoms_detection.dl_detection import DLDetection
from atoms_detection.dl_detection_scaled import DLScaled
from atoms_detection.evaluation import Evaluation
from utils.paths import MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH, FE_DATASET, PRED_GT_VIS_PATH
from utils.constants import ModelArgs, Split
from visualizations.prediction_gt_images import plot_gt_pred_on_img, get_gt_coords
def detection_pipeline(args):
extension_name = args.extension_name
print(f"Storing at {extension_name}")
architecture = ModelArgs.BASICCNN
ckpt_filename = os.path.join(MODELS_PATH, "model_sac_cnn.ckpt")
inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
testing_thresholds = [0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99]
testing_thresholds = [0.8, 0.85, 0.9, 0.95]
for threshold in testing_thresholds:
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}",
f"dl_detection_{extension_name}_{threshold}")
print(f"Detecting atoms on test data with threshold={threshold}...")
if args.experimental_rescale:
print("Using experimental ruler rescaling")
detection = DLScaled(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=args.dataset,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
else:
detection = DLDetection(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=args.dataset,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
detection.run()
if args.eval:
logging_filename = os.path.join(LOGS_PATH, f"dl_detection_{extension_name}",
f"dl_detection_{extension_name}_{threshold}.csv")
evaluation = Evaluation(
coords_csv=args.dataset,
predictions_path=detections_path,
logging_filename=logging_filename
)
evaluation.run()
if args.visualise:
vis_folder = os.path.join(PRED_GT_VIS_PATH, f"dl_detection_{extension_name}")
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
vis_folder = os.path.join(vis_folder, f"dl_detection_{extension_name}_{threshold}")
if not os.path.exists(vis_folder):
os.makedirs(vis_folder)
if args.eval:
gt_coords_dict = get_gt_coords(evaluation.coordinates_dataset)
for image_path in detection.image_dataset.iterate_data(Split.TEST):
print(image_path)
img_name = os.path.split(image_path)[-1]
gt_coords = gt_coords_dict[img_name] if args.eval else None
pred_df_path = os.path.join(detections_path, os.path.splitext(img_name)[0]+'.csv')
df_predicted = pd.read_csv(pred_df_path)
pred_coords = [(row['x'], row['y']) for _, row in df_predicted.iterrows()]
img = Image.open(image_path)
img_arr = np.array(img).astype(np.float32)
img_normed = (img_arr - img_arr.min()) / (img_arr.max() - img_arr.min())
plot_gt_pred_on_img(img_normed, gt_coords, pred_coords)
clean_image_name = os.path.splitext(img_name)[0]
vis_path = os.path.join(vis_folder, f'{clean_image_name}.png')
plt.savefig(vis_path, bbox_inches='tight', pad_inches=0.0, transparent=True)
plt.close()
print(f"Experiment {extension_name} completed")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"extension_name",
type=str,
help="Experiment extension name"
)
parser.add_argument(
"dataset",
type=str,
help="Dataset file upon which to do inference"
)
parser.add_argument(
"--eval",
action='store_true',
help="Whether to perform evaluation after inference",
default=False
)
parser.add_argument(
"--visualise",
action='store_true',
help="Whether to store inference results as visual png images",
default=False
)
parser.add_argument(
"--experimental_rescale",
action='store_true',
help="Whether to rescale inputs based on the ruler of the image as preprocess",
default=False
)
parser.add_argument('--feature', )
return parser.parse_args()
if __name__=='__main__':
args = get_args()
detection_pipeline(args)
|