|
import json |
|
import os |
|
import random |
|
import pandas as pd |
|
import streamlit as st |
|
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline |
|
from transformers_interpret import SequenceClassificationExplainer |
|
import streamlit.components.v1 as components |
|
|
|
def visualize(text): |
|
|
|
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint) |
|
|
|
cls_explainer = SequenceClassificationExplainer(model,tokenizer) |
|
word_attributions = cls_explainer(masked_text) |
|
components.html(cls_explainer.visualize('visualize.html')) |
|
|
|
|
|
@st.cache |
|
def load_model(text): |
|
|
|
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint) |
|
|
|
nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer) |
|
|
|
results = nlp(text) |
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
import re |
|
def app(): |
|
st.title("OGBV-BERT") |
|
|
|
data = st.sidebar.radio("Pick the evaluation data :",('Twitter','Trac2020')) |
|
|
|
if data=="Twitter": |
|
|
|
target_text_path = "./input/tweet_list.csv" |
|
target_text_df = pd.read_csv(target_text_path) |
|
texts = target_text_df["text"] |
|
|
|
else: |
|
|
|
target_text_path = "./input/trac2_hin_test.csv" |
|
target_text_df = pd.read_csv(target_text_path) |
|
texts = target_text_df["Text"] |
|
|
|
pick_random = st.sidebar.checkbox("Pick any random text") |
|
|
|
|
|
if pick_random: |
|
random_text = texts[random.randint(0, texts.shape[0] - 1)] |
|
text = re.sub('@[^\s]+','',random_text) |
|
text = text[3:] |
|
|
|
masked_text = st.text_area("Please type a sentence to classify", text) |
|
else: |
|
select_text = st.sidebar.selectbox("Select any of the following text", texts) |
|
text = re.sub('@[^\s]+','',select_text) |
|
text = text[3:] |
|
masked_text = st.text_area("Please type a sentence to classify", text) |
|
|
|
st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""") |
|
|
|
|
|
|
|
if st.button("Classify"): |
|
with st.spinner("Classifying the sentence..."): |
|
pred = load_model(masked_text) |
|
st.write(pred) |
|
|
|
if st.button('Visualize attributions'): |
|
with st.spinner("Visualizing .....") : |
|
visualize(masked_text) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
app() |