Update README.md
Browse files
README.md
CHANGED
@@ -75,7 +75,72 @@ Users (both direct and downstream) should be made aware of the risks, biases and
|
|
75 |
|
76 |
Use the code below to get started with the model.
|
77 |
|
78 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
## Training Details
|
81 |
|
|
|
75 |
|
76 |
Use the code below to get started with the model.
|
77 |
|
78 |
+
[
|
79 |
+
import torch
|
80 |
+
from transformers import pipeline
|
81 |
+
import os
|
82 |
+
import json
|
83 |
+
|
84 |
+
class GeoLLMBertInference:
|
85 |
+
def __init__(self, config_path='config.json'):
|
86 |
+
with open(config_path, 'r') as config_file:
|
87 |
+
config = json.load(config_file)
|
88 |
+
|
89 |
+
self.project_path = config['project_path']
|
90 |
+
self.tokenizer_path = os.path.join(self.project_path, config['tokenizer_path'])
|
91 |
+
self.model_path = os.path.join(self.project_path, config['model_path'])
|
92 |
+
|
93 |
+
# Check if a GPU is available and set the device accordingly
|
94 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
95 |
+
|
96 |
+
self.ner_pipeline = pipeline("ner", model=self.model_path, tokenizer=self.tokenizer_path, device=self.device)
|
97 |
+
self.result = None
|
98 |
+
self.concatenate_result = None
|
99 |
+
|
100 |
+
def get_ner_result(self, address):
|
101 |
+
self.result = self.ner_pipeline(address.upper())
|
102 |
+
return self.result
|
103 |
+
|
104 |
+
def concatenate_entities(self):
|
105 |
+
if self.result is None:
|
106 |
+
raise ValueError("NER result is not available. Please run get_ner_result first.")
|
107 |
+
|
108 |
+
concatenated_result = {}
|
109 |
+
for entity in self.result:
|
110 |
+
tag = entity['entity']
|
111 |
+
word = entity['word'].replace('##', '').replace(',', '')
|
112 |
+
if tag not in concatenated_result:
|
113 |
+
concatenated_result[tag] = word.upper()
|
114 |
+
else:
|
115 |
+
concatenated_result[tag] += '' + word.upper()
|
116 |
+
|
117 |
+
self.concatenate_result = concatenated_result
|
118 |
+
return self.concatenate_result
|
119 |
+
|
120 |
+
def get_json_result(self):
|
121 |
+
if self.concatenate_result is None:
|
122 |
+
raise ValueError("Concatenated result is not available. Please run concatenate_entities first.")
|
123 |
+
|
124 |
+
return json.dumps(self.concatenate_result, indent=4)
|
125 |
+
|
126 |
+
# Example Usage
|
127 |
+
if __name__ == "__main__":
|
128 |
+
geo_llm = GeoLLMBertInference('code/geo_llm/config.json')
|
129 |
+
address = "16 ChSeAStREtST.CATHARINE"
|
130 |
+
result = geo_llm.get_ner_result(address)
|
131 |
+
print(result)
|
132 |
+
|
133 |
+
concatenate_result = geo_llm.concatenate_entities()
|
134 |
+
print(concatenate_result)
|
135 |
+
|
136 |
+
# Get the concatenated result in JSON format
|
137 |
+
json_result = geo_llm.get_json_result()
|
138 |
+
data = json.loads(json_result)
|
139 |
+
|
140 |
+
|
141 |
+
# Print the JSON string
|
142 |
+
print(json_result)
|
143 |
+
]
|
144 |
|
145 |
## Training Details
|
146 |
|