File size: 5,194 Bytes
f9714b1
 
 
 
 
 
1c56cce
f9714b1
 
 
1c56cce
f9714b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b53210
f9714b1
 
 
 
 
 
 
 
 
 
 
1b53210
f9714b1
 
 
 
 
 
788c760
 
 
 
 
1b53210
f9714b1
 
 
 
1b53210
624162d
 
 
1b53210
 
f9714b1
788c760
 
1b53210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9714b1
1b53210
 
 
 
 
 
788c760
1b53210
 
 
f9714b1
1b53210
788c760
 
1b53210
788c760
f9714b1
1b53210
f9714b1
788c760
f9714b1
788c760
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
import requests
import os

VERSION = "v0.1.1"
class LocationFinder:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
        model_url = f"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/{VERSION}/onnx/model_quantized.onnx"
        model_dir_path = "models"
        model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA"
        if not os.path.exists(model_dir_path):
            os.makedirs(model_dir_path)
        if not os.path.exists(model_path):
            print("Downloading ONNX model...")
            response = requests.get(model_url)
            with open(model_path, "wb") as f:
                f.write(response.content)
            print("ONNX model downloaded.")

        # Load the ONNX model
        self.ort_session = ort.InferenceSession(model_path)

    def find_location(self, sequence, verbose=False):
        inputs = self.tokenizer(sequence,
                                return_tensors="np",  # ONNX requires inputs in NumPy format
                                padding="max_length",  # Pad to max length
                                truncation=True,       # Truncate if the text is too long
                                max_length=64)
        input_feed = {
            'input_ids': inputs['input_ids'].astype(np.int64),
            'attention_mask': inputs['attention_mask'].astype(np.int64),
        }

        # Run inference with the ONNX model
        outputs = self.ort_session.run(None, input_feed)
        logits = outputs[0]  # Assuming the model output is logits
        probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
        
        predicted_ids = np.argmax(logits, axis=-1)
        predicted_probs = np.max(probabilities, axis=-1)
        
        # Define the threshold for NER probability
        threshold = 0.6
        
        # Define the label map for city, state, citystate, etc.
        label_map = {
            0: "O",        # Outside any named entity
            1: "B-PER",    # Beginning of a person entity
            2: "I-PER",    # Inside a person entity
            3: "B-ORG",    # Beginning of an organization entity
            4: "I-ORG",    # Inside an organization entity
            5: "B-CITY",   # Beginning of a city entity
            6: "I-CITY",   # Inside a city entity
            7: "B-STATE",  # Beginning of a state entity
            8: "I-STATE",  # Inside a state entity
            9: "B-CITYSTATE",   # Beginning of a city_state entity
        10: "I-CITYSTATE",   # Inside a city_state entity
        }
        
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        # Initialize lists to hold detected entities
        city_entities = []
        state_entities = []
        city_state_entities = []
        
        for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]):
            if prob > threshold:
                if token in ["[CLS]", "[SEP]", "[PAD]"]:
                    continue
                if label_map[predicted_id] in ["B-CITY", "I-CITY"]:
                    # Handle the case of continuation tokens (like "##" in subwords)
                    if token.startswith("##") and city_entities:
                        city_entities[-1] += token[2:]  # Remove "##" and append to the last token
                    else:
                        city_entities.append(token)
                elif label_map[predicted_id] in ["B-STATE", "I-STATE"]:
                    if token.startswith("##") and state_entities:
                        state_entities[-1] += token[2:]
                    else:
                        state_entities.append(token)
                elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]:
                    if token.startswith("##") and city_state_entities:
                        city_state_entities[-1] += token[2:]
                    else:
                        city_state_entities.append(token)

        # Combine city_state entities and split into city and state if necessary
        if city_state_entities:
            city_state_str = " ".join(city_state_entities)
            city_state_split = city_state_str.split(",")  # Split on comma to separate city and state
            city_res = city_state_split[0].strip() if city_state_split[0] else None
            state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None
        else:
            # If no city_state entities, use detected city and state entities separately
            city_res = " ".join(city_entities).strip() if city_entities else None
            state_res = " ".join(state_entities).strip() if state_entities else None

        # Return the detected city and state as separate components
        return {
            'city': city_res,
            'state': state_res
        }


if __name__ == '__main__':
    query = "weather in san francisco, ca"
    loc_finder = LocationFinder()
    entities = loc_finder.find_location(query)
    print(f"query = {query} => {entities}")