|
import streamlit as st |
|
import plotly.graph_objects as go |
|
import torch |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
import requests |
|
|
|
def search_geonames(location): |
|
api_endpoint = "http://api.geonames.org/searchJSON" |
|
username = "zekun" |
|
|
|
params = { |
|
'q': location, |
|
'username': username, |
|
'maxRows': 5 |
|
} |
|
|
|
response = requests.get(api_endpoint, params=params) |
|
data = response.json() |
|
|
|
if 'geonames' in data: |
|
fig = go.Figure() |
|
for place_info in data['geonames']: |
|
latitude = float(place_info.get('lat', 0.0)) |
|
longitude = float(place_info.get('lng', 0.0)) |
|
|
|
fig.add_trace(go.Scattermapbox( |
|
lat=[latitude], |
|
lon=[longitude], |
|
mode='markers', |
|
marker=go.scattermapbox.Marker( |
|
size=10, |
|
color='orange', |
|
), |
|
text=[f'Location: {location}'], |
|
hoverinfo="text", |
|
hovertemplate='<b>Location</b>: %{text}', |
|
)) |
|
|
|
fig.update_layout( |
|
mapbox_style="open-street-map", |
|
hovermode='closest', |
|
mapbox=dict( |
|
bearing=0, |
|
center=go.layout.mapbox.Center( |
|
lat=latitude, |
|
lon=longitude |
|
), |
|
pitch=0, |
|
zoom=2 |
|
)) |
|
|
|
st.plotly_chart(fig) |
|
|
|
|
|
return go.Figure() |
|
|
|
|
|
def mapping(location): |
|
st.write(f"Mapping location: {location}") |
|
|
|
search_geonames(location) |
|
|
|
|
|
|
|
def generate_human_readable(tokens,labels): |
|
ret = [] |
|
for t,lab in zip(tokens,labels): |
|
if t == '[SEP]': |
|
continue |
|
|
|
if t.startswith("##") : |
|
assert len(ret) > 0 |
|
ret[-1] = ret[-1] + t.strip('##') |
|
|
|
elif lab==2: |
|
assert len(ret) > 0 |
|
ret[-1] = ret[-1] + " "+ t.strip('##') |
|
else: |
|
ret.append(t) |
|
|
|
return ret |
|
|
|
|
|
|
|
def showOnMap(input_sentence): |
|
|
|
|
|
model_name = "zekun-li/geolm-base-toponym-recognition" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
|
|
tokens = tokenizer.encode(input_sentence, return_tensors="pt") |
|
|
|
outputs = model(tokens) |
|
|
|
predicted_labels = torch.argmax(outputs.logits, dim=2) |
|
|
|
predicted_labels = predicted_labels.detach().cpu().numpy() |
|
|
|
|
|
|
|
predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]] |
|
|
|
predicted_labels = torch.argmax(outputs.logits, dim=2) |
|
|
|
query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]] |
|
|
|
query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]] |
|
|
|
human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels) |
|
|
|
|
|
return human_readable |
|
|
|
|
|
|
|
|
|
|
|
def show_on_map(): |
|
|
|
input = st.text_area("Enter a sentence:", height=200) |
|
|
|
st.button("Submit") |
|
|
|
places = showOnMap(input) |
|
|
|
selected_place = st.selectbox("Select a location:", places) |
|
mapping(selected_place) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
show_on_map() |
|
|