Chidam Gopal commited on
Commit
f9714b1
1 Parent(s): 8bf96ea

intent and location updates

Browse files
Files changed (3) hide show
  1. app.py +11 -1
  2. infer_intent.py +18 -18
  3. infer_location.py +104 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import streamlit.components.v1 as components
3
  from infer_intent import IntentClassifier
 
4
  import matplotlib.pyplot as plt
5
 
6
  st.title("Intent classifier")
@@ -10,11 +11,20 @@ def get_intent_classifier():
10
  cls = IntentClassifier()
11
  return cls
12
 
 
 
 
 
 
13
  cls = get_intent_classifier()
14
  query = st.text_input("Enter a query", value="What is the weather today")
15
  pred_result, proba_result = cls.find_intent(query)
16
 
17
- st.markdown(f"prediction = :green[{pred_result}]")
 
 
 
 
18
  keys = list(proba_result.keys())
19
  values = list(proba_result.values())
20
 
 
1
  import streamlit as st
2
  import streamlit.components.v1 as components
3
  from infer_intent import IntentClassifier
4
+ from infer_location import LocationFinder
5
  import matplotlib.pyplot as plt
6
 
7
  st.title("Intent classifier")
 
11
  cls = IntentClassifier()
12
  return cls
13
 
14
+ @st.cache_resource
15
+ def get_location_finder():
16
+ ner = LocationFinder()
17
+ return ner
18
+
19
  cls = get_intent_classifier()
20
  query = st.text_input("Enter a query", value="What is the weather today")
21
  pred_result, proba_result = cls.find_intent(query)
22
 
23
+ ner = get_location_finder()
24
+ location = ner.find_location(query)
25
+
26
+ st.markdown(f"Intent = :green[{pred_result}]")
27
+ st.markdown(f"Location = :green[{location}]")
28
  keys = list(proba_result.keys())
29
  values = list(proba_result.values())
30
 
infer_intent.py CHANGED
@@ -1,5 +1,5 @@
1
- from transformers import AutoTokenizer
2
  import onnxruntime as ort
 
3
  import numpy as np
4
  import requests
5
  import os
@@ -8,16 +8,17 @@ import os
8
  class IntentClassifier:
9
  def __init__(self):
10
  self.id2label = {0: 'information_intent',
11
- 1: 'yelp_intent',
12
- 2: 'navigation_intent',
13
- 3: 'travel_intent',
14
- 4: 'purchase_intent',
15
- 5: 'weather_intent',
16
- 6: 'translation_intent',
17
- 7: 'unknown'}
18
- self.label2id = {label:id for id,label in self.id2label.items()}
19
 
20
  self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
 
21
  model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
22
  model_dir_path = "models"
23
  model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
@@ -35,10 +36,10 @@ class IntentClassifier:
35
 
36
  def find_intent(self, sequence, verbose=False):
37
  inputs = self.tokenizer(sequence,
38
- return_tensors="np", # ONNX requires inputs in NumPy format
39
- padding="max_length", # Pad to max length
40
- truncation=True, # Truncate if the text is too long
41
- max_length=64)
42
 
43
  # Convert inputs to NumPy arrays
44
  onnx_inputs = {k: v for k, v in inputs.items()}
@@ -60,12 +61,11 @@ class IntentClassifier:
60
 
61
  return pred_result, proba_result
62
 
63
-
64
  def main():
65
  text_list = [
66
- 'floor repair cost',
67
- 'pet store near me',
68
- 'who is the us president',
69
  'italian food',
70
  'sandwiches for lunch',
71
  "cheese burger cost",
@@ -75,7 +75,7 @@ def main():
75
  ]
76
  cls = IntentClassifier()
77
  for sequence in text_list:
78
- cls.find_intent(sequence)
79
 
80
  if __name__ == '__main__':
81
  main()
 
 
1
  import onnxruntime as ort
2
+ from transformers import AutoTokenizer
3
  import numpy as np
4
  import requests
5
  import os
 
8
  class IntentClassifier:
9
  def __init__(self):
10
  self.id2label = {0: 'information_intent',
11
+ 1: 'yelp_intent',
12
+ 2: 'navigation_intent',
13
+ 3: 'travel_intent',
14
+ 4: 'purchase_intent',
15
+ 5: 'weather_intent',
16
+ 6: 'translation_intent',
17
+ 7: 'unknown'}
18
+ self.label2id = {label: id for id, label in self.id2label.items()}
19
 
20
  self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
21
+
22
  model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
23
  model_dir_path = "models"
24
  model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
 
36
 
37
  def find_intent(self, sequence, verbose=False):
38
  inputs = self.tokenizer(sequence,
39
+ return_tensors="np", # ONNX requires inputs in NumPy format
40
+ padding="max_length", # Pad to max length
41
+ truncation=True, # Truncate if the text is too long
42
+ max_length=64)
43
 
44
  # Convert inputs to NumPy arrays
45
  onnx_inputs = {k: v for k, v in inputs.items()}
 
61
 
62
  return pred_result, proba_result
63
 
 
64
  def main():
65
  text_list = [
66
+ 'floor repair cost',
67
+ 'pet store near me',
68
+ 'who is the us president',
69
  'italian food',
70
  'sandwiches for lunch',
71
  "cheese burger cost",
 
75
  ]
76
  cls = IntentClassifier()
77
  for sequence in text_list:
78
+ cls.find_intent(sequence, verbose=True)
79
 
80
  if __name__ == '__main__':
81
  main()
infer_location.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ from transformers import AutoTokenizer
3
+ import numpy as np
4
+ import requests
5
+ import os
6
+
7
+ class LocationFinder:
8
+ def __init__(self):
9
+ self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
10
+ model_url = "https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
11
+ model_dir_path = "models"
12
+ model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
13
+ if not os.path.exists(model_dir_path):
14
+ os.makedirs(model_dir_path)
15
+ if not os.path.exists(model_path):
16
+ print("Downloading ONNX model...")
17
+ response = requests.get(model_url)
18
+ with open(model_path, "wb") as f:
19
+ f.write(response.content)
20
+ print("ONNX model downloaded.")
21
+
22
+ # Load the ONNX model
23
+ self.ort_session = ort.InferenceSession(model_path)
24
+
25
+ def find_location(self, sequence, verbose=False):
26
+ inputs = self.tokenizer(sequence,
27
+ return_tensors="np", # ONNX requires inputs in NumPy format
28
+ padding="max_length", # Pad to max length
29
+ truncation=True, # Truncate if the text is too long
30
+ max_length=64)
31
+ input_feed = {
32
+ 'input_ids': inputs['input_ids'].astype(np.int64),
33
+ 'attention_mask': inputs['attention_mask'].astype(np.int64),
34
+ }
35
+
36
+ # Run inference with the ONNX model
37
+ outputs = self.ort_session.run(None, input_feed)
38
+ logits = outputs[0] # Assuming the model output is logits
39
+ probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
40
+
41
+ predicted_ids = np.argmax(logits, axis=-1)
42
+ predicted_probs = np.max(probabilities, axis=-1)
43
+
44
+ # Define the threshold for NER probability
45
+ threshold = 0.6
46
+
47
+ label_map = {
48
+ 0: "O", # Outside any named entity
49
+ 1: "B-PER", # Beginning of a person entity
50
+ 2: "I-PER", # Inside a person entity
51
+ 3: "B-ORG", # Beginning of an organization entity
52
+ 4: "I-ORG", # Inside an organization entity
53
+ 5: "B-LOC", # Beginning of a location entity
54
+ 6: "I-LOC", # Inside a location entity
55
+ 7: "B-MISC", # Beginning of a miscellaneous entity
56
+ 8: "I-MISC" # Inside a miscellaneous entity
57
+ }
58
+
59
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
60
+
61
+ # List to hold the detected location terms
62
+ location_entities = []
63
+ current_location = []
64
+
65
+ # Loop through each token and its predicted label and probability
66
+ for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
67
+ label = label_map[predicted_id]
68
+
69
+ # Ignore special tokens like [CLS], [SEP]
70
+ if token in ["[CLS]", "[SEP]", "[PAD]"]:
71
+ continue
72
+
73
+ # Only consider tokens with probability above the threshold
74
+ if prob > threshold:
75
+ # If the token is a part of a location entity (B-LOC or I-LOC)
76
+ if label in ["B-LOC", "I-LOC"]:
77
+ if label == "B-LOC":
78
+ # If we encounter a B-LOC, we may need to store the previous location
79
+ if current_location:
80
+ location_entities.append(" ".join(current_location).replace("##", ""))
81
+ # Start a new location entity
82
+ current_location = [token]
83
+ elif label == "I-LOC" and current_location:
84
+ # Continue appending to the current location entity
85
+ current_location.append(token)
86
+ else:
87
+ # If we encounter a non-location entity, store the current location and reset
88
+ if current_location:
89
+ location_entities.append(" ".join(current_location).replace("##", ""))
90
+ current_location = []
91
+
92
+ # Append the last location entity if it exists
93
+ if current_location:
94
+ location_entities.append(" ".join(current_location).replace("##", ""))
95
+
96
+ # Return the detected location terms
97
+ return location_entities[0] if location_entities != [] else None
98
+
99
+
100
+ if __name__ == '__main__':
101
+ query = "weather in seattle"
102
+ loc_finder = LocationFinder()
103
+ location = loc_finder.find_location(query)
104
+ print(f"query = {query} => {location}")