File size: 3,676 Bytes
c8f4a50
348dfbf
 
faae8e8
0775e2f
348dfbf
 
 
 
 
0eb19c8
 
a9b4522
348dfbf
 
 
 
 
 
 
 
 
 
31a243b
348dfbf
875560e
 
 
6c59dc8
348dfbf
c8f4a50
348dfbf
c8f4a50
 
348dfbf
c8f4a50
 
348dfbf
b4d9dd6
348dfbf
 
 
 
2495ea7
 
348dfbf
 
 
2495ea7
348dfbf
f0f53dd
 
348dfbf
97a1749
348dfbf
f0f53dd
 
348dfbf
97a1749
 
2495ea7
 
 
b4d9dd6
348dfbf
 
 
6fb9fd9
 
 
 
 
 
 
 
 
e8f5b25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Zero-Shot Text Classification with Multilingual T5 (mT5)

import streamlit as st
import plotly.graph_objects as go
from mT5Model import runModel

text_1 = """Bilim insanları Botsvana’da Covid-19’un şu ana kadar en çok mutasyona uğramış varyantını tespit etti. \
Resmi olarak B.1.1.529 koduyla bilinen bu varyantı ise “Nu varyantı” adı verildi. Uzmanlar bu varyant içerisinde \
tam 32 farklı mutasyon tespit edildiğini açıklarken, bu virüsün corona virüsü aşılarına karşı daha dirençli olabileceğini duyurdu."""

text_2 = """Argentina beat Australia 2-1 on Saturday and will take on the Netherlands in the World Cup quarterfinals. \
It was a historic night for Lionel Messi as the Argentine superstar took to the pitch for his 1,000th match for club and country. \
He also scored in the match. Messi scored the opening goal in the 35th minute as his low shot in the box beat Australian goalkeeper Mathew Ryan."""

@st.cache(allow_output_mutation=True)
def list2text(label_list):
    labels = ""
    for label in label_list:
        labels = labels + label + ","
    labels = labels[:-1]
    return labels

label_list_1 = ["dünya", "ekonomi", "kültür", "sağlık", "siyaset", "spor", "teknoloji"]
label_list_2 = ["positive", "negative", "neutral"]

hypothesis_1 = "Bu yazı {} konusundadır"
hypothesis_2 = "This text is in {} subject"

st.title("Multilingual Zero-Shot Text Classification with mT5")

model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"

st.sidebar.write("For details of used model:")
st.sidebar.write("https://huggingface.co/alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli")

st.sidebar.write("For Xtreme XNLI Dataset:")
st.sidebar.write("https://www.tensorflow.org/datasets/catalog/xtreme_xnli")

st.subheader("Select Text, Label List and Hyphothesis")
st.text_area("Text #1", text_1, height=128)
st.text_area("Text #2", text_2, height=128)
st.write(f"Label List #1: {list2text(label_list_1)}")
st.write(f"Label List #2: {list2text(label_list_2)}")
st.write(f"Hypothesis #1: {hypothesis_1}")
st.write(f"Hypothesis #2: {hypothesis_2}")

text = st.radio("Select Text", ("Text #1", "Text #2", "New Text"))
labels = st.radio("Select Label List", ("Label List #1", "Label List #2", "New Label List"))
hypothesis = st.radio("Select Hypothesis", ("Hypothesis #1", "Hypothesis #2", "New Hypothesis"))

if text == "Text #1": sequence_to_classify = text_1
elif text == "Text #2": sequence_to_classify = text_2
elif text == "New Text":
    sequence_to_classify = st.text_area("New Text", value="", height=128)

if labels == "Label List #1": candidate_labels = label_list_1
elif labels == "Label List #2": candidate_labels = label_list_2
elif labels == "New Label List":
    candidate_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",")

if hypothesis == "Hypothesis #1": hypothesis_template = hypothesis_1
elif hypothesis == "Hypothesis #2": hypothesis_template = hypothesis_2
elif labels == "New Hypothesis":
    hypothesis_template = st.text_area("Hypothesis Template for NLI (Pls use similar format of examples)", value="", height=16)
        
Run_Button = st.button("Run", key=None)
if Run_Button == True:
    with st.spinner('Model is running...'):
        output = runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template)
        output_labels = list(output.keys())
        output_scores = list(output.values())

        st.header("Result")
        fig = go.Figure([go.Bar(x=output_labels, y=output_scores)])
        st.plotly_chart(fig, use_container_width=False, sharing="streamlit")
        st.success('Done!')