import torch import torch.onnx import onnx from VitsModelSplit.vits_model_only_d import Vits_models_only_decoder from VitsModelSplit.vits_model import VitsModel import gradio as gr import os from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-ar") def create_file(file_path): # مسار الملف المؤقت #file_path = "DDFGDdd.onnx" # إنشاء ملف ONNX تجريبي في حالة عدم وجوده if not os.path.exists(file_path): #with open(file_path, "w") as file: #file.write("This is a test ONNX model file.") return None # إرجاع مسار الملف حتى يمكن تنزيله return file_path class OnnxModelConverter: def __init__(self): self.model = None def download_file(self,file_path): if not os.path.exists(file_path): #with open(file_path, "w") as file: #file.write("This is a test ONNX model.") return None return file_path def convert(self, model_name, token, onnx_filename, conversion_type): """ Main function to handle different types of model conversions. Args: model_name (str): Name of the model to convert. token (str): Access token for loading the model. onnx_filename (str): Desired filename for the ONNX output. conversion_type (str): Type of conversion ('decoder', 'only_decoder', or 'full_model'). Returns: str: The path to the generated ONNX file. """ if conversion_type == "decoder": return self.convert_decoder(model_name, token, onnx_filename) elif conversion_type == "only_decoder": return self.convert_only_decoder(model_name, token, onnx_filename) elif conversion_type == "full_model": return self.convert_full_model(model_name, token, onnx_filename) else: raise ValueError("Invalid conversion type. Choose from 'decoder', 'only_decoder', or 'full_model'.") def convert_decoder(self, model_name, token, onnx_filename): """ Converts only the decoder part of the Vits model to ONNX format. Args: model_name (str): Name of the model to convert. token (str): Access token for loading the model. onnx_filename (str): Desired filename for the ONNX output. Returns: str: The path to the generated ONNX file. """ model = VitsModel.from_pretrained(model_name, token=token) onnx_file = f"/tmp/{onnx_filename}.onnx" example_input = torch.randn(1, 192, 10) torch.onnx.export( model.decoder, example_input, onnx_file, opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={"input": {0: "batch_size", 2: "seq_len"}, "output": {0: "batch_size", 1: "sequence_length"}} ) return self.download_file(onnx_file) def convert_only_decoder(self, model_name, token, onnx_filename): """ Converts only the decoder part of the Vits model to ONNX format. Args: model_name (str): Name of the model to convert. token (str): Access token for loading the model. onnx_filename (str): Desired filename for the ONNX output. Returns: str: The path to the generated ONNX file. """ model = Vits_models_only_decoder.from_pretrained(model_name, token=token) onnx_file = f"/tmp/{onnx_filename}.onnx" inputs = tokenizer("السلام عليكم كيف الحال", return_tensors="pt") # Trace the decoder part of the model example_inputs = inputs.input_ids.type(torch.LongTensor) torch.onnx.export(model, example_inputs, onnx_file, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"}, "output": {0: "batch_size", 1: "sequence_length"}}) return self.download_file(onnx_file) def convert_full_model(self, model_name, token, onnx_filename): """ Converts the full Vits model (including encoder and decoder) to ONNX format. Args: model_name (str): Name of the model to convert. token (str): Access token for loading the model. onnx_filename (str): Desired filename for the ONNX output. Returns: str: The path to the generated ONNX file. """ model = VitsModel.from_pretrained(model_name, token=token) onnx_file = f"/tmp/{onnx_filename}.onnx" vocab_size = model.text_encoder.embed_tokens.weight.size(0) example_input = torch.randint(0, vocab_size, (1, 100), dtype=torch.long) torch.onnx.export( model, example_input, onnx_file, opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length'}, 'output': {0: 'batch_size'}} ) return self.download_file(onnx_file) def starrt(self): with gr.Blocks() as demo: with gr.Row(): with gr.Column(): text_n_model=gr.Textbox(label="name model") text_n_token=gr.Textbox(label="token") text_n_onxx=gr.Textbox(label="name model onxx") choice = gr.Dropdown(choices=["decoder", "only_decoder", "full_model"], label="My Dropdown") with gr.Column(): btn=gr.Button("convert") label=gr.Label("return name model onxx") btn.click(self.convert,[text_n_model,text_n_token,text_n_onxx,choice],[gr.File(label="Download ONNX File")]) btx=gr.Textbox("namefile") download_button1=gr.Button("send") download_button = gr.File(label="Download ONNX File") download_button1.click(create_file,[btx],[download_button]) #choice.change(fn=function_change, inputs=choice, outputs=label) return demo c=OnnxModelConverter() cc=c.starrt() cc.launch(share=True)