Upload ONNX weights

#3
by Xenova HF staff - opened

Conversion code:

import os
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize

query_prompt = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
queries = [
    "What are some ways to reduce stress?",
    "What are the benefits of drinking green tea?",
]
queries = [query_prompt + query for query in queries]
# docs do not need any prompts
docs = [
    "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.",
    "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.",
]

# The path of your model after cloning it
model_dir = "./stella_en_1.5B_v5"

vector_dim = 1024
vector_linear_directory = f"2_Dense_{vector_dim}"
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
vector_linear = torch.nn.Linear(in_features=model.config.hidden_size, out_features=vector_dim)
vector_linear_dict = {
    k.replace("linear.", ""): v for k, v in
    torch.load(os.path.join(model_dir, f"{vector_linear_directory}/pytorch_model.bin"), map_location=torch.device('cpu')).items()
}
vector_linear.load_state_dict(vector_linear_dict)
vector_linear.eval()

model.vector_linear = vector_linear
original_forward = model.forward
def patched_forward(input_ids, attention_mask):
    last_hidden_state = original_forward(input_ids=input_ids, attention_mask=attention_mask)[0]
    last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
    query_vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    return model.vector_linear(query_vectors)
model.forward = patched_forward

# Embed the queries
with torch.no_grad():
    input_data = tokenizer(queries, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    outputs = model(**input_data)
    query_vectors = normalize(outputs.cpu().numpy())

# Embed the documents
with torch.no_grad():
    input_data = tokenizer(docs, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    outputs = model(**input_data)
    docs_vectors = normalize(outputs.cpu().numpy())

print(query_vectors.shape, docs_vectors.shape)
# (2, 1024) (2, 1024)

similarities = query_vectors @ docs_vectors.T
print(similarities)
# [[0.8397531  0.29900077]
#  [0.32818374 0.80954516]]

Followed by:

input_data = tokenizer(queries, padding="longest", truncation=True, max_length=512, return_tensors="pt")

# Export the model
torch.onnx.export(model,               # model being run
                  (input_data['input_ids'], input_data['attention_mask']), # model input (or a tuple for multiple inputs)
                  "model.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=14,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input_ids', 'attention_mask'],   # the model's input names
                  output_names = ['sentence_embedding'], # the model's output names
                  dynamic_axes={
                    "input_ids": {0: "batch_size", 1: "sequence_length"},
                    "attention_mask": {0: "batch_size", 1: "sequence_length"},
                    "sentence_embedding": {0: "batch_size"},
                  }
)
StellaEncoder org

Hi, have you tested whether the output of onnx is consistent with the output of the original model?

I was unable to use model_fp16.onnx model with the following error:

E           onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from /Users/narayan/.cache/huggingface/hub/models--dunzhang--stella_en_1.5B_v5/snapshots/edbd7b06454d3f4b409c78e5d68c965387d8680c/onnx/model_fp16.onnx failed:Type Error: Type (tensor(float16)) of output arg (/Cast_1_output_0) of node (/Cast_1) does not match expected type (tensor(float)).

Example usage of the model.onnx file for CPU inference:

from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions

options = SessionOptions()
options.log_severity_level = 3  # Only errors, ignore warnings
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
model = InferenceSession("path/to/model.onnx", options, providers="CPUExecutionProvider")

tokenizer = AutoTokenizer.from_pretrained("dunzhang/stella_en_1.5B_v5")
query_prompt = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "

texts = ["This is an example query!"]
input_text = tokenizer(
    [query_prompt + text for text in texts],
    return_tensors="pt", 
    padding=True, 
    truncation=True
)
inputs = {
    "input_ids": input_text["input_ids"].cpu().numpy(), 
    "attention_mask": input_text["attention_mask"].cpu().numpy()
}
outputs = model.run(None, inputs)
print(outputs[0])

The output above is the same as on a GPU:

from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
    "dunzhang/stella_en_1.5B_v5",
    device="cuda:0",
    trust_remote_code=True
)
print(model.encode(texts, prompt_name="s2p_query"))
print(model.encode([f"Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: {text}" for text in texts]))
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment