import torch from transformers import AutoModelForQuestionAnswering from transformers import AutoTokenizer, BertConfig import onnx from onnxruntime.quantization import quantize_dynamic, QuantType import os import logging import subprocess from typing import Optional, Dict, Any class ONNXModelConverter: def __init__(self, model_name: str, output_dir: str): self.model_name = model_name self.output_dir = output_dir self.setup_logging() os.makedirs(output_dir, exist_ok=True) self.logger.info(f"Loading tokenizer {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.logger.info(f"Loading model {model_name}...") self.model = AutoModelForQuestionAnswering.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float32 ) self.model.eval() def setup_logging(self): self.logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) self.logger.addHandler(handler) def prepare_dummy_inputs(self): dummy_input = self.tokenizer( "Hello, how are you?", return_tensors="pt", padding=True, truncation=True, max_length=128 ) return { 'input_ids': dummy_input['input_ids'], 'attention_mask': dummy_input['attention_mask'], 'token_type_ids': dummy_input['token_type_ids'] } def export_to_onnx(self): output_path = os.path.join(self.output_dir, "model.onnx") inputs = self.prepare_dummy_inputs() dynamic_axes = { 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, 'token_type_ids': {0: 'batch_size', 1: 'sequence_length'}, 'start_logits': {0: 'batch_size', 1: 'sequence_length'}, 'end_logits': {0: 'batch_size', 1: 'sequence_length'}, } class ModelWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, attention_mask, token_type_ids): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) return outputs.start_logits, outputs.end_logits wrapped_model = ModelWrapper(self.model) try: torch.onnx.export( wrapped_model, (inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), output_path, export_params=True, opset_version=14, # Or a suitable version do_constant_folding=True, input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['start_logits', 'end_logits'], dynamic_axes=dynamic_axes, verbose=False ) self.logger.info(f"Model exported to {output_path}") return output_path except Exception as e: self.logger.error(f"ONNX export failed: {str(e)}") raise def verify_model(self, model_path: str): try: onnx_model = onnx.load(model_path) onnx.checker.check_model(onnx_model) self.logger.info("ONNX model verification successful") return True except Exception as e: self.logger.error(f"Model verification failed: {str(e)}") return False def preprocess_model(self, model_path: str) -> str: preprocessed_path = os.path.join(self.output_dir, "model-infer.onnx") try: command = [ "python", "-m", "onnxruntime.quantization.preprocess", "--input", model_path, "--output", preprocessed_path ] result = subprocess.run(command, check=True, capture_output=True, text=True) self.logger.info(f"Model preprocessing successful. Output saved to {preprocessed_path}") return preprocessed_path except subprocess.CalledProcessError as e: self.logger.error(f"Preprocessing failed: {e.stderr}") raise except Exception as e: self.logger.error(f"Preprocessing failed: {str(e)}") raise def quantize_model(self, model_path: str): weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16} all_quantized_paths = [] for weight_type in weight_types.keys(): quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx") try: quantize_dynamic( model_path, quantized_path, weight_type=weight_types[weight_type] ) self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}") all_quantized_paths.append(quantized_path) except Exception as e: self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}") raise return all_quantized_paths def convert(self): try: onnx_path = self.export_to_onnx() if self.verify_model(onnx_path): # Add preprocessing step before quantization preprocessed_path = self.preprocess_model(onnx_path) # Use preprocessed model for quantization quantized_paths = self.quantize_model(preprocessed_path) tokenizer_path = os.path.join(self.output_dir, "tokenizer") self.tokenizer.save_pretrained(tokenizer_path) self.logger.info(f"Tokenizer saved to {tokenizer_path}") return { 'onnx_model': onnx_path, 'preprocessed_model': preprocessed_path, 'quantized_models': quantized_paths, 'tokenizer': tokenizer_path } else: raise Exception("Model verification failed") except Exception as e: self.logger.error(f"Conversion process failed: {str(e)}") raise if __name__ == "__main__": MODEL_NAME = "timpal0l/mdeberta-v3-base-squad2" # Or any other suitable model OUTPUT_DIR = "onnx" try: converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR) results = converter.convert() print("\nConversion completed successfully!") print(f"ONNX model path: {results['onnx_model']}") print(f"Preprocessed model path: {results['preprocessed_model']}") print(f"Quantized model paths: {results['quantized_models']}") print(f"Tokenizer path: {results['tokenizer']}") except Exception as e: print(f"Conversion failed: {str(e)}")