Spaces:
Running
Running
Chidam Gopal
commited on
Commit
•
f9714b1
1
Parent(s):
8bf96ea
intent and location updates
Browse files- app.py +11 -1
- infer_intent.py +18 -18
- infer_location.py +104 -0
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import streamlit.components.v1 as components
|
3 |
from infer_intent import IntentClassifier
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
|
6 |
st.title("Intent classifier")
|
@@ -10,11 +11,20 @@ def get_intent_classifier():
|
|
10 |
cls = IntentClassifier()
|
11 |
return cls
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
cls = get_intent_classifier()
|
14 |
query = st.text_input("Enter a query", value="What is the weather today")
|
15 |
pred_result, proba_result = cls.find_intent(query)
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
keys = list(proba_result.keys())
|
19 |
values = list(proba_result.values())
|
20 |
|
|
|
1 |
import streamlit as st
|
2 |
import streamlit.components.v1 as components
|
3 |
from infer_intent import IntentClassifier
|
4 |
+
from infer_location import LocationFinder
|
5 |
import matplotlib.pyplot as plt
|
6 |
|
7 |
st.title("Intent classifier")
|
|
|
11 |
cls = IntentClassifier()
|
12 |
return cls
|
13 |
|
14 |
+
@st.cache_resource
|
15 |
+
def get_location_finder():
|
16 |
+
ner = LocationFinder()
|
17 |
+
return ner
|
18 |
+
|
19 |
cls = get_intent_classifier()
|
20 |
query = st.text_input("Enter a query", value="What is the weather today")
|
21 |
pred_result, proba_result = cls.find_intent(query)
|
22 |
|
23 |
+
ner = get_location_finder()
|
24 |
+
location = ner.find_location(query)
|
25 |
+
|
26 |
+
st.markdown(f"Intent = :green[{pred_result}]")
|
27 |
+
st.markdown(f"Location = :green[{location}]")
|
28 |
keys = list(proba_result.keys())
|
29 |
values = list(proba_result.values())
|
30 |
|
infer_intent.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from transformers import AutoTokenizer
|
2 |
import onnxruntime as ort
|
|
|
3 |
import numpy as np
|
4 |
import requests
|
5 |
import os
|
@@ -8,16 +8,17 @@ import os
|
|
8 |
class IntentClassifier:
|
9 |
def __init__(self):
|
10 |
self.id2label = {0: 'information_intent',
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
self.label2id = {label:id for id,label in self.id2label.items()}
|
19 |
|
20 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
|
|
21 |
model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
|
22 |
model_dir_path = "models"
|
23 |
model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
|
@@ -35,10 +36,10 @@ class IntentClassifier:
|
|
35 |
|
36 |
def find_intent(self, sequence, verbose=False):
|
37 |
inputs = self.tokenizer(sequence,
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
# Convert inputs to NumPy arrays
|
44 |
onnx_inputs = {k: v for k, v in inputs.items()}
|
@@ -60,12 +61,11 @@ class IntentClassifier:
|
|
60 |
|
61 |
return pred_result, proba_result
|
62 |
|
63 |
-
|
64 |
def main():
|
65 |
text_list = [
|
66 |
-
'floor repair cost',
|
67 |
-
'pet store near me',
|
68 |
-
'who is the us president',
|
69 |
'italian food',
|
70 |
'sandwiches for lunch',
|
71 |
"cheese burger cost",
|
@@ -75,7 +75,7 @@ def main():
|
|
75 |
]
|
76 |
cls = IntentClassifier()
|
77 |
for sequence in text_list:
|
78 |
-
cls.find_intent(sequence)
|
79 |
|
80 |
if __name__ == '__main__':
|
81 |
main()
|
|
|
|
|
1 |
import onnxruntime as ort
|
2 |
+
from transformers import AutoTokenizer
|
3 |
import numpy as np
|
4 |
import requests
|
5 |
import os
|
|
|
8 |
class IntentClassifier:
|
9 |
def __init__(self):
|
10 |
self.id2label = {0: 'information_intent',
|
11 |
+
1: 'yelp_intent',
|
12 |
+
2: 'navigation_intent',
|
13 |
+
3: 'travel_intent',
|
14 |
+
4: 'purchase_intent',
|
15 |
+
5: 'weather_intent',
|
16 |
+
6: 'translation_intent',
|
17 |
+
7: 'unknown'}
|
18 |
+
self.label2id = {label: id for id, label in self.id2label.items()}
|
19 |
|
20 |
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
|
21 |
+
|
22 |
model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
|
23 |
model_dir_path = "models"
|
24 |
model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
|
|
|
36 |
|
37 |
def find_intent(self, sequence, verbose=False):
|
38 |
inputs = self.tokenizer(sequence,
|
39 |
+
return_tensors="np", # ONNX requires inputs in NumPy format
|
40 |
+
padding="max_length", # Pad to max length
|
41 |
+
truncation=True, # Truncate if the text is too long
|
42 |
+
max_length=64)
|
43 |
|
44 |
# Convert inputs to NumPy arrays
|
45 |
onnx_inputs = {k: v for k, v in inputs.items()}
|
|
|
61 |
|
62 |
return pred_result, proba_result
|
63 |
|
|
|
64 |
def main():
|
65 |
text_list = [
|
66 |
+
'floor repair cost',
|
67 |
+
'pet store near me',
|
68 |
+
'who is the us president',
|
69 |
'italian food',
|
70 |
'sandwiches for lunch',
|
71 |
"cheese burger cost",
|
|
|
75 |
]
|
76 |
cls = IntentClassifier()
|
77 |
for sequence in text_list:
|
78 |
+
cls.find_intent(sequence, verbose=True)
|
79 |
|
80 |
if __name__ == '__main__':
|
81 |
main()
|
infer_location.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime as ort
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
import os
|
6 |
+
|
7 |
+
class LocationFinder:
|
8 |
+
def __init__(self):
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
|
10 |
+
model_url = "https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
|
11 |
+
model_dir_path = "models"
|
12 |
+
model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
|
13 |
+
if not os.path.exists(model_dir_path):
|
14 |
+
os.makedirs(model_dir_path)
|
15 |
+
if not os.path.exists(model_path):
|
16 |
+
print("Downloading ONNX model...")
|
17 |
+
response = requests.get(model_url)
|
18 |
+
with open(model_path, "wb") as f:
|
19 |
+
f.write(response.content)
|
20 |
+
print("ONNX model downloaded.")
|
21 |
+
|
22 |
+
# Load the ONNX model
|
23 |
+
self.ort_session = ort.InferenceSession(model_path)
|
24 |
+
|
25 |
+
def find_location(self, sequence, verbose=False):
|
26 |
+
inputs = self.tokenizer(sequence,
|
27 |
+
return_tensors="np", # ONNX requires inputs in NumPy format
|
28 |
+
padding="max_length", # Pad to max length
|
29 |
+
truncation=True, # Truncate if the text is too long
|
30 |
+
max_length=64)
|
31 |
+
input_feed = {
|
32 |
+
'input_ids': inputs['input_ids'].astype(np.int64),
|
33 |
+
'attention_mask': inputs['attention_mask'].astype(np.int64),
|
34 |
+
}
|
35 |
+
|
36 |
+
# Run inference with the ONNX model
|
37 |
+
outputs = self.ort_session.run(None, input_feed)
|
38 |
+
logits = outputs[0] # Assuming the model output is logits
|
39 |
+
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
|
40 |
+
|
41 |
+
predicted_ids = np.argmax(logits, axis=-1)
|
42 |
+
predicted_probs = np.max(probabilities, axis=-1)
|
43 |
+
|
44 |
+
# Define the threshold for NER probability
|
45 |
+
threshold = 0.6
|
46 |
+
|
47 |
+
label_map = {
|
48 |
+
0: "O", # Outside any named entity
|
49 |
+
1: "B-PER", # Beginning of a person entity
|
50 |
+
2: "I-PER", # Inside a person entity
|
51 |
+
3: "B-ORG", # Beginning of an organization entity
|
52 |
+
4: "I-ORG", # Inside an organization entity
|
53 |
+
5: "B-LOC", # Beginning of a location entity
|
54 |
+
6: "I-LOC", # Inside a location entity
|
55 |
+
7: "B-MISC", # Beginning of a miscellaneous entity
|
56 |
+
8: "I-MISC" # Inside a miscellaneous entity
|
57 |
+
}
|
58 |
+
|
59 |
+
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
60 |
+
|
61 |
+
# List to hold the detected location terms
|
62 |
+
location_entities = []
|
63 |
+
current_location = []
|
64 |
+
|
65 |
+
# Loop through each token and its predicted label and probability
|
66 |
+
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
67 |
+
label = label_map[predicted_id]
|
68 |
+
|
69 |
+
# Ignore special tokens like [CLS], [SEP]
|
70 |
+
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
71 |
+
continue
|
72 |
+
|
73 |
+
# Only consider tokens with probability above the threshold
|
74 |
+
if prob > threshold:
|
75 |
+
# If the token is a part of a location entity (B-LOC or I-LOC)
|
76 |
+
if label in ["B-LOC", "I-LOC"]:
|
77 |
+
if label == "B-LOC":
|
78 |
+
# If we encounter a B-LOC, we may need to store the previous location
|
79 |
+
if current_location:
|
80 |
+
location_entities.append(" ".join(current_location).replace("##", ""))
|
81 |
+
# Start a new location entity
|
82 |
+
current_location = [token]
|
83 |
+
elif label == "I-LOC" and current_location:
|
84 |
+
# Continue appending to the current location entity
|
85 |
+
current_location.append(token)
|
86 |
+
else:
|
87 |
+
# If we encounter a non-location entity, store the current location and reset
|
88 |
+
if current_location:
|
89 |
+
location_entities.append(" ".join(current_location).replace("##", ""))
|
90 |
+
current_location = []
|
91 |
+
|
92 |
+
# Append the last location entity if it exists
|
93 |
+
if current_location:
|
94 |
+
location_entities.append(" ".join(current_location).replace("##", ""))
|
95 |
+
|
96 |
+
# Return the detected location terms
|
97 |
+
return location_entities[0] if location_entities != [] else None
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
query = "weather in seattle"
|
102 |
+
loc_finder = LocationFinder()
|
103 |
+
location = loc_finder.find_location(query)
|
104 |
+
print(f"query = {query} => {location}")
|