GeorgHCundK's picture
fix formatting
ec04ce5
raw
history blame
2.93 kB
import gradio as gr
import requests
import time
import os
# Define the AgreementClassifier class
class AgreementClassifier:
def __init__(self, api_token, api_url, backoff_factor=1):
self.api_token = api_token
self.api_url = api_url
self.headers = {"Authorization": f"Bearer {self.api_token}"}
self.backoff_factor = backoff_factor
def query(self, payload):
retries = 0
while True:
response = requests.post(self.api_url, headers=self.headers, json=payload)
if response.status_code == 503:
retries += 1
wait_time = self.backoff_factor * (2 ** (retries - 1))
print(f"503 Service Unavailable. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
else:
response.raise_for_status()
return response.json()
def classify_text_topic(self, input_text):
result = self.query(
{
"inputs": input_text,
"parameters": {},
}
)
return result
# Initialize the classifier with API token and URL
API_TOKEN = os.getenv("API_TOKEN")
API_URL = os.getenv("API_URL")
classifier = AgreementClassifier(API_TOKEN, API_URL)
# Define the color mapping
color_mapping = {
"Ablehnung": "red",
"Neutral": "yellow",
"Zustimmung": "green"
}
# Function to classify text using the API
def classify_text(text):
# Get predictions from the classifier
predictions = classifier.classify_text_topic(text)
# Find the label with the highest score
predicted_label = max(predictions, key=lambda x: x['score'])['label']
# Return the label with the appropriate background color
return f'<div style="background-color: {color_mapping[predicted_label]}; padding: 10px; border-radius: 5px;">{predicted_label}</div>'
# Create the Gradio interface using Blocks for custom layout
with gr.Blocks(css=".gradio-container { max-width: 400px; margin: auto; }") as interface:
gr.Markdown("# ePA Classifier")
gr.Markdown("Gib einen Satz oder Text ein, der in 'Ablehnung', 'Neutral', oder 'Zustimmung' klassifiziert werden soll.")
# Input text box set for single-line input
text_input = gr.Textbox(lines=1, placeholder="Hier Text...") # Single-line input
# Submit button
submit_btn = gr.Button("Klassifizieren")
# Placeholder for result with an initial message
result_output = gr.HTML(value="<div style='color:gray;'>Das Ergebnis wird hier angezeigt</div>")
# Connect the submit button to the classification function
submit_btn.click(fn=classify_text, inputs=text_input, outputs=result_output)
# Trigger the classification function when the user presses Enter in the text box
text_input.submit(fn=classify_text, inputs=text_input, outputs=result_output)
# Launch the interface
interface.launch()