karths commited on
Commit
de24b75
·
verified ·
1 Parent(s): b8c4b59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -7,6 +7,8 @@ from huggingface_hub import login, HfFolder
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from scipy.special import softmax
9
  import logging
 
 
10
 
11
  # Setup logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
@@ -62,7 +64,8 @@ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for pat
62
 
63
  def get_quality_name(model_name):
64
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
65
-
 
66
  def model_prediction(model, text, device):
67
  model.to(device)
68
  model.eval()
@@ -91,7 +94,7 @@ def main_interface(text):
91
  for model_path, model in models.items():
92
  quality_name = get_quality_name(model_path)
93
  avg_prob = model_prediction(model, text, device)
94
- if avg_prob >= 0.995: # Only consider probabilities >= 0.90
95
  results.append((quality_name, avg_prob))
96
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
97
 
 
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
  from scipy.special import softmax
9
  import logging
10
+ import spaces
11
+
12
 
13
  # Setup logging
14
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
 
64
 
65
  def get_quality_name(model_name):
66
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
67
+
68
+ @spaces.GPU
69
  def model_prediction(model, text, device):
70
  model.to(device)
71
  model.eval()
 
94
  for model_path, model in models.items():
95
  quality_name = get_quality_name(model_path)
96
  avg_prob = model_prediction(model, text, device)
97
+ if avg_prob >= 0.95: # Only consider probabilities >= 0.90
98
  results.append((quality_name, avg_prob))
99
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
100