test / app.py
jinwei12's picture
Create app.py
2a2fe0b
raw
history blame
3.44 kB
import streamlit as st
import plotly.graph_objects as go
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
import requests
def search_geonames(location):
api_endpoint = "http://api.geonames.org/searchJSON"
username = "zekun"
params = {
'q': location,
'username': username,
'maxRows': 5
}
response = requests.get(api_endpoint, params=params)
data = response.json()
if 'geonames' in data:
fig = go.Figure()
for place_info in data['geonames']:
latitude = float(place_info.get('lat', 0.0))
longitude = float(place_info.get('lng', 0.0))
fig.add_trace(go.Scattermapbox(
lat=[latitude],
lon=[longitude],
mode='markers',
marker=go.scattermapbox.Marker(
size=10,
color='orange',
),
text=[f'Location: {location}'],
hoverinfo="text",
hovertemplate='<b>Location</b>: %{text}',
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
bearing=0,
center=go.layout.mapbox.Center(
lat=latitude,
lon=longitude
),
pitch=0,
zoom=2
))
st.plotly_chart(fig)
# Return an empty figure
return go.Figure()
def mapping(location):
st.write(f"Mapping location: {location}")
search_geonames(location)
def generate_human_readable(tokens,labels):
ret = []
for t,lab in zip(tokens,labels):
if t == '[SEP]':
continue
if t.startswith("##") :
assert len(ret) > 0
ret[-1] = ret[-1] + t.strip('##')
elif lab==2:
assert len(ret) > 0
ret[-1] = ret[-1] + " "+ t.strip('##')
else:
ret.append(t)
return ret
def showOnMap(input_sentence):
# get the location names:
model_name = "zekun-li/geolm-base-toponym-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokens = tokenizer.encode(input_sentence, return_tensors="pt")
outputs = model(tokens)
predicted_labels = torch.argmax(outputs.logits, dim=2)
predicted_labels = predicted_labels.detach().cpu().numpy()
# "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo" }
predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]]
predicted_labels = torch.argmax(outputs.logits, dim=2)
query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]]
query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]]
human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels)
#['Los Angeles', 'L . A .', 'California', 'U . S .', 'Southern California', 'Los Angeles', 'United States', 'New York City']
return human_readable
def show_on_map():
input = st.text_area("Enter a sentence:", height=200)
st.button("Submit")
places = showOnMap(input)
selected_place = st.selectbox("Select a location:", places)
mapping(selected_place)
if __name__ == "__main__":
show_on_map()