Spaces:
Running
Running
Chidam Gopal
commited on
included state and city in NER
Browse files- infer_location.py +5 -15
infer_location.py
CHANGED
@@ -22,20 +22,6 @@ class LocationFinder:
|
|
22 |
# Load the ONNX model
|
23 |
self.ort_session = ort.InferenceSession(model_path)
|
24 |
|
25 |
-
# State abbreviations list for post-processing
|
26 |
-
self.state_abbr = {
|
27 |
-
"AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "FL", "GA", "HI", "ID", "IL", "IN", "IA", "KS", "KY",
|
28 |
-
"LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO", "MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND",
|
29 |
-
"OH", "OK", "OR", "PA", "RI", "SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"
|
30 |
-
}
|
31 |
-
|
32 |
-
# # Helper function to correct misclassified state abbreviations
|
33 |
-
# def correct_state_abbreviation(self, token, predicted_label):
|
34 |
-
# if token.upper() in self.state_abbr and predicted_label == "I-CITY":
|
35 |
-
# return "I-STATE"
|
36 |
-
# return predicted_label
|
37 |
-
|
38 |
-
|
39 |
def find_location(self, sequence, verbose=False):
|
40 |
inputs = self.tokenizer(sequence,
|
41 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
@@ -80,6 +66,11 @@ class LocationFinder:
|
|
80 |
state_entities = []
|
81 |
org_entities = []
|
82 |
city_state_entities = []
|
|
|
|
|
|
|
|
|
|
|
83 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
84 |
if prob > threshold:
|
85 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
@@ -115,7 +106,6 @@ class LocationFinder:
|
|
115 |
return {
|
116 |
'city': city_res,
|
117 |
'state': state_res,
|
118 |
-
'organization': org_res,
|
119 |
}
|
120 |
|
121 |
if __name__ == '__main__':
|
|
|
22 |
# Load the ONNX model
|
23 |
self.ort_session = ort.InferenceSession(model_path)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def find_location(self, sequence, verbose=False):
|
26 |
inputs = self.tokenizer(sequence,
|
27 |
return_tensors="np", # ONNX requires inputs in NumPy format
|
|
|
66 |
state_entities = []
|
67 |
org_entities = []
|
68 |
city_state_entities = []
|
69 |
+
|
70 |
+
city_entities = []
|
71 |
+
state_entities = []
|
72 |
+
city_state_entities = []
|
73 |
+
org_entities = []
|
74 |
for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
|
75 |
if prob > threshold:
|
76 |
if token in ["[CLS]", "[SEP]", "[PAD]"]:
|
|
|
106 |
return {
|
107 |
'city': city_res,
|
108 |
'state': state_res,
|
|
|
109 |
}
|
110 |
|
111 |
if __name__ == '__main__':
|