|
|
|
|
|
|
|
|
|
|
|
from .base_wrapper import ONNXModel, OnnxModelPickable |
|
from pathlib import Path |
|
import torch |
|
|
|
class ModelBase: |
|
def __init__(self, model_info, provider): |
|
self.model_path = model_info['model_path'] |
|
|
|
if 'input_dynamic_shape' in model_info.keys(): |
|
self.input_dynamic_shape = model_info['input_dynamic_shape'] |
|
else: |
|
self.input_dynamic_shape = None |
|
|
|
if 'picklable' in model_info.keys(): |
|
picklable = model_info['picklable'] |
|
else: |
|
picklable = False |
|
|
|
if 'trt_wrapper_self' in model_info.keys(): |
|
TRTWrapper = TRTWrapperSelf |
|
|
|
|
|
if Path(self.model_path).suffix == '.engine': |
|
self.model_type = 'trt' |
|
self.model = TRTWrapper(self.model_path) |
|
elif Path(self.model_path).suffix == '.tjm': |
|
self.model_type = 'tjm' |
|
self.model =torch.jit.load(self.model_path) |
|
self.model.eval() |
|
elif Path(self.model_path).suffix in ['.onnx', '.bin']: |
|
self.model_type = 'onnx' |
|
model_name = self.model_path.split('/')[-1].split('.')[0].split('_')[0] |
|
if not picklable: |
|
self.model = ONNXModel(self.model_path, provider=provider, input_dynamic_shape=self.input_dynamic_shape, model_name=model_name) |
|
else: |
|
self.model = OnnxModelPickable(self.model_path, provider=provider, ) |
|
else: |
|
raise 'check model suffix , support engine/tjm/onnx now.' |
|
|