Spaces:
Running
Running
Chidam Gopal
commited on
Commit
•
1b53210
1
Parent(s):
624162d
changes to infer_location
Browse files- infer_location.py +34 -37
infer_location.py
CHANGED
@@ -32,7 +32,7 @@ class LocationFinder:
|
|
32 |
'input_ids': inputs['input_ids'].astype(np.int64),
|
33 |
'attention_mask': inputs['attention_mask'].astype(np.int64),
|
34 |
}
|
35 |
-
|
36 |
# Run inference with the ONNX model
|
37 |
outputs = self.ort_session.run(None, input_feed)
|
38 |
logits = outputs[0] # Assuming the model output is logits
|
@@ -44,7 +44,7 @@ class LocationFinder:
|
|
44 |
# Define the threshold for NER probability
|
45 |
threshold = 0.6
|
46 |
|
47 |
-
# Define the label map for city, state,
|
48 |
label_map = {
|
49 |
0: "O", # Outside any named entity
|
50 |
1: "B-PER", # Beginning of a person entity
|
@@ -56,58 +56,55 @@ class LocationFinder:
|
|
56 |
7: "B-STATE", # Beginning of a state entity
|
57 |
8: "I-STATE", # Inside a state entity
|
58 |
9: "B-CITYSTATE", # Beginning of a city_state entity
|
59 |
-
|
60 |
}
|
61 |
|
62 |
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
63 |
|
64 |
-
#
|
65 |
-
city_entities = []
|
66 |
-
state_entities = []
|
67 |
-
org_entities = []
|
68 |
-
city_state_entities = []
|
69 |
-
|
70 |
city_entities = []
|
71 |
state_entities = []
|
72 |
city_state_entities = []
|
73 |
-
|
74 |
-
for
|
75 |
if prob > threshold:
|
76 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
77 |
continue
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
city_res = None
|
95 |
|
96 |
-
if
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
100 |
else:
|
101 |
-
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
# Return the detected entities
|
106 |
return {
|
107 |
'city': city_res,
|
108 |
-
'state': state_res
|
109 |
}
|
110 |
|
|
|
111 |
if __name__ == '__main__':
|
112 |
query = "weather in san francisco, ca"
|
113 |
loc_finder = LocationFinder()
|
|
|
32 |
'input_ids': inputs['input_ids'].astype(np.int64),
|
33 |
'attention_mask': inputs['attention_mask'].astype(np.int64),
|
34 |
}
|
35 |
+
|
36 |
# Run inference with the ONNX model
|
37 |
outputs = self.ort_session.run(None, input_feed)
|
38 |
logits = outputs[0] # Assuming the model output is logits
|
|
|
44 |
# Define the threshold for NER probability
|
45 |
threshold = 0.6
|
46 |
|
47 |
+
# Define the label map for city, state, citystate, etc.
|
48 |
label_map = {
|
49 |
0: "O", # Outside any named entity
|
50 |
1: "B-PER", # Beginning of a person entity
|
|
|
56 |
7: "B-STATE", # Beginning of a state entity
|
57 |
8: "I-STATE", # Inside a state entity
|
58 |
9: "B-CITYSTATE", # Beginning of a city_state entity
|
59 |
+
10: "I-CITYSTATE", # Inside a city_state entity
|
60 |
}
|
61 |
|
62 |
tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
63 |
|
64 |
+
# Initialize lists to hold detected entities
|
|
|
|
|
|
|
|
|
|
|
65 |
city_entities = []
|
66 |
state_entities = []
|
67 |
city_state_entities = []
|
68 |
+
|
69 |
+
for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
|
70 |
if prob > threshold:
|
71 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
72 |
continue
|
73 |
+
if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
|
74 |
+
# Handle the case of continuation tokens (like "##" in subwords)
|
75 |
+
if token.startswith("##") and city_entities:
|
76 |
+
city_entities[-1] += token[2:] # Remove "##" and append to the last token
|
77 |
+
else:
|
78 |
+
city_entities.append(token)
|
79 |
+
elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
|
80 |
+
if token.startswith("##") and state_entities:
|
81 |
+
state_entities[-1] += token[2:]
|
82 |
+
else:
|
83 |
+
state_entities.append(token)
|
84 |
+
elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
|
85 |
+
if token.startswith("##") and city_state_entities:
|
86 |
+
city_state_entities[-1] += token[2:]
|
87 |
+
else:
|
88 |
+
city_state_entities.append(token)
|
|
|
89 |
|
90 |
+
# Combine city_state entities and split into city and state if necessary
|
91 |
+
if city_state_entities:
|
92 |
+
city_state_str = " ".join(city_state_entities)
|
93 |
+
city_state_split = city_state_str.split(",") # Split on comma to separate city and state
|
94 |
+
city_res = city_state_split[0].strip() if city_state_split[0] else None
|
95 |
+
state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
|
96 |
else:
|
97 |
+
# If no city_state entities, use detected city and state entities separately
|
98 |
+
city_res = " ".join(city_entities).strip() if city_entities else None
|
99 |
+
state_res = " ".join(state_entities).strip() if state_entities else None
|
100 |
|
101 |
+
# Return the detected city and state as separate components
|
|
|
|
|
102 |
return {
|
103 |
'city': city_res,
|
104 |
+
'state': state_res
|
105 |
}
|
106 |
|
107 |
+
|
108 |
if __name__ == '__main__':
|
109 |
query = "weather in san francisco, ca"
|
110 |
loc_finder = LocationFinder()
|