atom-detection / atoms_detection /dl_full_pipeline.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
from typing import List
import argparse
import os
from atoms_detection.create_crop_dataset import create_crops_dataset
from atoms_detection.dl_detection import DLDetection
from atoms_detection.evaluation import Evaluation
from atoms_detection.training_model import train_model
from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH
from utils.constants import ModelArgs
def dl_full_pipeline(
extension_name: str,
architecture: ModelArgs,
coords_csv: str,
thresholds_list: List[float],
force: bool = False
):
# Create crops data
crops_folder = CROPS_PATH + f"_{extension_name}"
crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
if force or not os.path.exists(crops_dataset):
print("Creating crops dataset...")
create_crops_dataset(crops_folder, coords_csv, crops_dataset)
# training DL model
ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
if force or not os.path.exists(ckpt_filename):
print("Training DL crops model...")
train_model(architecture, crops_dataset, crops_folder, ckpt_filename)
force = True
# DL Detection & Evaluation
for threshold in thresholds_list:
inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", f"dl_detection_{extension_name}_{threshold}")
if force or not os.path.exists(detections_path):
print(f"Detecting atoms on test data with threshold={threshold}...")
detection = DLDetection(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=coords_csv,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
detection.run()
logging_filename = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{threshold}.csv")
if force or not os.path.exists(logging_filename):
evaluation = Evaluation(
coords_csv=coords_csv,
predictions_path=detections_path,
logging_filename=logging_filename
)
evaluation.run()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"extension_name",
type=str,
help="Experiment extension name"
)
parser.add_argument(
"architecture",
type=ModelArgs,
choices=ModelArgs,
help="Architecture name"
)
parser.add_argument(
"coords_csv",
type=str,
help="Coordinates CSV file to use as input"
)
parser.add_argument(
"-t"
"--thresholds",
nargs="+",
type=float,
help="Coordinates CSV file to use as input"
)
parser.add_argument(
"--force",
action="store_true"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
print(args)
dl_full_pipeline(args.extension_name, args.architecture, args.coords_csv, args.t__thresholds, args.force)