File size: 4,001 Bytes
850dd49
cc21853
 
 
 
850dd49
5f92565
 
 
ebc9c45
5ab9f7f
 
 
ebc9c45
 
b9d04e8
cc21853
 
 
b9d04e8
5f92565
 
b9d04e8
8937106
f12cfe9
 
c7a6b84
5f92565
ebc9c45
5f92565
 
cc21853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f92565
 
cc21853
 
 
5f92565
 
 
cc21853
 
6441204
 
 
 
 
 
 
 
 
cc21853
 
 
5f92565
c078837
f409800
8937106
 
908a1af
8937106
d1a418a
cc21853
 
5f92565
 
8937106
 
5f92565
 
 
850dd49
 
 
f8dfc8e
cc21853
5f92565
 
83db710
5f92565
cc21853
f409800
 
 
f8dfc8e
 
850dd49
8937106
f8dfc8e
8937106
80b487d
 
8937106
 
f8dfc8e
9bd0746
850dd49
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np


# Load models

MODELS = [
    ("Bayes Enron1 spam", BAYES := "bayes-enron1-spam"),
    ("NN Enron1 spam", NN := "nn-enron1-spam"),
    ("GISTy Enron1 spam", LLM := "gisty-enron1-spam"),
]

model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
    model_probs = json.load(f)

nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)

llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)
# Sentence Transformers should be imported after Keras models, in order to prevent it from setting Keras to legacy.
from sentence_transformers import SentenceTransformer
st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")


# Utils for Bayes

UNK = '[UNK]'

def tokenize(text):
    return tf.keras.preprocessing.text.text_to_word_sequence(text)

def combine(probs):
    if any(p == 0 for p in probs):
        return 0
    prod = np.prod(probs)
    neg_prod = np.prod([1 - p for p in probs])
    if prod + neg_prod == 0:  # Still possible due to floating point arithmetic
        return 0.5  # Assume that prod and neg_prod are equally small
    return prod / (prod + neg_prod)

def get_interesting_probs(probs, intr_threshold):
    return sorted(probs,
                  key=lambda p: abs(p - 0.5),
                  reverse=True)[:intr_threshold]

DEFAULT_INTR_THRESHOLD = 15

def unbias(p):
    return (2 * p) / (p + 1)


# Predict functions

def predict_bayes(text, intr_threshold, unbiased=False):
    words = tokenize(text)
    probs = []
    for w in words:
        try:
            p = model_probs[w]
            if unbiased:
                p = unbias(p)
        except KeyError:
            p = model_probs[UNK]
        probs.append(p)
    interesting_probs = get_interesting_probs(probs, intr_threshold)
    return combine(interesting_probs)

def predict_nn(text):
    return float(nn_model(np.array([text]))[0][0].numpy())

def predict_llm(text):
    embedding = st_model.encode(text)
    return float(llm_model(np.array([embedding]))[0][0].numpy())

def predict(model, input_txt, unbiased, intr_threshold):
    if model == BAYES:
        return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
    elif model == NN:
        return predict_nn(input_txt)
    elif model == LLM:
        return predict_llm(input_txt)


# UI

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Dropdown(choices=MODELS, value=BAYES, label="Model"),
        gr.TextArea(label="Email"),
    ],
    additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
    additional_inputs=[
        gr.Checkbox(label="Unbias", info="Correct Graham's bias?"),
        gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, 
                  label="Interestingness threshold", 
                  info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"),
    ],
    outputs=[gr.Number(label="Spam probability")],
    title="Bayes or Spam?",
    description="Choose your model, and predict if your email is a spam! 📨",
    examples=[
        [BAYES, enron_email := "Enron actuals for June 26, 2000", False, DEFAULT_INTR_THRESHOLD],
        [BAYES, nerissa_email := "Stop the aging clock\nNerissa", False, DEFAULT_INTR_THRESHOLD],
        [BAYES, nerissa_email, True, DEFAULT_INTR_THRESHOLD],
        [NN, enron_email, None, None],
        [LLM, enron_email, None, None],
    ],
    article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)

if __name__ == "__main__":
    demo.launch()