azamat's picture
Final fix
e0e99ed
raw
history blame
1.85 kB
import re
import gradio as gr
from transformers import pipeline
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
def process_tweet(tweet):
# remove links
tweet = re.sub('((www\.[\s]+)|(https?://[^\s]+))', '', tweet)
# remove usernames
tweet = re.sub('@[^\s]+', '', tweet)
# remove additional white spaces
tweet = re.sub('[\s]+', ' ', tweet)
# replace hashtags with words
tweet = re.sub(r'#([^\s]+)', r'\1', tweet)
# trim
tweet = tweet.strip('\'"')
return tweet #if len(tweet) > 0 else ""
tokenizer = AutoTokenizer.from_pretrained(
"azamat/geocoder_model_xlm_roberta_50"
)
relevancy_pipeline = pipeline("sentiment-analysis", model="azamat/geocoder_model")
coordinates_model = AutoModelForSequenceClassification.from_pretrained(
"azamat/geocoder_model_xlm_roberta_50",
)
def predict_relevancy(text):
outputs = relevancy_pipeline(text)
return outputs[0]['label'], outputs[0]['score']
def predict_coordinates(text):
encoding = tokenizer(text, padding="max_length", truncation=True, \
max_length=128, return_tensors='pt')
outputs = coordinates_model(**encoding)
return round(outputs[0][0][0].item(), 3), round(outputs[0][0][1].item(), 3)
def predict(text):
text = process_tweet(text)
relevancy_label, relevancy_score = predict_relevancy(text)
if relevancy_label == 'relevant':
lat, lon = predict_coordinates(text)
return f"Confident for {round(relevancy_score * 100, 2)}% that tweet has the geolocation relevant information.\n" + \
f"Predicted coordinates are: lat: {lat} lon: {lon}"
return f"Confident for {relevancy_score * 100}% that tweet does not have the geolocation relevant information."
iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface.launch()