File size: 3,441 Bytes
2a2fe0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import streamlit as st
import plotly.graph_objects as go
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
import requests
def search_geonames(location):
api_endpoint = "http://api.geonames.org/searchJSON"
username = "zekun"
params = {
'q': location,
'username': username,
'maxRows': 5
}
response = requests.get(api_endpoint, params=params)
data = response.json()
if 'geonames' in data:
fig = go.Figure()
for place_info in data['geonames']:
latitude = float(place_info.get('lat', 0.0))
longitude = float(place_info.get('lng', 0.0))
fig.add_trace(go.Scattermapbox(
lat=[latitude],
lon=[longitude],
mode='markers',
marker=go.scattermapbox.Marker(
size=10,
color='orange',
),
text=[f'Location: {location}'],
hoverinfo="text",
hovertemplate='<b>Location</b>: %{text}',
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=latitude,
lon=longitude
),
pitch=0,
zoom=2
))
st.plotly_chart(fig)
# Return an empty figure
return go.Figure()
def mapping(location):
st.write(f"Mapping location: {location}")
search_geonames(location)
def generate_human_readable(tokens,labels):
ret = []
for t,lab in zip(tokens,labels):
if t == '[SEP]':
continue
if t.startswith("##") :
assert len(ret) > 0
ret[-1] = ret[-1] + t.strip('##')
elif lab==2:
assert len(ret) > 0
ret[-1] = ret[-1] + " "+ t.strip('##')
else:
ret.append(t)
return ret
def showOnMap(input_sentence):
# get the location names:
model_name = "zekun-li/geolm-base-toponym-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokens = tokenizer.encode(input_sentence, return_tensors="pt")
outputs = model(tokens)
predicted_labels = torch.argmax(outputs.logits, dim=2)
predicted_labels = predicted_labels.detach().cpu().numpy()
# "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo" }
predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]]
predicted_labels = torch.argmax(outputs.logits, dim=2)
query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]]
query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]]
human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels)
#['Los Angeles', 'L . A .', 'California', 'U . S .', 'Southern California', 'Los Angeles', 'United States', 'New York City']
return human_readable
def show_on_map():
input = st.text_area("Enter a sentence:", height=200)
st.button("Submit")
places = showOnMap(input)
selected_place = st.selectbox("Select a location:", places)
mapping(selected_place)
if __name__ == "__main__":
show_on_map()
|