zero / pipeline.py
m3's picture
chore: change onnxruntime to 1.16
857f94a
raw
history blame
3.02 kB
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel
from transformers.pipelines import PIPELINE_REGISTRY
from huggingface_hub import hf_hub_download
import onnxruntime as ort
import torch
import os
# 1. register AutoConfig
class ONNXBaseConfig(PretrainedConfig):
model_type = 'onnx-base'
AutoConfig.register('onnx-base', ONNXBaseConfig)
# 2. register AutoModel
class ONNXBaseModel(PreTrainedModel):
config_class = ONNXBaseConfig
def __init__(self, config, base_path=None):
super().__init__(config)
if base_path:
model_path = base_path + '/' + config.model_path
if os.path.exists(model_path):
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
def forward(self, input=None, **kwargs):
outs = self.session.run(None, {'input': input})
return outs
def save_pretrained(self, save_directory: str, **kwargs):
super().save_pretrained(save_directory=save_directory, **kwargs)
onnx_file_path = save_directory + '/model.onnx'
dummy_input = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
torch.onnx.export(self, dummy_input, onnx_file_path,
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
if config.model_path is None:
config.model_path = 'model.onnx'
is_local = os.path.isdir(pretrained_model_name_or_path)
if is_local:
base_path = pretrained_model_name_or_path
else:
config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename='config.json')
base_path = os.path.dirname(config_path)
hf_hub_download(repo_id=pretrained_model_name_or_path, filename=config.model_path)
return cls(config, base_path=base_path)
@property
def device(self):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return torch.device(device)
AutoModel.register(ONNXBaseConfig, ONNXBaseModel)
# 2. register Pipeline
from transformers.pipelines import Pipeline
class ONNXBasePipeline(Pipeline):
def __init__(self, model, **kwargs):
self.device_id = kwargs['device']
super().__init__(model=model, **kwargs)
def _sanitize_parameters(self, **kwargs):
return {}, {}, {}
def preprocess(self, input):
return {'input': input}
def _forward(self, model_input):
with torch.no_grad():
outputs = self.model(**model_input)
return outputs
def postprocess(self, model_outputs):
return model_outputs
PIPELINE_REGISTRY.register_pipeline(
task='onnx-base',
pipeline_class=ONNXBasePipeline,
pt_model=ONNXBaseModel
)