File size: 2,930 Bytes
6529594
80cacdd
 
 
6529594
80cacdd
 
 
 
 
 
 
6529594
80cacdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6529594
 
 
 
 
80cacdd
6529594
 
80cacdd
6529594
80cacdd
 
6529594
80cacdd
 
6529594
 
80cacdd
6529594
 
b190566
0078639
 
6529594
b57b028
0078639
6529594
 
0078639
6529594
 
0078639
6529594
 
 
 
 
 
 
 
ec04ce5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import requests
import time
import os

# Define the AgreementClassifier class
class AgreementClassifier:
    def __init__(self, api_token, api_url, backoff_factor=1):
        self.api_token = api_token
        self.api_url = api_url
        self.headers = {"Authorization": f"Bearer {self.api_token}"}
        self.backoff_factor = backoff_factor

    def query(self, payload):
        retries = 0
        while True:
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            if response.status_code == 503:
                retries += 1
                wait_time = self.backoff_factor * (2 ** (retries - 1))
                print(f"503 Service Unavailable. Retrying in {wait_time} seconds...")
                time.sleep(wait_time)
            else:
                response.raise_for_status()
                return response.json()

    def classify_text_topic(self, input_text):
        result = self.query(
            {
                "inputs": input_text,
                "parameters": {},
            }
        )
        return result
    
# Initialize the classifier with API token and URL
API_TOKEN = os.getenv("API_TOKEN")
API_URL = os.getenv("API_URL")
classifier = AgreementClassifier(API_TOKEN, API_URL)

# Define the color mapping
color_mapping = {
    "Ablehnung": "red",
    "Neutral": "yellow",
    "Zustimmung": "green"
}

# Function to classify text using the API
def classify_text(text):
    # Get predictions from the classifier
    predictions = classifier.classify_text_topic(text)
    
    # Find the label with the highest score
    predicted_label = max(predictions, key=lambda x: x['score'])['label']
    
    # Return the label with the appropriate background color
    return f'<div style="background-color: {color_mapping[predicted_label]}; padding: 10px; border-radius: 5px;">{predicted_label}</div>'

# Create the Gradio interface using Blocks for custom layout
with gr.Blocks(css=".gradio-container { max-width: 400px; margin: auto; }") as interface:
    gr.Markdown("# ePA Classifier")
    gr.Markdown("Gib einen Satz oder Text ein, der in 'Ablehnung', 'Neutral', oder 'Zustimmung' klassifiziert werden soll.")
    
    # Input text box set for single-line input
    text_input = gr.Textbox(lines=1, placeholder="Hier Text...")  # Single-line input
    
    # Submit button
    submit_btn = gr.Button("Klassifizieren")
    
    # Placeholder for result with an initial message
    result_output = gr.HTML(value="<div style='color:gray;'>Das Ergebnis wird hier angezeigt</div>")

    # Connect the submit button to the classification function
    submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)

    # Trigger the classification function when the user presses Enter in the text box
    text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)

# Launch the interface
interface.launch()