GeorgHCundK commited on
Commit
80cacdd
1 Parent(s): b57b028

adjust to API funcs

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -1,37 +1,65 @@
1
  import gradio as gr
2
- from setfit import SetFitModel
 
 
3
 
4
- # Load a pre-trained SetFit model (replace with your specific model)
5
- model = SetFitModel.from_pretrained("CundK/ePA-Classifier-Agreement")
 
 
 
 
 
6
 
7
- # Define the mapping of labels
8
- label_mapping = {0: "Ablehnung", 1: "Neutral", 2: "Befürwortung"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Define the color mapping
11
  color_mapping = {
12
  "Ablehnung": "red",
13
  "Neutral": "yellow",
14
- "Befürwortung": "green"
15
  }
16
 
17
- # Function to classify text and return the label with background color
18
  def classify_text(text):
19
- # Get predictions from the model
20
- predictions = model([text])
21
-
22
- # Convert the prediction tensor to an integer index
23
- label_index = int(predictions[0].item())
24
 
25
- # Map the integer index to the corresponding label
26
- label = label_mapping[label_index]
27
 
28
  # Return the label with the appropriate background color
29
- return f'<div style="background-color: {color_mapping[label]}; padding: 10px; border-radius: 5px;">{label}</div>'
30
 
31
  # Create the Gradio interface using Blocks for custom layout
32
  with gr.Blocks() as interface:
33
- gr.Markdown("# Text Classification with SetFit")
34
- gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Befürwortung'.")
35
 
36
  # Input text box set for single-line input
37
  text_input = gr.Textbox(lines=1, placeholder="Enter your text here...") # Single-line input
 
1
  import gradio as gr
2
+ import requests
3
+ import time
4
+ import os
5
 
6
+ # Define the AgreementClassifier class
7
+ class AgreementClassifier:
8
+ def __init__(self, api_token, api_url, backoff_factor=1):
9
+ self.api_token = api_token
10
+ self.api_url = api_url
11
+ self.headers = {"Authorization": f"Bearer {self.api_token}"}
12
+ self.backoff_factor = backoff_factor
13
 
14
+ def query(self, payload):
15
+ retries = 0
16
+ while True:
17
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
18
+ if response.status_code == 503:
19
+ retries += 1
20
+ wait_time = self.backoff_factor * (2 ** (retries - 1))
21
+ print(f"503 Service Unavailable. Retrying in {wait_time} seconds...")
22
+ time.sleep(wait_time)
23
+ else:
24
+ response.raise_for_status()
25
+ return response.json()
26
+
27
+ def classify_text_topic(self, input_text):
28
+ result = self.query(
29
+ {
30
+ "inputs": input_text,
31
+ "parameters": {},
32
+ }
33
+ )
34
+ return result
35
+
36
+ # Initialize the classifier with API token and URL
37
+ API_TOKEN = os.getenv("API_TOKEN")
38
+ API_URL = os.getenv("API_URL")
39
+ classifier = AgreementClassifier(API_TOKEN, API_URL)
40
 
41
  # Define the color mapping
42
  color_mapping = {
43
  "Ablehnung": "red",
44
  "Neutral": "yellow",
45
+ "Zustimmung": "green"
46
  }
47
 
48
+ # Function to classify text using the API
49
  def classify_text(text):
50
+ # Get predictions from the classifier
51
+ predictions = classifier.classify_text_topic(text)
 
 
 
52
 
53
+ # Find the label with the highest score
54
+ predicted_label = max(predictions, key=lambda x: x['score'])['label']
55
 
56
  # Return the label with the appropriate background color
57
+ return f'<div style="background-color: {color_mapping[predicted_label]}; padding: 10px; border-radius: 5px;">{predicted_label}</div>'
58
 
59
  # Create the Gradio interface using Blocks for custom layout
60
  with gr.Blocks() as interface:
61
+ gr.Markdown("# Text Classification with API")
62
+ gr.Markdown("Enter a sentence and get it classified into 'Ablehnung', 'Neutral', or 'Zustimmung'.")
63
 
64
  # Input text box set for single-line input
65
  text_input = gr.Textbox(lines=1, placeholder="Enter your text here...") # Single-line input