This model is a fork of facebook/levit-256, where:
nn.BatchNorm2d
andnn.Conv2d
are fusednn.BatchNorm1d
andnn.Linear
are fused
and the optimized model is converted to the onnx format.
The fusion of layers leverages torch.fx, using the transformations FuseBatchNorm2dInConv2d
and FuseBatchNorm1dInLinear
soon to be available to use out-of-the-box with 🤗 Optimum, check it out: https://huggingface.co/docs/optimum/main/en/fx/optimization#the-transformation-guide .
How to use
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoFeatureExtractor
from PIL import Image
import requests
preprocessor = AutoFeatureExtractor.from_pretrained("fxmarty/levit-256-onnx")
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
inputs = preprocessor(images=image, return_tensors="pt")
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
To be safe, check as well that the onnx model returns the same logits as the PyTorch model:
from optimum.onnxruntime.modeling_ort import ORTModelForImageClassification
from transformers import AutoModelForImageClassification
pt_model = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
pt_model.eval()
ort_model = ORTModelForImageClassification.from_pretrained("fxmarty/levit-256-onnx")
inp = {"pixel_values": torch.rand(1, 3, 224, 224)}
with torch.no_grad():
res = pt_model(**inp)
res_ort = ort_model(**inp)
assert torch.allclose(res.logits, res_ort.logits, atol=1e-4)
Benchmarking
More than x2 throughput with batch normalization folding and onnxruntime 🔥
Below you can find latency percentiles and mean (in ms), and the models throughput (in iterations/s).
PyTorch runtime:
{'latency_50': 22.3024695,
'latency_90': 23.1230725,
'latency_95': 23.2653985,
'latency_99': 23.60095705,
'latency_999': 23.865580469999998,
'latency_mean': 22.442956878923766,
'latency_std': 0.46544295612971265,
'nb_forwards': 446,
'throughput': 44.6}
Optimum-onnxruntime runtime:
{'latency_50': 9.302445,
'latency_90': 9.782875,
'latency_95': 9.9071944,
'latency_99': 11.084606999999997,
'latency_999': 12.035858692000001,
'latency_mean': 9.357703552853133,
'latency_std': 0.4018553286992142,
'nb_forwards': 1069,
'throughput': 106.9}
Run on your own machine with:
from optimum.runs_base import TimeBenchmark
from pprint import pprint
time_benchmark_ort = TimeBenchmark(
model=ort_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_ort = time_benchmark_ort.execute()
with torch.no_grad():
time_benchmark_pt = TimeBenchmark(
model=pt_model,
batch_size=1,
input_length=224,
model_input_names={"pixel_values"},
warmup_runs=10,
duration=10
)
results_pt = time_benchmark_pt.execute()
print("PyTorch runtime:\n")
pprint(results_pt)
print("\nOptimum-onnxruntime runtime:\n")
pprint(results_ort)
- Downloads last month
- 5
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.