Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
import plotly.express as px | |
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS | |
st.title("Zero-shot Turkish Text Classification") | |
method_selection = st.radio( | |
"Select a zero-shot classification method.", | |
[ | |
METHOD_OPTIONS["nli"], | |
METHOD_OPTIONS["nsp"], | |
], | |
) | |
if method_selection == METHOD_OPTIONS["nli"]: | |
model = st.selectbox( | |
"Select a natural language inference model.", NLI_MODEL_OPTIONS | |
) | |
if method_selection == METHOD_OPTIONS["nsp"]: | |
model = st.selectbox( | |
"Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS | |
) | |
st.header("Configure prompts and labels") | |
col1, col2 = st.columns(2) | |
col1.subheader("Candidate labels") | |
labels = col1.text_area( | |
label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", | |
value="spor,dünya,siyaset,ekonomi,kültür ve sanat", | |
height=10, | |
) | |
col2.subheader("Prompt template") | |
prompt_template = col2.text_area( | |
label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", | |
value="Bu metin {} kategorisine aittir", | |
height=10, | |
) | |
col1.header("Make predictions") | |
col2.header("") | |
col1.text_area("", value="Enter some text to classify.") | |
col1.button("Predict") | |
probs = [0.86, 0.10, 0.01, 0.02, 0.01] | |
data = pd.DataFrame({"labels": labels.split(","), "probability": probs}).sort_values( | |
by="probability", ascending=False | |
) | |
chart = px.bar( | |
data, | |
x="probability", | |
y="labels", | |
color="labels", | |
orientation="h", | |
height=290, | |
width=500, | |
).update_layout( | |
{ | |
"xaxis": {"title": "probability", "visible": True, "showticklabels": True}, | |
"yaxis": {"title": None, "visible": True, "showticklabels": True}, | |
"margin": dict( | |
l=10, # left | |
r=10, # right | |
t=50, # top | |
b=10, # bottom | |
), | |
"showlegend": False, | |
} | |
) | |
col2.plotly_chart(chart) | |