jeremiebasso
commited on
Commit
•
2ffc97e
1
Parent(s):
8fe5582
fix: no need torch
Browse files- onnx_model.py +4 -2
onnx_model.py
CHANGED
@@ -8,7 +8,6 @@ from typing import Any
|
|
8 |
import numpy as np
|
9 |
import onnxruntime as ort
|
10 |
from loguru import logger
|
11 |
-
from onnxruntime.transformers.io_binding_helper import TypeHelper
|
12 |
|
13 |
|
14 |
@dataclass
|
@@ -36,7 +35,10 @@ class ONNXModel:
|
|
36 |
else:
|
37 |
self.device = "cpu"
|
38 |
|
39 |
-
self.io_types =
|
|
|
|
|
|
|
40 |
|
41 |
self.input_names = [el.name for el in model.get_inputs()]
|
42 |
self.output_name = model.get_outputs()[0].name
|
|
|
8 |
import numpy as np
|
9 |
import onnxruntime as ort
|
10 |
from loguru import logger
|
|
|
11 |
|
12 |
|
13 |
@dataclass
|
|
|
35 |
else:
|
36 |
self.device = "cpu"
|
37 |
|
38 |
+
self.io_types = {
|
39 |
+
"input_ids": np.int32,
|
40 |
+
"attention_mask": np.bool_
|
41 |
+
}
|
42 |
|
43 |
self.input_names = [el.name for el in model.get_inputs()]
|
44 |
self.output_name = model.get_outputs()[0].name
|