File size: 2,675 Bytes
8fe5582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
import onnxruntime as ort
from loguru import logger
from onnxruntime.transformers.io_binding_helper import TypeHelper


@dataclass
class ModelInfo:
    base_model: str

    @classmethod
    def from_dir(cls, model_dir: Path):
        with open(model_dir / "metadata.json", "r", encoding="utf-8") as file:
            data = json.load(file)
        return ModelInfo(base_model=data["bert_type"])


class ONNXModel:
    def __init__(self, model: ort.InferenceSession, model_info: ModelInfo) -> None:
        self.model = model
        self.model_info = model_info
        self.model_path = Path(model._model_path)  # type: ignore
        self.model_name = self.model_path.name

        self.providers = model.get_providers()

        if self.providers[0] in ["CUDAExecutionProvider", "TensorrtExecutionProvider"]:
            self.device = "cuda"
        else:
            self.device = "cpu"

        self.io_types = TypeHelper.get_io_numpy_type_map(model)

        self.input_names = [el.name for el in model.get_inputs()]
        self.output_name = model.get_outputs()[0].name

    @staticmethod
    def load_session(
        path: str | Path,
        provider: str = "CPUExecutionProvider",
        session_options: ort.SessionOptions | None = None,
        provider_options: dict[str, Any] | None = None,
    ) -> ort.InferenceSession:
        providers = [provider]
        if provider == "TensorrtExecutionProvider":
            providers.append("CUDAExecutionProvider")
        elif provider == "CUDAExecutionProvider":
            providers.append("CPUExecutionProvider")

        if not isinstance(path, str):
            path = Path(path) / "model.onnx"

        providers_options = None
        if provider_options is not None:
            providers_options = [provider_options] + [{} for _ in range(len(providers) - 1)]

        session = ort.InferenceSession(
            str(path),
            providers=providers,
            sess_options=session_options,
            provider_options=providers_options,
        )
        logger.info("Session loaded")
        return session

    @classmethod
    def from_dir(cls, model_dir: str | Path) -> ONNXModel:
        return ONNXModel(ONNXModel.load_session(model_dir), ModelInfo.from_dir(model_dir))

    def __call__(self, **model_inputs: np.ndarray):
        model_inputs = {
            input_name: tensor.astype(self.io_types[input_name]) for input_name, tensor in model_inputs.items()
        }

        return self.model.run([self.output_name], model_inputs)[0]