|
import os |
|
from dataclasses import dataclass |
|
from typing import List, Union |
|
|
|
import yaml |
|
from dataclasses_json import DataClassJsonMixin |
|
|
|
from aidisdk import AIDIClient |
|
from aidisdk.experiment import Image, Table |
|
|
|
|
|
@dataclass |
|
class RunConfig(DataClassJsonMixin): |
|
"""Dataclass for config of run.""" |
|
|
|
endpoint: str |
|
token: str |
|
group_name: str |
|
images_dataset_id: str |
|
gt_dataset_id: str |
|
labels_dataset_id: str |
|
predictions_dataset_id: str |
|
prediction_name: str |
|
setting_file_name: str |
|
|
|
|
|
@dataclass |
|
class DetectionEvalConfig(DataClassJsonMixin): |
|
"""Dataclass for config of evaluation.""" |
|
|
|
images_dir: str |
|
gt: str |
|
prediction: str |
|
setting: str |
|
|
|
|
|
@dataclass |
|
class SemanticSegmentationEvalConfig(DataClassJsonMixin): |
|
"""Dataclass for config of evaluation.""" |
|
|
|
images_dir: str |
|
labels_dir: str |
|
prediction_dir: str |
|
setting: str |
|
images_json: str |
|
|
|
|
|
@dataclass |
|
class EvalResult(DataClassJsonMixin): |
|
"""Dataclass for result of evaluation.""" |
|
|
|
summary: dict |
|
tables: List[Table] |
|
plots: List[dict] |
|
images: List[Image] |
|
|
|
|
|
class BaseEvaluation: |
|
|
|
def __init__(self, run_config: RunConfig): |
|
self.run_config = run_config |
|
self.client = AIDIClient(endpoint=run_config.endpoint) |
|
|
|
def get_data(self, dataset_id_info: str, file_type: str) -> str: |
|
dataset_version = dataset_id_info.split("://")[0] |
|
dataset_id = int(dataset_id_info.split("://")[1]) |
|
if dataset_version == "dataset": |
|
dataset_interface = self.client.dataset.load(dataset_id) |
|
data_path_list = dataset_interface.file_list(download=True) |
|
if file_type == "gt" and len(data_path_list): |
|
gt_path = data_path_list[0] |
|
return gt_path |
|
elif file_type == "images_dir" and len(data_path_list): |
|
dir_name = os.path.dirname(data_path_list[0]) |
|
return dir_name |
|
else: |
|
raise NotImplementedError |
|
else: |
|
raise ValueError("dataset version not supported") |
|
|
|
|
|
def detection_preprocess(self) -> DetectionEvalConfig: |
|
|
|
images_dir = self.get_data( |
|
self.run_config.images_dataset_id, "images_dir" |
|
) |
|
gt_path = self.get_data(self.run_config.gt_dataset_id, "gt") |
|
|
|
self.client.experiment.init_group(self.run_config.group_name) |
|
|
|
|
|
pr_name = self.run_config.prediction_name.split("/")[0] |
|
file_name = self.run_config.prediction_name.split("/")[1] |
|
pr_file = self.client.experiment.use_artifact(name=pr_name).get_file( |
|
name=file_name |
|
) |
|
|
|
|
|
with open(self.run_config.setting_file_name, "r") as fid: |
|
cfg_dict = yaml.load(fid, Loader=yaml.Loader) |
|
self.client.experiment.log_config(cfg_dict) |
|
|
|
return DetectionEvalConfig( |
|
images_dir=images_dir, |
|
gt=gt_path, |
|
prediction=pr_file, |
|
setting=self.run_config.setting_file_name, |
|
) |
|
|
|
|
|
def semantic_segmentation_preprocess( |
|
self, |
|
) -> SemanticSegmentationEvalConfig: |
|
|
|
images_dir = self.get_data( |
|
self.run_config.images_dataset_id, "images_dir" |
|
) |
|
labels_dir = self.get_data( |
|
self.run_config.labels_dataset_id, "images_dir" |
|
) |
|
prediction_dir = self.get_data( |
|
self.run_config.predictions_dataset_id, "images_dir" |
|
) |
|
if self.run_config.gt_dataset_id: |
|
images_json_file = self.get_data( |
|
self.run_config.gt_dataset_id, "images_dir" |
|
) |
|
else: |
|
images_json_file = None |
|
|
|
self.client.experiment.init_group(self.run_config.group_name) |
|
|
|
|
|
with open(self.run_config.setting_file_name, "r") as fid: |
|
cfg_dict = yaml.load(fid, Loader=yaml.Loader) |
|
self.client.experiment.log_config(cfg_dict) |
|
|
|
return SemanticSegmentationEvalConfig( |
|
images_dir=images_dir, |
|
labels_dir=labels_dir, |
|
prediction_dir=prediction_dir, |
|
setting=self.run_config.setting_file_name, |
|
images_json=images_json_file, |
|
) |
|
|
|
|
|
|
|
def evaluate( |
|
self, |
|
eval_config: Union[ |
|
DetectionEvalConfig, SemanticSegmentationEvalConfig |
|
], |
|
) -> EvalResult: |
|
|
|
raise NotImplementedError |
|
|
|
|
|
def postprocess(self, eval_result: EvalResult) -> None: |
|
self.client.experiment.log_summary(eval_result.summary) |
|
if eval_result.tables is not None: |
|
for table in eval_result.tables: |
|
self.client.experiment.log_table(table) |
|
if eval_result.plots is not None: |
|
for plot in eval_result.plots: |
|
self.client.experiment.log_plot( |
|
plot["Table"].name, plot["Table"], plot["Line"] |
|
) |
|
if eval_result.images is not None: |
|
for image in eval_result.images: |
|
self.client.experiment.log_image(image) |
|
|