|
--- |
|
license: apache-2.0 |
|
tags: |
|
- vision |
|
- image-classification |
|
datasets: |
|
- imagenet-1k |
|
--- |
|
|
|
This model is a fork of [facebook/levit-256](https://huggingface.co/facebook/levit-256), where: |
|
|
|
* `nn.BatchNorm2d` and `nn.Conv2d` are fused |
|
* `nn.BatchNorm1d` and `nn.Linear` are fused |
|
|
|
and the optimized model is converted to the onnx format. |
|
|
|
## How to use |
|
|
|
```python |
|
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification |
|
from transformers import AutoFeatureExtractor |
|
|
|
from PIL import Image |
|
import requests |
|
|
|
preprocessor = AutoFeatureExtractor.from_pretrained("fxmarty/levit-256-onnx") |
|
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx") |
|
|
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
inputs = preprocessor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
|
|
predicted_class_idx = outputs.logits.argmax(-1).item() |
|
print("Predicted class:", model.config.id2label[predicted_class_idx]) |
|
``` |
|
|
|
To be safe, check as well that the onnx model returns the same logits as the PyTorch model: |
|
|
|
```python |
|
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification |
|
from transformers import AutoModelForImageClassification |
|
|
|
pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256") |
|
pt_model.eval() |
|
|
|
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx") |
|
|
|
inp = {"pixel_values": torch.rand(1, 3, 224, 224)} |
|
|
|
with torch.no_grad(): |
|
res = pt_model(**inp) |
|
res_ort = ort_model(**inp) |
|
|
|
assert torch.allclose(res.logits, res_ort.logits, atol=1e-4) |
|
``` |
|
|
|
## Benchmarking |
|
|
|
More than x2 throughput with batch normalization folding and onnxruntime 🔥 |
|
|
|
Below you can find latency percentiles and mean (in ms), and the models throughput (in iterations/s). |
|
|
|
``` |
|
PyTorch runtime: |
|
|
|
{'latency_50': 22.3024695, |
|
'latency_90': 23.1230725, |
|
'latency_95': 23.2653985, |
|
'latency_99': 23.60095705, |
|
'latency_999': 23.865580469999998, |
|
'latency_mean': 22.442956878923766, |
|
'latency_std': 0.46544295612971265, |
|
'nb_forwards': 446, |
|
'throughput': 44.6} |
|
|
|
Optimum-onnxruntime runtime: |
|
|
|
{'latency_50': 9.302445, |
|
'latency_90': 9.782875, |
|
'latency_95': 9.9071944, |
|
'latency_99': 11.084606999999997, |
|
'latency_999': 12.035858692000001, |
|
'latency_mean': 9.357703552853133, |
|
'latency_std': 0.4018553286992142, |
|
'nb_forwards': 1069, |
|
'throughput': 106.9} |
|
|
|
``` |
|
|
|
Run on your own machine with: |
|
|
|
```python |
|
from optimum.runs_base import TimeBenchmark |
|
|
|
from pprint import pprint |
|
|
|
time_benchmark_ort = TimeBenchmark( |
|
model=ort_model, |
|
batch_size=1, |
|
input_length=224, |
|
model_input_names={"pixel_values"}, |
|
warmup_runs=10, |
|
duration=10 |
|
) |
|
|
|
results_ort = time_benchmark_ort.execute() |
|
|
|
with torch.no_grad(): |
|
time_benchmark_pt = TimeBenchmark( |
|
model=pt_model, |
|
batch_size=1, |
|
input_length=224, |
|
model_input_names={"pixel_values"}, |
|
warmup_runs=10, |
|
duration=10 |
|
) |
|
|
|
results_pt = time_benchmark_pt.execute() |
|
|
|
print("PyTorch runtime:\n") |
|
pprint(results_pt) |
|
|
|
print("\nOptimum-onnxruntime runtime:\n") |
|
pprint(results_ort) |
|
``` |
|
|