import gradio as gr
from setfit import SetFitModel
# Load a pre-trained SetFit model (replace with your specific model)
model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")
# Define the mapping of labels
label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}
# Define the color mapping
color_mapping = {
"Ablehnung": "red",
"Neutral": "yellow",
"Befürwortung": "green"
}
# Function to classify text and return the label with background color
def classify_text(text):
# Get predictions from the model
predictions = model([text])
# Convert the prediction tensor to an integer index
label_index = int(predictions[0].item())
# Map the integer index to the corresponding label
label = label_mapping[label_index]
# Return the label with the appropriate background color
return f'
{label}
'
# Create the Gradio interface using Blocks for custom layout
with gr.Blocks() as interface:
gr.Markdown("# Text Classification with SetFit")
gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
# Input text box with multiline set to False to avoid Shift+Enter issue
text_input = gr.Textbox(lines=2, placeholder="Enter your text here...", multiline=False)
# Submit button
submit_btn = gr.Button("Submit")
# Placeholder for result with an initial message
result_output = gr.HTML(value="The classification result will be displayed here.
")
# Connect the submit button to the classification function
submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
# Trigger the classification function when the user presses Enter in the text box
text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)
# Launch the interface
interface.launch()