File size: 1,992 Bytes
6529594 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import gradio as gr
from setfit import SetFitModel
# Load a pre-trained SetFit model (replace with your specific model)
model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")
# Define the mapping of labels
label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}
# Define the color mapping
color_mapping = {
"Ablehnung": "red",
"Neutral": "yellow",
"Befürwortung": "green"
}
# Function to classify text and return the label with background color
def classify_text(text):
# Get predictions from the model
predictions = model([text])
# Convert the prediction tensor to an integer index
label_index = int(predictions[0].item())
# Map the integer index to the corresponding label
label = label_mapping[label_index]
# Return the label with the appropriate background color
return f'<div style="background-color: {color_mapping[label]}; padding: 10px; border-radius: 5px;">{label}</div>'
# Create the Gradio interface using Blocks for custom layout
with gr.Blocks() as interface:
gr.Markdown("# Text Classification with SetFit")
gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
# Input text box with multiline set to False to avoid Shift+Enter issue
text_input = gr.Textbox(lines=2, placeholder="Enter your text here...", multiline=False)
# Submit button
submit_btn = gr.Button("Submit")
# Placeholder for result with an initial message
result_output = gr.HTML(value="<div style='color:gray;'>The classification result will be displayed here.</div>")
# Connect the submit button to the classification function
submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
# Trigger the classification function when the user presses Enter in the text box
text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)
# Launch the interface
interface.launch() |