georges commited on
Commit
6529594
1 Parent(s): 12e80fc

initial commit

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from setfit import SetFitModel
3
+
4
+ # Load a pre-trained SetFit model (replace with your specific model)
5
+ model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")
6
+
7
+ # Define the mapping of labels
8
+ label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}
9
+
10
+ # Define the color mapping
11
+ color_mapping = {
12
+ "Ablehnung": "red",
13
+ "Neutral": "yellow",
14
+ "Befürwortung": "green"
15
+ }
16
+
17
+ # Function to classify text and return the label with background color
18
+ def classify_text(text):
19
+ # Get predictions from the model
20
+ predictions = model([text])
21
+
22
+ # Convert the prediction tensor to an integer index
23
+ label_index = int(predictions[0].item())
24
+
25
+ # Map the integer index to the corresponding label
26
+ label = label_mapping[label_index]
27
+
28
+ # Return the label with the appropriate background color
29
+ return f'<div style="background-color: {color_mapping[label]}; padding: 10px; border-radius: 5px;">{label}</div>'
30
+
31
+ # Create the Gradio interface using Blocks for custom layout
32
+ with gr.Blocks() as interface:
33
+ gr.Markdown("# Text Classification with SetFit")
34
+ gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
35
+
36
+ # Input text box with multiline set to False to avoid Shift+Enter issue
37
+ text_input = gr.Textbox(lines=2, placeholder="Enter your text here...", multiline=False)
38
+
39
+ # Submit button
40
+ submit_btn = gr.Button("Submit")
41
+
42
+ # Placeholder for result with an initial message
43
+ result_output = gr.HTML(value="<div style='color:gray;'>The classification result will be displayed here.</div>")
44
+
45
+ # Connect the submit button to the classification function
46
+ submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
47
+
48
+ # Trigger the classification function when the user presses Enter in the text box
49
+ text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)
50
+
51
+ # Launch the interface
52
+ interface.launch()