Spaces:
Running
Running
File size: 5,194 Bytes
f9714b1 1c56cce f9714b1 1c56cce f9714b1 1b53210 f9714b1 1b53210 f9714b1 788c760 1b53210 f9714b1 1b53210 624162d 1b53210 f9714b1 788c760 1b53210 f9714b1 1b53210 788c760 1b53210 f9714b1 1b53210 788c760 1b53210 788c760 f9714b1 1b53210 f9714b1 788c760 f9714b1 788c760 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import requests
import os
VERSION = "v0.1.1"
class LocationFinder:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
model_url = f"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/{VERSION}/onnx/model_quantized.onnx"
model_dir_path = "models"
model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
if not os.path.exists(model_dir_path):
os.makedirs(model_dir_path)
if not os.path.exists(model_path):
print("Downloading ONNX model...")
response = requests.get(model_url)
with open(model_path, "wb") as f:
f.write(response.content)
print("ONNX model downloaded.")
# Load the ONNX model
self.ort_session = ort.InferenceSession(model_path)
def find_location(self, sequence, verbose=False):
inputs = self.tokenizer(sequence,
return_tensors="np", # ONNX requires inputs in NumPy format
padding="max_length", # Pad to max length
truncation=True, # Truncate if the text is too long
max_length=64)
input_feed = {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64),
}
# Run inference with the ONNX model
outputs = self.ort_session.run(None, input_feed)
logits = outputs[0] # Assuming the model output is logits
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
predicted_ids = np.argmax(logits, axis=-1)
predicted_probs = np.max(probabilities, axis=-1)
# Define the threshold for NER probability
threshold = 0.6
# Define the label map for city, state, citystate, etc.
label_map = {
0: "O", # Outside any named entity
1: "B-PER", # Beginning of a person entity
2: "I-PER", # Inside a person entity
3: "B-ORG", # Beginning of an organization entity
4: "I-ORG", # Inside an organization entity
5: "B-CITY", # Beginning of a city entity
6: "I-CITY", # Inside a city entity
7: "B-STATE", # Beginning of a state entity
8: "I-STATE", # Inside a state entity
9: "B-CITYSTATE", # Beginning of a city_state entity
10: "I-CITYSTATE", # Inside a city_state entity
}
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
# Initialize lists to hold detected entities
city_entities = []
state_entities = []
city_state_entities = []
for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
if prob > threshold:
if token in ["[CLS]", "[SEP]", "[PAD]"]:
continue
if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
# Handle the case of continuation tokens (like "##" in subwords)
if token.startswith("##") and city_entities:
city_entities[-1] += token[2:] # Remove "##" and append to the last token
else:
city_entities.append(token)
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
if token.startswith("##") and state_entities:
state_entities[-1] += token[2:]
else:
state_entities.append(token)
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
if token.startswith("##") and city_state_entities:
city_state_entities[-1] += token[2:]
else:
city_state_entities.append(token)
# Combine city_state entities and split into city and state if necessary
if city_state_entities:
city_state_str = " ".join(city_state_entities)
city_state_split = city_state_str.split(",") # Split on comma to separate city and state
city_res = city_state_split[0].strip() if city_state_split[0] else None
state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
else:
# If no city_state entities, use detected city and state entities separately
city_res = " ".join(city_entities).strip() if city_entities else None
state_res = " ".join(state_entities).strip() if state_entities else None
# Return the detected city and state as separate components
return {
'city': city_res,
'state': state_res
}
if __name__ == '__main__':
query = "weather in san francisco, ca"
loc_finder = LocationFinder()
entities = loc_finder.find_location(query)
print(f"query = {query} => {entities}")
|