Spaces:
Running
Running
import onnxruntime as ort | |
from transformers import AutoTokenizer | |
import numpy as np | |
import requests | |
import os | |
VERSION = "v0.1.1" | |
class LocationFinder: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA") | |
model_url = f"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/{VERSION}/onnx/model_quantized.onnx" | |
model_dir_path = "models" | |
model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA" | |
if not os.path.exists(model_dir_path): | |
os.makedirs(model_dir_path) | |
if not os.path.exists(model_path): | |
print("Downloading ONNX model...") | |
response = requests.get(model_url) | |
with open(model_path, "wb") as f: | |
f.write(response.content) | |
print("ONNX model downloaded.") | |
# Load the ONNX model | |
self.ort_session = ort.InferenceSession(model_path) | |
def find_location(self, sequence, verbose=False): | |
inputs = self.tokenizer(sequence, | |
return_tensors="np", # ONNX requires inputs in NumPy format | |
padding="max_length", # Pad to max length | |
truncation=True, # Truncate if the text is too long | |
max_length=64) | |
input_feed = { | |
'input_ids': inputs['input_ids'].astype(np.int64), | |
'attention_mask': inputs['attention_mask'].astype(np.int64), | |
} | |
# Run inference with the ONNX model | |
outputs = self.ort_session.run(None, input_feed) | |
logits = outputs[0] # Assuming the model output is logits | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) | |
predicted_ids = np.argmax(logits, axis=-1) | |
predicted_probs = np.max(probabilities, axis=-1) | |
# Define the threshold for NER probability | |
threshold = 0.6 | |
# Define the label map for city, state, citystate, etc. | |
label_map = { | |
0: "O", # Outside any named entity | |
1: "B-PER", # Beginning of a person entity | |
2: "I-PER", # Inside a person entity | |
3: "B-ORG", # Beginning of an organization entity | |
4: "I-ORG", # Inside an organization entity | |
5: "B-CITY", # Beginning of a city entity | |
6: "I-CITY", # Inside a city entity | |
7: "B-STATE", # Beginning of a state entity | |
8: "I-STATE", # Inside a state entity | |
9: "B-CITYSTATE", # Beginning of a city_state entity | |
10: "I-CITYSTATE", # Inside a city_state entity | |
} | |
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
# Initialize lists to hold detected entities | |
city_entities = [] | |
state_entities = [] | |
city_state_entities = [] | |
for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]): | |
if prob > threshold: | |
if token in ["[CLS]", "[SEP]", "[PAD]"]: | |
continue | |
if label_map[predicted_id] in ["B-CITY", "I-CITY"]: | |
# Handle the case of continuation tokens (like "##" in subwords) | |
if token.startswith("##") and city_entities: | |
city_entities[-1] += token[2:] # Remove "##" and append to the last token | |
else: | |
city_entities.append(token) | |
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]: | |
if token.startswith("##") and state_entities: | |
state_entities[-1] += token[2:] | |
else: | |
state_entities.append(token) | |
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]: | |
if token.startswith("##") and city_state_entities: | |
city_state_entities[-1] += token[2:] | |
else: | |
city_state_entities.append(token) | |
# Combine city_state entities and split into city and state if necessary | |
if city_state_entities: | |
city_state_str = " ".join(city_state_entities) | |
city_state_split = city_state_str.split(",") # Split on comma to separate city and state | |
city_res = city_state_split[0].strip() if city_state_split[0] else None | |
state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None | |
else: | |
# If no city_state entities, use detected city and state entities separately | |
city_res = " ".join(city_entities).strip() if city_entities else None | |
state_res = " ".join(state_entities).strip() if state_entities else None | |
# Return the detected city and state as separate components | |
return { | |
'city': city_res, | |
'state': state_res | |
} | |
if __name__ == '__main__': | |
query = "weather in san francisco, ca" | |
loc_finder = LocationFinder() | |
entities = loc_finder.find_location(query) | |
print(f"query = {query} => {entities}") | |