|
import torch |
|
from transformers import AutoModelForQuestionAnswering |
|
from transformers import AutoTokenizer, BertConfig |
|
import onnx |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
import os |
|
import logging |
|
import subprocess |
|
from typing import Optional, Dict, Any |
|
|
|
class ONNXModelConverter: |
|
def __init__(self, model_name: str, output_dir: str): |
|
self.model_name = model_name |
|
self.output_dir = output_dir |
|
self.setup_logging() |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
self.logger.info(f"Loading tokenizer {model_name}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
self.logger.info(f"Loading model {model_name}...") |
|
self.model = AutoModelForQuestionAnswering.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float32 |
|
) |
|
self.model.eval() |
|
|
|
def setup_logging(self): |
|
self.logger = logging.getLogger(__name__) |
|
self.logger.setLevel(logging.INFO) |
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
self.logger.addHandler(handler) |
|
|
|
def prepare_dummy_inputs(self): |
|
dummy_input = self.tokenizer( |
|
"Hello, how are you?", |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=128 |
|
) |
|
return { |
|
'input_ids': dummy_input['input_ids'], |
|
'attention_mask': dummy_input['attention_mask'], |
|
'token_type_ids': dummy_input['token_type_ids'] |
|
} |
|
|
|
def export_to_onnx(self): |
|
output_path = os.path.join(self.output_dir, "model.onnx") |
|
inputs = self.prepare_dummy_inputs() |
|
|
|
dynamic_axes = { |
|
'input_ids': {0: 'batch_size', 1: 'sequence_length'}, |
|
'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, |
|
'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, |
|
'start_logits': {0: 'batch_size', 1: 'sequence_length'}, |
|
'end_logits': {0: 'batch_size', 1: 'sequence_length'}, |
|
} |
|
|
|
class ModelWrapper(torch.nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
return outputs.start_logits, outputs.end_logits |
|
|
|
wrapped_model = ModelWrapper(self.model) |
|
|
|
try: |
|
torch.onnx.export( |
|
wrapped_model, |
|
(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), |
|
output_path, |
|
export_params=True, |
|
opset_version=14, |
|
do_constant_folding=True, |
|
input_names=['input_ids', 'attention_mask', 'token_type_ids'], |
|
output_names=['start_logits', 'end_logits'], |
|
dynamic_axes=dynamic_axes, |
|
verbose=False |
|
) |
|
self.logger.info(f"Model exported to {output_path}") |
|
return output_path |
|
except Exception as e: |
|
self.logger.error(f"ONNX export failed: {str(e)}") |
|
raise |
|
|
|
def verify_model(self, model_path: str): |
|
try: |
|
onnx_model = onnx.load(model_path) |
|
onnx.checker.check_model(onnx_model) |
|
self.logger.info("ONNX model verification successful") |
|
return True |
|
except Exception as e: |
|
self.logger.error(f"Model verification failed: {str(e)}") |
|
return False |
|
|
|
def preprocess_model(self, model_path: str) -> str: |
|
preprocessed_path = os.path.join(self.output_dir, "model-infer.onnx") |
|
try: |
|
command = [ |
|
"python", "-m", "onnxruntime.quantization.preprocess", |
|
"--input", model_path, |
|
"--output", preprocessed_path |
|
] |
|
result = subprocess.run(command, check=True, capture_output=True, text=True) |
|
self.logger.info(f"Model preprocessing successful. Output saved to {preprocessed_path}") |
|
return preprocessed_path |
|
except subprocess.CalledProcessError as e: |
|
self.logger.error(f"Preprocessing failed: {e.stderr}") |
|
raise |
|
except Exception as e: |
|
self.logger.error(f"Preprocessing failed: {str(e)}") |
|
raise |
|
|
|
def quantize_model(self, model_path: str): |
|
weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16} |
|
all_quantized_paths = [] |
|
for weight_type in weight_types.keys(): |
|
quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx") |
|
|
|
try: |
|
quantize_dynamic( |
|
model_path, |
|
quantized_path, |
|
weight_type=weight_types[weight_type] |
|
) |
|
self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}") |
|
all_quantized_paths.append(quantized_path) |
|
except Exception as e: |
|
self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}") |
|
raise |
|
|
|
return all_quantized_paths |
|
|
|
def convert(self): |
|
try: |
|
onnx_path = self.export_to_onnx() |
|
|
|
if self.verify_model(onnx_path): |
|
|
|
preprocessed_path = self.preprocess_model(onnx_path) |
|
|
|
|
|
quantized_paths = self.quantize_model(preprocessed_path) |
|
|
|
tokenizer_path = os.path.join(self.output_dir, "tokenizer") |
|
self.tokenizer.save_pretrained(tokenizer_path) |
|
self.logger.info(f"Tokenizer saved to {tokenizer_path}") |
|
|
|
return { |
|
'onnx_model': onnx_path, |
|
'preprocessed_model': preprocessed_path, |
|
'quantized_models': quantized_paths, |
|
'tokenizer': tokenizer_path |
|
} |
|
else: |
|
raise Exception("Model verification failed") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Conversion process failed: {str(e)}") |
|
raise |
|
|
|
if __name__ == "__main__": |
|
MODEL_NAME = "timpal0l/mdeberta-v3-base-squad2" |
|
OUTPUT_DIR = "onnx" |
|
|
|
try: |
|
converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR) |
|
results = converter.convert() |
|
|
|
print("\nConversion completed successfully!") |
|
print(f"ONNX model path: {results['onnx_model']}") |
|
print(f"Preprocessed model path: {results['preprocessed_model']}") |
|
print(f"Quantized model paths: {results['quantized_models']}") |
|
print(f"Tokenizer path: {results['tokenizer']}") |
|
|
|
except Exception as e: |
|
print(f"Conversion failed: {str(e)}") |
|
|