File size: 3,162 Bytes
e287bc1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import argparse
from code.detection3d import Detection3dEval
from code.eval_handler import EvaluationHandler
from code.semantic_segmentation import SemanticSegmentationEval
from aidisdk import AIDIClient
def process(args):
task_type = args.task_type
endpoint = args.endpoint
token = args.token
experiment_name = args.experiment_name
group_name = args.group_name
run_name = args.run_name
if args.images_dataset_id:
images_dataset_id = args.images_dataset_id
else:
images_dataset_id = None
if args.labels_dataset_id:
labels_dataset_id = args.labels_dataset_id
else:
labels_dataset_id = None
if args.predictions_dataset_id:
predictions_dataset_id = args.predictions_dataset_id
else:
predictions_dataset_id = None
if args.gt_dataset_id:
gt_dataset_id = args.gt_dataset_id
else:
gt_dataset_id = None
prediction_name = args.prediction_name
setting_file_name = args.setting_file_name
if task_type == "Detection_3D":
eval_class = Detection3dEval
elif task_type == "Semantic_Segmentation":
eval_class = SemanticSegmentationEval
else:
raise NotImplementedError
client = AIDIClient(token=token, endpoint=endpoint)
# 任务发起阶段,支持发起一个本地任务和艾迪平台dag任;
# 本文档举例本地是如何结合实验管理来发起;
# 开始初始化一个实验run
# run_name是必填的,可以填写已经存在的或者不存在的run name;
# 当填写的是不存在的run name时,会自动创建;
with client.experiment.init(
experiment_name=experiment_name, run_name=run_name, enabled=True
) as run:
# 进行一次上报,runtime为默认值"local"即可
# config_file可以填写当前任务的配置文件,会自动上传并记录
# 除了runtime和config_file,其他参数均为用户自定义上报内容
# 此处举例上报aidisdk版本
run.log_runtime(
runtime="local",
horizon_hat="1.3.1",
)
EvaluationHandler(
endpoint,
token,
group_name,
images_dataset_id,
gt_dataset_id,
labels_dataset_id,
predictions_dataset_id,
prediction_name,
setting_file_name,
eval_class,
).execute()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task_type", type=str)
parser.add_argument("--endpoint", type=str)
parser.add_argument("--token", type=str)
parser.add_argument("--group_name", type=str)
parser.add_argument("--experiment_name", type=str)
parser.add_argument("--run_name", type=str)
parser.add_argument("--images_dataset_id", type=str)
parser.add_argument("--gt_dataset_id", type=str)
parser.add_argument("--labels_dataset_id", type=str)
parser.add_argument("--predictions_dataset_id", type=str)
parser.add_argument("--prediction_name", type=str) # pr_name/file_name
parser.add_argument("--setting_file_name", type=str)
args = parser.parse_args()
process(args)
|