Chidam Gopal commited on
Commit
1b53210
1 Parent(s): 624162d

changes to infer_location

Browse files
Files changed (1) hide show
  1. infer_location.py +34 -37
infer_location.py CHANGED
@@ -32,7 +32,7 @@ class LocationFinder:
32
  'input_ids': inputs['input_ids'].astype(np.int64),
33
  'attention_mask': inputs['attention_mask'].astype(np.int64),
34
  }
35
-
36
  # Run inference with the ONNX model
37
  outputs = self.ort_session.run(None, input_feed)
38
  logits = outputs[0] # Assuming the model output is logits
@@ -44,7 +44,7 @@ class LocationFinder:
44
  # Define the threshold for NER probability
45
  threshold = 0.6
46
 
47
- # Define the label map for city, state, organization, citystate
48
  label_map = {
49
  0: "O", # Outside any named entity
50
  1: "B-PER", # Beginning of a person entity
@@ -56,58 +56,55 @@ class LocationFinder:
56
  7: "B-STATE", # Beginning of a state entity
57
  8: "I-STATE", # Inside a state entity
58
  9: "B-CITYSTATE", # Beginning of a city_state entity
59
- 10: "I-CITYSTATE", # Inside a city_state entity
60
  }
61
 
62
  tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
63
 
64
- # List to hold the detected entities (city, state, organization, citystate)
65
- city_entities = []
66
- state_entities = []
67
- org_entities = []
68
- city_state_entities = []
69
-
70
  city_entities = []
71
  state_entities = []
72
  city_state_entities = []
73
- org_entities = []
74
- for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):
75
  if prob > threshold:
76
  if token in ["[CLS]", "[SEP]", "[PAD]"]:
77
  continue
78
- else:
79
- if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
80
- city_entities.append(token.replace("##", ""))
81
- elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
82
- state_entities.append(token.replace("##", ""))
83
- elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
84
- city_state_entities.append(token.replace("##", ""))
85
- elif label_map[predicted_id] in ["B-ORG", "I-ORG"]:
86
- org_entities.append(token.replace("##", ""))
87
-
88
- city_state_res = "".join(cs_entity.replace(",", ", ") for cs_entity in city_state_entities) if city_state_entities else None
89
- if city_entities:
90
- city_res = " ".join(city_entities)
91
- elif city_state_res:
92
- city_res = city_state_res.split(", ")[0]
93
- else:
94
- city_res = None
95
 
96
- if state_entities:
97
- state_res = " ".join(state_entities)
98
- elif city_state_res and len(city_state_res) > 0:
99
- state_res = city_state_res.split(", ")[-1]
 
 
100
  else:
101
- state_res = None
 
 
102
 
103
- org_res = " ".join(org_entities) if org_entities else None
104
-
105
- # Return the detected entities
106
  return {
107
  'city': city_res,
108
- 'state': state_res,
109
  }
110
 
 
111
  if __name__ == '__main__':
112
  query = "weather in san francisco, ca"
113
  loc_finder = LocationFinder()
 
32
  'input_ids': inputs['input_ids'].astype(np.int64),
33
  'attention_mask': inputs['attention_mask'].astype(np.int64),
34
  }
35
+
36
  # Run inference with the ONNX model
37
  outputs = self.ort_session.run(None, input_feed)
38
  logits = outputs[0] # Assuming the model output is logits
 
44
  # Define the threshold for NER probability
45
  threshold = 0.6
46
 
47
+ # Define the label map for city, state, citystate, etc.
48
  label_map = {
49
  0: "O", # Outside any named entity
50
  1: "B-PER", # Beginning of a person entity
 
56
  7: "B-STATE", # Beginning of a state entity
57
  8: "I-STATE", # Inside a state entity
58
  9: "B-CITYSTATE", # Beginning of a city_state entity
59
+ 10: "I-CITYSTATE", # Inside a city_state entity
60
  }
61
 
62
  tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
63
 
64
+ # Initialize lists to hold detected entities
 
 
 
 
 
65
  city_entities = []
66
  state_entities = []
67
  city_state_entities = []
68
+
69
+ for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
70
  if prob > threshold:
71
  if token in ["[CLS]", "[SEP]", "[PAD]"]:
72
  continue
73
+ if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
74
+ # Handle the case of continuation tokens (like "##" in subwords)
75
+ if token.startswith("##") and city_entities:
76
+ city_entities[-1] += token[2:] # Remove "##" and append to the last token
77
+ else:
78
+ city_entities.append(token)
79
+ elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
80
+ if token.startswith("##") and state_entities:
81
+ state_entities[-1] += token[2:]
82
+ else:
83
+ state_entities.append(token)
84
+ elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
85
+ if token.startswith("##") and city_state_entities:
86
+ city_state_entities[-1] += token[2:]
87
+ else:
88
+ city_state_entities.append(token)
 
89
 
90
+ # Combine city_state entities and split into city and state if necessary
91
+ if city_state_entities:
92
+ city_state_str = " ".join(city_state_entities)
93
+ city_state_split = city_state_str.split(",") # Split on comma to separate city and state
94
+ city_res = city_state_split[0].strip() if city_state_split[0] else None
95
+ state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
96
  else:
97
+ # If no city_state entities, use detected city and state entities separately
98
+ city_res = " ".join(city_entities).strip() if city_entities else None
99
+ state_res = " ".join(state_entities).strip() if state_entities else None
100
 
101
+ # Return the detected city and state as separate components
 
 
102
  return {
103
  'city': city_res,
104
+ 'state': state_res
105
  }
106
 
107
+
108
  if __name__ == '__main__':
109
  query = "weather in san francisco, ca"
110
  loc_finder = LocationFinder()