|
from __future__ import annotations |
|
|
|
import json |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import numpy as np |
|
import onnxruntime as ort |
|
from loguru import logger |
|
from onnxruntime.transformers.io_binding_helper import TypeHelper |
|
|
|
|
|
@dataclass |
|
class ModelInfo: |
|
base_model: str |
|
|
|
@classmethod |
|
def from_dir(cls, model_dir: Path): |
|
with open(model_dir / "metadata.json", "r", encoding="utf-8") as file: |
|
data = json.load(file) |
|
return ModelInfo(base_model=data["bert_type"]) |
|
|
|
|
|
class ONNXModel: |
|
def __init__(self, model: ort.InferenceSession, model_info: ModelInfo) -> None: |
|
self.model = model |
|
self.model_info = model_info |
|
self.model_path = Path(model._model_path) |
|
self.model_name = self.model_path.name |
|
|
|
self.providers = model.get_providers() |
|
|
|
if self.providers[0] in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]: |
|
self.device = "cuda" |
|
else: |
|
self.device = "cpu" |
|
|
|
self.io_types = TypeHelper.get_io_numpy_type_map(model) |
|
|
|
self.input_names = [el.name for el in model.get_inputs()] |
|
self.output_name = model.get_outputs()[0].name |
|
|
|
@staticmethod |
|
def load_session( |
|
path: str | Path, |
|
provider: str = "CPUExecutionProvider", |
|
session_options: ort.SessionOptions | None = None, |
|
provider_options: dict[str, Any] | None = None, |
|
) -> ort.InferenceSession: |
|
providers = [provider] |
|
if provider == "TensorrtExecutionProvider": |
|
providers.append("CUDAExecutionProvider") |
|
elif provider == "CUDAExecutionProvider": |
|
providers.append("CPUExecutionProvider") |
|
|
|
if not isinstance(path, str): |
|
path = Path(path) / "model.onnx" |
|
|
|
providers_options = None |
|
if provider_options is not None: |
|
providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)] |
|
|
|
session = ort.InferenceSession( |
|
str(path), |
|
providers=providers, |
|
sess_options=session_options, |
|
provider_options=providers_options, |
|
) |
|
logger.info("Session loaded") |
|
return session |
|
|
|
@classmethod |
|
def from_dir(cls, model_dir: str | Path) -> ONNXModel: |
|
return ONNXModel(ONNXModel.load_session(model_dir), ModelInfo.from_dir(model_dir)) |
|
|
|
def __call__(self, **model_inputs: np.ndarray): |
|
model_inputs = { |
|
input_name: tensor.astype(self.io_types[input_name]) for input_name, tensor in model_inputs.items() |
|
} |
|
|
|
return self.model.run([self.output_name], model_inputs)[0] |
|
|