|
"""This script exports BGEM3 to ONNX format which can be run using ONNX Runtime. |
|
By default, the script does not apply any optimization to the ONNX model. |
|
""" |
|
|
|
import copy |
|
import os |
|
import shutil |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
from typing import Annotated |
|
|
|
import torch |
|
import typer |
|
from huggingface_hub import snapshot_download |
|
from optimum.exporters.onnx import onnx_export_from_model |
|
from optimum.exporters.onnx.model_configs import XLMRobertaOnnxConfig |
|
from optimum.onnxruntime import ORTModelForCustomTasks |
|
from torch import Tensor |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
PretrainedConfig, |
|
PreTrainedModel, |
|
XLMRobertaConfig, |
|
) |
|
|
|
|
|
class BGEM3InferenceModel(PreTrainedModel): |
|
"""Based on: |
|
1. https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py |
|
2. https://huggingface.co/aapot/bge-m3-onnx/blob/main/export_onnx.py |
|
|
|
The main changes here is that we are inheriting from `PreTrainedModel` which has the |
|
methods .from_pretrained and .push_to_hub. This allows us to easily convert the model |
|
""" |
|
|
|
config_class = XLMRobertaConfig |
|
base_model_prefix = "BGEM3InferenceModel" |
|
model_tags = ["BAAI/bge-m3"] |
|
|
|
def __init__( |
|
self, |
|
model_name: str = "BAAI/bge-m3", |
|
colbert_dim: int = -1, |
|
) -> None: |
|
super().__init__(config=PretrainedConfig()) |
|
|
|
model_name = snapshot_download(repo_id=model_name) |
|
self.config = AutoConfig.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
self.colbert_linear = torch.nn.Linear( |
|
in_features=self.model.config.hidden_size, |
|
out_features=( |
|
self.model.config.hidden_size if colbert_dim == -1 else colbert_dim |
|
), |
|
) |
|
self.sparse_linear = torch.nn.Linear( |
|
in_features=self.model.config.hidden_size, out_features=1 |
|
) |
|
colbert_state_dict = torch.load( |
|
os.path.join(model_name, "colbert_linear.pt"), map_location="cpu" |
|
) |
|
sparse_state_dict = torch.load( |
|
os.path.join(model_name, "sparse_linear.pt"), map_location="cpu" |
|
) |
|
self.colbert_linear.load_state_dict(colbert_state_dict) |
|
self.sparse_linear.load_state_dict(sparse_state_dict) |
|
|
|
def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: |
|
return last_hidden_state[:, 0] |
|
|
|
def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor: |
|
with torch.no_grad(): |
|
return torch.relu(self.sparse_linear(last_hidden_state)) |
|
|
|
def colbert_embedding( |
|
self, last_hidden_state: Tensor, attention_mask: Tensor |
|
) -> Tensor: |
|
with torch.no_grad(): |
|
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:]) |
|
colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float() |
|
return colbert_vecs |
|
|
|
def forward( |
|
self, input_ids: Tensor, attention_mask: Tensor |
|
) -> dict[str, dict[str, Tensor]]: |
|
"""Forward pass of the model with custom output dict with dense, sparse, and |
|
colbert embeddings. Dense and colbert embeddings are normalized.""" |
|
with torch.no_grad(): |
|
last_hidden_state = self.model( |
|
input_ids=input_ids, attention_mask=attention_mask, return_dict=True |
|
).last_hidden_state |
|
|
|
output = {} |
|
dense_vecs = self.dense_embedding(last_hidden_state) |
|
output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1) |
|
|
|
sparse_vecs = self.sparse_embedding(last_hidden_state) |
|
output["sparse_vecs"] = sparse_vecs |
|
|
|
colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask) |
|
output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1) |
|
|
|
return output |
|
|
|
|
|
class BGEM3OnnxConfig(XLMRobertaOnnxConfig): |
|
"""Modify RobertaOnnxConfig to include the additional outputs of the model |
|
(dense_vecs, sparse_vecs, colbert_vecs).""" |
|
|
|
@property |
|
def outputs(self) -> dict[str, dict[int, str]]: |
|
""" |
|
Dict containing the axis definition of the output tensors to provide to the model. |
|
|
|
Returns: |
|
`Dict[str, Dict[int, str]]`: A mapping of each output name to a mapping of axis |
|
position to the axes symbolic name. |
|
""" |
|
return copy.deepcopy( |
|
OrderedDict( |
|
{ |
|
"dense_vecs": {0: "batch_size", 1: "embedding"}, |
|
"sparse_vecs": {0: "batch_size", 1: "token", 2: "weight"}, |
|
"colbert_vecs": {0: "batch_size", 1: "token", 2: "embedding"}, |
|
} |
|
) |
|
) |
|
|
|
|
|
def main( |
|
output: Annotated[ |
|
str, typer.Option(help="Path to directory generated ONNX model is stored.") |
|
] = "./onnx", |
|
opset: Annotated[int, typer.Option(help="ONNX opset version number.")] = 17, |
|
device: Annotated[ |
|
str, typer.Option(help="Device used to perform the export 'cpu' or 'cuda'.") |
|
] = "cpu", |
|
optimize: Annotated[ |
|
str, |
|
typer.Option( |
|
help=( |
|
"Allows to run ONNX Runtime optimizations directly during the export. " |
|
"Some of these optimizations are specific to ONNX Runtime, and " |
|
"the resulting ONNX will not be usable with other runtime as OpenVINO or TensorRT. " |
|
"Possible options:\n" |
|
" - None: No optimization\n" |
|
" - O1: Basic general optimizations\n" |
|
" - O2: Basic and extended general optimizations, transformers-specific fusions\n" |
|
" - O3: Same as O2 with GELU approximation\n" |
|
" - O4: Same as O3 with mixed precision (fp16, GPU-only, requires `--device cuda`)" |
|
), |
|
), |
|
] = None, |
|
atol: Annotated[ |
|
str, |
|
typer.Option( |
|
help=( |
|
"If specified, the absolute difference tolerance when validating the model. " |
|
"Otherwise, the default atol for the model will be used." |
|
) |
|
), |
|
] = None, |
|
push_to_hub_repo_id: Annotated[ |
|
str, |
|
typer.Option( |
|
help="Huggingface Hub repo id in `namespace/model_name` format." |
|
"If None, then model will not be pushed to Huggingface Hub." |
|
), |
|
] = None, |
|
) -> None: |
|
model = BGEM3InferenceModel(model_name="BAAI/bge-m3") |
|
|
|
onnx_config = BGEM3OnnxConfig(config=model.config) |
|
onnx_export_from_model( |
|
model, |
|
output=output, |
|
task="feature-extraction", |
|
custom_onnx_configs={"model": onnx_config}, |
|
opset=opset, |
|
optimize=optimize, |
|
atol=atol, |
|
device=device, |
|
) |
|
|
|
|
|
try: |
|
shutil.copy(__file__, output) |
|
except Exception as ex: |
|
print(f"Error copying script to export directory: {ex}") |
|
try: |
|
shutil.copy(str(Path(__file__).parent / "model_card.md"), output) |
|
shutil.move(f"{output}/model_card.md", f"{output}/README.md") |
|
except Exception as ex: |
|
print(f"Error copying model card to export directory: {ex}") |
|
|
|
|
|
if push_to_hub_repo_id: |
|
local_onnx_model = ORTModelForCustomTasks.from_pretrained(output) |
|
local_onnx_model.push_to_hub( |
|
save_directory=output, |
|
repository_id=push_to_hub_repo_id, |
|
use_auth_token=True, |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
typer.run(main) |
|
|