|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
|
|
max_seq_length=128 |
|
|
|
model = AutoModel.from_pretrained("sentence-transformers/all-mpnet-base-v2") |
|
model.eval() |
|
|
|
inputs = {"input_ids": torch.ones(1, max_seq_length, dtype=torch.int64), |
|
"attention_mask": torch.ones(1, max_seq_length, dtype=torch.int64), |
|
"token_type_ids": torch.ones(1, max_seq_length, dtype=torch.int64)} |
|
|
|
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} |
|
|
|
torch.onnx.export(model, args=tuple(inputs.values()), f='pytorch_model.onnx', export_params=True, |
|
input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['last_hidden_state'], |
|
dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names}) |
|
|
|
|