classification / benchmarker.py
Isaacgv's picture
update
17979b3
"""
Author : Bastien GUILLAUME
Version : 0.0.1
Date : 2023-03-16
Title : Benchmark ONNX model from a config file made for gradio_interfacer
"""
import os
from config_parser import *
from inferencer import *
from pathlib import Path
def format_examples(task_number, product, product_example):
response = requests.get(product_example)
examples_folder = Path(f"examples/{product}")
os.makedirs(examples_folder, exist_ok=True)
filepath = Path(examples_folder / f'{product_example.split("/")[-1]}')
if filepath.exists():
pass
else:
with open(filepath, "wb") as f:
f.write(response.content)
return [f"task{task_number+1}", product, filepath]
def benchmark_models(task_number, product):
logging.log(level=logging.INFO, msg=f"Entering benchmark_models")
models_to_benchamrk = config["tasks"][f"task{task_number+1}"]["models"][product]
number_of_model = len(models_to_benchamrk)
task_info = config["tasks"][f"task{task_number+1}"]
result = []
product_examples = (
[
format_examples(task_number, product, product_example)
for product_example in task_info["examples"][product]
]
if "examples" in task_info and product in task_info["examples"]
else []
)
for model in models_to_benchamrk:
for product_example in product_examples:
result.append(inference(task_number, product, product_example, number_of_model))
return result
benchmark_builder_list = []
benchmark_builder_dict = {}
logging.log(level=logging.INFO, msg=f"Building Interfaces")
logging.log(level=logging.INFO, msg=f"Number of task(s) : {len(tasks)}")
for task_number in range(0, len(tasks)):
logging.log(level=logging.INFO, msg=f"Treating task n°{task_number+1}")
benchmark_builder_dict[tasks[task_number]] = {}
product_list = list(config["tasks"][f"task{task_number+1}"]["models"].keys())
logging.log(level=logging.DEBUG, msg=f"Products : {product_list}")
benchmark_builder_product_level_list = []
for product in product_list:
logging.log(level=logging.INFO, msg=f"Product : {product}")
benchmark_builder_dict[tasks[task_number]][product] = []
if len(config["tasks"][f"task{task_number+1}"]["models"][product]) > 1:
generated_parralel_interface = benchmark_models(
task_number, product
)
benchmark_builder_dict[tasks[task_number]][product].append(
generated_parralel_interface
)
benchmark_builder_product_level_list.append(
generated_parralel_interface
)
else:
generated_interface = create_interface(
task_number=task_number, product=product, model_number=0
)
benchmark_builder_dict[tasks[task_number]][product].append(
generated_interface
)
benchmark_builder_product_level_list.append(generated_interface)
benchmark_builder_list.append(
gr.TabbedInterface(
interface_list=benchmark_builder_product_level_list,
tab_names=product_list,
)
)
logging.log(level=logging.INFO, msg=f"Interfaces ready\n")
logging.log(level=logging.DEBUG, msg=f"Interfaces List {benchmark_builder_list}")
# logging.log(level=logging.INFO, msg=f"Interfaces Dict {interface_dict}")