import streamlit as st import plotly.graph_objects as go import torch from transformers import AutoModelForTokenClassification, AutoTokenizer import requests def search_geonames(location): api_endpoint = "http://api.geonames.org/searchJSON" username = "zekun" params = { 'q': location, 'username': username, 'maxRows': 5 } response = requests.get(api_endpoint, params=params) data = response.json() if 'geonames' in data: fig = go.Figure() for place_info in data['geonames']: latitude = float(place_info.get('lat', 0.0)) longitude = float(place_info.get('lng', 0.0)) fig.add_trace(go.Scattermapbox( lat=[latitude], lon=[longitude], mode='markers', marker=go.scattermapbox.Marker( size=10, color='orange', ), text=[f'Location: {location}'], hoverinfo="text", hovertemplate='Location: %{text}', )) fig.update_layout( mapbox_style="open-street-map", hovermode='closest', mapbox=dict( bearing=0, center=go.layout.mapbox.Center( lat=latitude, lon=longitude ), pitch=0, zoom=2 )) st.plotly_chart(fig) # Return an empty figure return go.Figure() def mapping(location): st.write(f"Mapping location: {location}") search_geonames(location) def generate_human_readable(tokens,labels): ret = [] for t,lab in zip(tokens,labels): if t == '[SEP]': continue if t.startswith("##") : assert len(ret) > 0 ret[-1] = ret[-1] + t.strip('##') elif lab==2: assert len(ret) > 0 ret[-1] = ret[-1] + " "+ t.strip('##') else: ret.append(t) return ret def showOnMap(input_sentence): # get the location names: model_name = "zekun-li/geolm-base-toponym-recognition" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForTokenClassification.from_pretrained(model_name) tokens = tokenizer.encode(input_sentence, return_tensors="pt") outputs = model(tokens) predicted_labels = torch.argmax(outputs.logits, dim=2) predicted_labels = predicted_labels.detach().cpu().numpy() # "id2label": { "0": "O", "1": "B-Topo", "2": "I-Topo" } predicted_labels = [model.config.id2label[label] for label in predicted_labels[0]] predicted_labels = torch.argmax(outputs.logits, dim=2) query_tokens = tokens[0][torch.where(predicted_labels[0] != 0)[0]] query_labels = predicted_labels[0][torch.where(predicted_labels[0] != 0)[0]] human_readable = generate_human_readable(tokenizer.convert_ids_to_tokens(query_tokens), query_labels) #['Los Angeles', 'L . A .', 'California', 'U . S .', 'Southern California', 'Los Angeles', 'United States', 'New York City'] return human_readable def show_on_map(): input = st.text_area("Enter a sentence:", height=200) st.button("Submit") places = showOnMap(input) selected_place = st.selectbox("Select a location:", places) mapping(selected_place) if __name__ == "__main__": show_on_map()