Spaces:
Sleeping
Sleeping
File size: 3,287 Bytes
b2ffc9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import List
import argparse
import os
from atoms_detection.create_crop_dataset import create_crops_dataset
from atoms_detection.dl_detection import DLDetection
from atoms_detection.evaluation import Evaluation
from atoms_detection.training_model import train_model
from utils.paths import CROPS_PATH, CROPS_DATASET, MODELS_PATH, LOGS_PATH, DETECTION_PATH, PREDS_PATH
from utils.constants import ModelArgs
def dl_full_pipeline(
extension_name: str,
architecture: ModelArgs,
coords_csv: str,
thresholds_list: List[float],
force: bool = False
):
# Create crops data
crops_folder = CROPS_PATH + f"_{extension_name}"
crops_dataset = CROPS_DATASET.replace(".csv", f"_{extension_name}.csv")
if force or not os.path.exists(crops_dataset):
print("Creating crops dataset...")
create_crops_dataset(crops_folder, coords_csv, crops_dataset)
# training DL model
ckpt_filename = os.path.join(MODELS_PATH, f"model_{extension_name}.ckpt")
if force or not os.path.exists(ckpt_filename):
print("Training DL crops model...")
train_model(architecture, crops_dataset, crops_folder, ckpt_filename)
force = True
# DL Detection & Evaluation
for threshold in thresholds_list:
inference_cache_path = os.path.join(PREDS_PATH, f"dl_detection_{extension_name}")
detections_path = os.path.join(DETECTION_PATH, f"dl_detection_{extension_name}", f"dl_detection_{extension_name}_{threshold}")
if force or not os.path.exists(detections_path):
print(f"Detecting atoms on test data with threshold={threshold}...")
detection = DLDetection(
model_name=architecture,
ckpt_filename=ckpt_filename,
dataset_csv=coords_csv,
threshold=threshold,
detections_path=detections_path,
inference_cache_path=inference_cache_path
)
detection.run()
logging_filename = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{threshold}.csv")
if force or not os.path.exists(logging_filename):
evaluation = Evaluation(
coords_csv=coords_csv,
predictions_path=detections_path,
logging_filename=logging_filename
)
evaluation.run()
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"extension_name",
type=str,
help="Experiment extension name"
)
parser.add_argument(
"architecture",
type=ModelArgs,
choices=ModelArgs,
help="Architecture name"
)
parser.add_argument(
"coords_csv",
type=str,
help="Coordinates CSV file to use as input"
)
parser.add_argument(
"-t"
"--thresholds",
nargs="+",
type=float,
help="Coordinates CSV file to use as input"
)
parser.add_argument(
"--force",
action="store_true"
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
print(args)
dl_full_pipeline(args.extension_name, args.architecture, args.coords_csv, args.t__thresholds, args.force)
|