tbitai commited on
Commit
8937106
1 Parent(s): 76143aa
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -3,6 +3,7 @@ from huggingface_hub import hf_hub_download
3
  import json
4
  import tensorflow as tf
5
  import numpy as np
 
6
 
7
 
8
  # Load models
@@ -14,6 +15,10 @@ with open(model_probs_path) as f:
14
  nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
15
  nn_model = tf.keras.models.load_model(nn_model_path)
16
 
 
 
 
 
17
 
18
  # Utils for Bayes
19
 
@@ -61,9 +66,14 @@ def predict_bayes(text, intr_threshold, unbiased=False):
61
  def predict_nn(text):
62
  return nn_model(np.array([text]))[0][0].numpy()
63
 
 
 
 
 
64
  MODELS = [
65
  BAYES := "Bayes Enron1 spam",
66
  NN := "NN Enron1 spam",
 
67
  ]
68
 
69
  def predict(model, input_txt, unbiased, intr_threshold):
@@ -71,6 +81,8 @@ def predict(model, input_txt, unbiased, intr_threshold):
71
  return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
72
  elif model == NN:
73
  return predict_nn(input_txt)
 
 
74
 
75
 
76
  # UI
@@ -90,11 +102,13 @@ demo = gr.Interface(
90
  ],
91
  outputs=[gr.Number(label="Spam probability")],
92
  title="Bayes or Spam?",
93
- description="Choose your model, and predict if your email is a spam! 📨<br>COMING SOON: LLM models.",
94
  examples=[
95
- [BAYES, "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
96
  [BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
97
  [BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
 
 
98
  ],
99
  article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
100
  )
 
3
  import json
4
  import tensorflow as tf
5
  import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
 
8
 
9
  # Load models
 
15
  nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
16
  nn_model = tf.keras.models.load_model(nn_model_path)
17
 
18
+ st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")
19
+ llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
20
+ llm_model = tf.keras.models.load_model(llm_model_path)
21
+
22
 
23
  # Utils for Bayes
24
 
 
66
  def predict_nn(text):
67
  return nn_model(np.array([text]))[0][0].numpy()
68
 
69
+ def predict_llm(text):
70
+ embedding = st_model.encode(text)
71
+ return llm_model(np.array([embedding]))[0][0].numpy()
72
+
73
  MODELS = [
74
  BAYES := "Bayes Enron1 spam",
75
  NN := "NN Enron1 spam",
76
+ LLM := "GISTy Enron1 spam",
77
  ]
78
 
79
  def predict(model, input_txt, unbiased, intr_threshold):
 
81
  return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
82
  elif model == NN:
83
  return predict_nn(input_txt)
84
+ elif model == LLM:
85
+ return predict_llm(input_txt)
86
 
87
 
88
  # UI
 
102
  ],
103
  outputs=[gr.Number(label="Spam probability")],
104
  title="Bayes or Spam?",
105
+ description="Choose your model, and predict if your email is a spam! 📨",
106
  examples=[
107
+ [BAYES, enron_email := "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
108
  [BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
109
  [BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
110
+ [NN, enron_email, None, None],
111
+ [LLM, enron_email, None, None],
112
  ],
113
  article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
114
  )