dewdev's picture
Upload 9 files
a32a4d8 verified
import torch
from transformers import AutoModelForQuestionAnswering
from transformers import AutoTokenizer, BertConfig
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import logging
import subprocess
from typing import Optional, Dict, Any
class ONNXModelConverter:
def __init__(self, model_name: str, output_dir: str):
self.model_name = model_name
self.output_dir = output_dir
self.setup_logging()
os.makedirs(output_dir, exist_ok=True)
self.logger.info(f"Loading tokenizer {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.logger.info(f"Loading model {model_name}...")
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
self.model.eval()
def setup_logging(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def prepare_dummy_inputs(self):
dummy_input = self.tokenizer(
"Hello, how are you?",
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
return {
'input_ids': dummy_input['input_ids'],
'attention_mask': dummy_input['attention_mask'],
'token_type_ids': dummy_input['token_type_ids']
}
def export_to_onnx(self):
output_path = os.path.join(self.output_dir, "model.onnx")
inputs = self.prepare_dummy_inputs()
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'token_type_ids': {0: 'batch_size', 1: 'sequence_length'},
'start_logits': {0: 'batch_size', 1: 'sequence_length'},
'end_logits': {0: 'batch_size', 1: 'sequence_length'},
}
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
return outputs.start_logits, outputs.end_logits
wrapped_model = ModelWrapper(self.model)
try:
torch.onnx.export(
wrapped_model,
(inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']),
output_path,
export_params=True,
opset_version=14, # Or a suitable version
do_constant_folding=True,
input_names=['input_ids', 'attention_mask', 'token_type_ids'],
output_names=['start_logits', 'end_logits'],
dynamic_axes=dynamic_axes,
verbose=False
)
self.logger.info(f"Model exported to {output_path}")
return output_path
except Exception as e:
self.logger.error(f"ONNX export failed: {str(e)}")
raise
def verify_model(self, model_path: str):
try:
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
self.logger.info("ONNX model verification successful")
return True
except Exception as e:
self.logger.error(f"Model verification failed: {str(e)}")
return False
def preprocess_model(self, model_path: str) -> str:
preprocessed_path = os.path.join(self.output_dir, "model-infer.onnx")
try:
command = [
"python", "-m", "onnxruntime.quantization.preprocess",
"--input", model_path,
"--output", preprocessed_path
]
result = subprocess.run(command, check=True, capture_output=True, text=True)
self.logger.info(f"Model preprocessing successful. Output saved to {preprocessed_path}")
return preprocessed_path
except subprocess.CalledProcessError as e:
self.logger.error(f"Preprocessing failed: {e.stderr}")
raise
except Exception as e:
self.logger.error(f"Preprocessing failed: {str(e)}")
raise
def quantize_model(self, model_path: str):
weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16}
all_quantized_paths = []
for weight_type in weight_types.keys():
quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx")
try:
quantize_dynamic(
model_path,
quantized_path,
weight_type=weight_types[weight_type]
)
self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}")
all_quantized_paths.append(quantized_path)
except Exception as e:
self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}")
raise
return all_quantized_paths
def convert(self):
try:
onnx_path = self.export_to_onnx()
if self.verify_model(onnx_path):
# Add preprocessing step before quantization
preprocessed_path = self.preprocess_model(onnx_path)
# Use preprocessed model for quantization
quantized_paths = self.quantize_model(preprocessed_path)
tokenizer_path = os.path.join(self.output_dir, "tokenizer")
self.tokenizer.save_pretrained(tokenizer_path)
self.logger.info(f"Tokenizer saved to {tokenizer_path}")
return {
'onnx_model': onnx_path,
'preprocessed_model': preprocessed_path,
'quantized_models': quantized_paths,
'tokenizer': tokenizer_path
}
else:
raise Exception("Model verification failed")
except Exception as e:
self.logger.error(f"Conversion process failed: {str(e)}")
raise
if __name__ == "__main__":
MODEL_NAME = "timpal0l/mdeberta-v3-base-squad2" # Or any other suitable model
OUTPUT_DIR = "onnx"
try:
converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR)
results = converter.convert()
print("\nConversion completed successfully!")
print(f"ONNX model path: {results['onnx_model']}")
print(f"Preprocessed model path: {results['preprocessed_model']}")
print(f"Quantized model paths: {results['quantized_models']}")
print(f"Tokenizer path: {results['tokenizer']}")
except Exception as e:
print(f"Conversion failed: {str(e)}")