ArcaneSVK2 / inference /onnx_model.py
Podtekatel's picture
Initial commit for arcane
046b3c9
raw
history blame contribute delete
451 Bytes
import numpy as np
import onnxruntime
class ONNXModel:
def __init__(self, onnx_mode_path):
self.path = onnx_mode_path
self.ort_session = onnxruntime.InferenceSession(str(self.path))
self.input_name = self.ort_session.get_inputs()[0].name
def __call__(self, img):
ort_inputs = {self.input_name: img.astype(dtype=np.float32)}
ort_outs = self.ort_session.run(None, ort_inputs)[0]
return ort_outs