File size: 588 Bytes
5fae609 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import fire
from pathlib import Path
from typing import Optional
from transformers.convert_graph_to_onnx import convert, quantize
def convert_model(model: str, path: Optional[str] = None):
if not path:
folder_name = Path(".").resolve().name
path = Path("onnx") / f"{folder_name}.onnx"
convert(
framework="pt",
model=str(Path(model).resolve()),
output=Path(path),
opset=11,
pipeline_name="sentiment-analysis", # needed for classification tasks
)
quantize(Path(path))
if __name__ == "__main__":
fire.Fire()
|