| from transformers import AutoTokenizer, AutoModel | |
| import torch | |
| max_seq_length=128 | |
| model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") | |
| model.eval() | |
| inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64), | |
| "attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64), | |
| "token_type_ids": torch.ones(1, max_seq_length, dtype=torch.int64)} | |
| symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} | |
| torch.onnx.export(model, args=tuple(inputs.values()), f='pytorch_model.onnx', export_params=True, | |
| input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['last_hidden_state'], | |
| dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names}) | |