File size: 1,992 Bytes
6529594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from setfit import SetFitModel

# Load a pre-trained SetFit model (replace with your specific model)
model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")

# Define the mapping of labels
label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}

# Define the color mapping
color_mapping = {
    "Ablehnung": "red",
    "Neutral": "yellow",
    "Befürwortung": "green"
}

# Function to classify text and return the label with background color
def classify_text(text):
    # Get predictions from the model
    predictions = model([text])
    
    # Convert the prediction tensor to an integer index
    label_index = int(predictions[0].item())
    
    # Map the integer index to the corresponding label
    label = label_mapping[label_index]
    
    # Return the label with the appropriate background color
    return f'<div style="background-color: {color_mapping[label]}; padding: 10px; border-radius: 5px;">{label}</div>'

# Create the Gradio interface using Blocks for custom layout
with gr.Blocks() as interface:
    gr.Markdown("# Text Classification with SetFit")
    gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
    
    # Input text box with multiline set to False to avoid Shift+Enter issue
    text_input = gr.Textbox(lines=2, placeholder="Enter your text here...", multiline=False)
    
    # Submit button
    submit_btn = gr.Button("Submit")
    
    # Placeholder for result with an initial message
    result_output = gr.HTML(value="<div style='color:gray;'>The classification result will be displayed here.</div>")

    # Connect the submit button to the classification function
    submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)

    # Trigger the classification function when the user presses Enter in the text box
    text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)

# Launch the interface
interface.launch()