import pandas as pd import streamlit as st import plotly.express as px from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS st.title("Zero-shot Turkish Text Classification") method_selection = st.radio( "Select a zero-shot classification method.", [ METHOD_OPTIONS["nli"], METHOD_OPTIONS["nsp"], ], ) if method_selection == METHOD_OPTIONS["nli"]: model = st.selectbox( "Select a natural language inference model.", NLI_MODEL_OPTIONS ) if method_selection == METHOD_OPTIONS["nsp"]: model = st.selectbox( "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS ) st.header("Configure prompts and labels") col1, col2 = st.columns(2) col1.subheader("Candidate labels") labels = col1.text_area( label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.", value="spor,dünya,siyaset,ekonomi,kültür ve sanat", height=10, ) col2.subheader("Prompt template") prompt_template = col2.text_area( label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.", value="Bu metin {} kategorisine aittir", height=10, ) col1.header("Make predictions") col2.header("") col1.text_area("", value="Enter some text to classify.") col1.button("Predict") probs = [0.86, 0.10, 0.01, 0.02, 0.01] data = pd.DataFrame({"labels": labels.split(","), "probability": probs}).sort_values( by="probability", ascending=False ) chart = px.bar( data, x="probability", y="labels", color="labels", orientation="h", height=290, width=500, ).update_layout( { "xaxis": {"title": "probability", "visible": True, "showticklabels": True}, "yaxis": {"title": None, "visible": True, "showticklabels": True}, "margin": dict( l=10, # left r=10, # right t=50, # top b=10, # bottom ), "showlegend": False, } ) col2.plotly_chart(chart)