File size: 2,858 Bytes
d7a6200 9316bdf d7a6200 0677d87 c6e39b5 9316bdf 2b2a52e 0677d87 d7a6200 5a2a0f4 d7a6200 0677d87 d7a6200 b1cf2dc d7a6200 5a2a0f4 d7a6200 f90f6f8 a4ace7e b1cf2dc d7a6200 b1cf2dc d7a6200 b1cf2dc a4ace7e fdbf9c1 d7a6200 e11f528 d7a6200 f589f6b 0677d87 f589f6b 0677d87 b1cf2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import json
import os
import random
import pandas as pd
import streamlit as st
from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
from transformers_interpret import SequenceClassificationExplainer
import streamlit.components.v1 as components # Import Streamlit
def visualize(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
cls_explainer = SequenceClassificationExplainer(model,tokenizer)
word_attributions = cls_explainer(masked_text)
components.html(cls_explainer.visualize('visualize.html'))
#components.html('visualize.html')
@st.cache
def load_model(text):
checkpoint = 'mlkorra/OGBV-gender-bert-hi-en'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
nlp = pipeline('sentiment-analysis',model=model,tokenizer=tokenizer)
results = nlp(text)
return results
#MASK_TOKEN = tokenizer.mask_token
#masked_text = masked_text.replace("<mask>", MASK_TOKEN)
#result_sentence = nlp(masked_text)
#return result_sentence[0]["sequence"], result_sentence[0]["token_str"]
import re
def app():
st.title("OGBV-BERT")
data = st.sidebar.radio("Pick the evaluation data :",('Twitter','Trac2020'))
if data=="Twitter":
target_text_path = "./input/tweet_list.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["text"]
else:
target_text_path = "./input/trac2_hin_test.csv"
target_text_df = pd.read_csv(target_text_path)
texts = target_text_df["Text"]
pick_random = st.sidebar.checkbox("Pick any random text")
if pick_random:
random_text = texts[random.randint(0, texts.shape[0] - 1)]
text = re.sub('@[^\s]+','',random_text)
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
else:
select_text = st.sidebar.selectbox("Select any of the following text", texts)
text = re.sub('@[^\s]+','',select_text)
text = text[3:]
masked_text = st.text_area("Please type a sentence to classify", text)
st.sidebar.markdown("""Find out more at [Github](https://github.com/mlkorra/OGBV-detection)""")
# pd.set_option('max_colwidth',30)
if st.button("Classify"):
with st.spinner("Classifying the sentence..."):
pred = load_model(masked_text)
st.write(pred)
if st.button('Visualize attributions'):
with st.spinner("Visualizing .....") :
visualize(masked_text)
if __name__ == "__main__":
app() |