georges
commited on
Commit
•
6529594
1
Parent(s):
12e80fc
initial commit
Browse files
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from setfit import SetFitModel
|
3 |
+
|
4 |
+
# Load a pre-trained SetFit model (replace with your specific model)
|
5 |
+
model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")
|
6 |
+
|
7 |
+
# Define the mapping of labels
|
8 |
+
label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}
|
9 |
+
|
10 |
+
# Define the color mapping
|
11 |
+
color_mapping = {
|
12 |
+
"Ablehnung": "red",
|
13 |
+
"Neutral": "yellow",
|
14 |
+
"Befürwortung": "green"
|
15 |
+
}
|
16 |
+
|
17 |
+
# Function to classify text and return the label with background color
|
18 |
+
def classify_text(text):
|
19 |
+
# Get predictions from the model
|
20 |
+
predictions = model([text])
|
21 |
+
|
22 |
+
# Convert the prediction tensor to an integer index
|
23 |
+
label_index = int(predictions[0].item())
|
24 |
+
|
25 |
+
# Map the integer index to the corresponding label
|
26 |
+
label = label_mapping[label_index]
|
27 |
+
|
28 |
+
# Return the label with the appropriate background color
|
29 |
+
return f'<div style="background-color: {color_mapping[label]}; padding: 10px; border-radius: 5px;">{label}</div>'
|
30 |
+
|
31 |
+
# Create the Gradio interface using Blocks for custom layout
|
32 |
+
with gr.Blocks() as interface:
|
33 |
+
gr.Markdown("# Text Classification with SetFit")
|
34 |
+
gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
|
35 |
+
|
36 |
+
# Input text box with multiline set to False to avoid Shift+Enter issue
|
37 |
+
text_input = gr.Textbox(lines=2, placeholder="Enter your text here...", multiline=False)
|
38 |
+
|
39 |
+
# Submit button
|
40 |
+
submit_btn = gr.Button("Submit")
|
41 |
+
|
42 |
+
# Placeholder for result with an initial message
|
43 |
+
result_output = gr.HTML(value="<div style='color:gray;'>The classification result will be displayed here.</div>")
|
44 |
+
|
45 |
+
# Connect the submit button to the classification function
|
46 |
+
submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
|
47 |
+
|
48 |
+
# Trigger the classification function when the user presses Enter in the text box
|
49 |
+
text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)
|
50 |
+
|
51 |
+
# Launch the interface
|
52 |
+
interface.launch()
|