File size: 4,235 Bytes
850dd49
cc21853
 
 
 
850dd49
5f92565
 
 
ebc9c45
5ab9f7f
 
 
ebc9c45
 
b9d04e8
cc21853
 
 
b9d04e8
5f92565
 
b9d04e8
8937106
f12cfe9
 
c7a6b84
5f92565
ebc9c45
5f92565
 
cc21853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f92565
 
cc21853
 
 
5f92565
 
 
cc21853
 
6441204
 
 
 
 
 
 
 
 
cc21853
 
 
5f92565
c078837
f409800
8937106
 
908a1af
8937106
d1a418a
cc21853
 
5f92565
 
8937106
 
5f92565
 
 
850dd49
 
794464a
9e28d9c
6bbc6ba
850dd49
f8dfc8e
0890fb9
4ff6c22
 
5f92565
 
83db710
5f92565
4ff6c22
f409800
 
4ff6c22
f8dfc8e
 
850dd49
0890fb9
f8dfc8e
d213d29
 
f8dfc8e
9bd0746
850dd49
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from huggingface_hub import hf_hub_download
import json
import tensorflow as tf
import numpy as np


# Load models

MODELS = [
    ("Bayes Enron1 spam", BAYES := "bayes-enron1-spam"),
    ("NN Enron1 spam", NN := "nn-enron1-spam"),
    ("GISTy Enron1 spam", LLM := "gisty-enron1-spam"),
]

model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json")
with open(model_probs_path) as f:
    model_probs = json.load(f)

nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras")
nn_model = tf.keras.models.load_model(nn_model_path)

llm_model_path = hf_hub_download(repo_id="tbitai/gisty-enron1-spam", filename="gisty-enron1-spam.keras")
llm_model = tf.keras.models.load_model(llm_model_path)
# Sentence Transformers should be imported after Keras models, in order to prevent it from setting Keras to legacy.
from sentence_transformers import SentenceTransformer
st_model = SentenceTransformer("avsolatorio/GIST-large-Embedding-v0")


# Utils for Bayes

UNK = '[UNK]'

def tokenize(text):
    return tf.keras.preprocessing.text.text_to_word_sequence(text)

def combine(probs):
    if any(p == 0 for p in probs):
        return 0
    prod = np.prod(probs)
    neg_prod = np.prod([1 - p for p in probs])
    if prod + neg_prod == 0:  # Still possible due to floating point arithmetic
        return 0.5  # Assume that prod and neg_prod are equally small
    return prod / (prod + neg_prod)

def get_interesting_probs(probs, intr_threshold):
    return sorted(probs,
                  key=lambda p: abs(p - 0.5),
                  reverse=True)[:intr_threshold]

DEFAULT_INTR_THRESHOLD = 15

def unbias(p):
    return (2 * p) / (p + 1)


# Predict functions

def predict_bayes(text, intr_threshold, unbiased=False):
    words = tokenize(text)
    probs = []
    for w in words:
        try:
            p = model_probs[w]
            if unbiased:
                p = unbias(p)
        except KeyError:
            p = model_probs[UNK]
        probs.append(p)
    interesting_probs = get_interesting_probs(probs, intr_threshold)
    return combine(interesting_probs)

def predict_nn(text):
    return float(nn_model(np.array([text]))[0][0].numpy())

def predict_llm(text):
    embedding = st_model.encode(text)
    return float(llm_model(np.array([embedding]))[0][0].numpy())

def predict(model, input_txt, unbiased, intr_threshold):
    if model == BAYES:
        return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold)
    elif model == NN:
        return predict_nn(input_txt)
    elif model == LLM:
        return predict_llm(input_txt)


# UI

demo = gr.Interface(
    theme=gr.themes.Origin(  # Gradio 4-like
        primary_hue="yellow",
    ),
    fn=predict,
    inputs=[
        gr.Dropdown(choices=MODELS, value=BAYES, label="Model",
                    # FIXME: Font size should be smaller by default. Remove workaround when fixed in Gradio: https://github.com/gradio-app/gradio/issues/9642
                    info="<small>Learn more about the models [here](https://huggingface.co/collections/tbitai/bayes-or-spam-6700033fa145e298ec849249)</small>"),
        gr.TextArea(label="Email"),
    ],
    additional_inputs_accordion=gr.Accordion("Additional configuration for Bayes", open=False),
    additional_inputs=[
        gr.Checkbox(label="Unbias", info="<small>Correct Graham's bias?</small>"),
        gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, 
                  label="Interestingness threshold", 
                  info=f"<small>How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)</small>"),
    ],
    outputs=[gr.Number(label="Spam probability")],
    title="Bayes or Spam?",
    description="Choose your model, and predict if your email is a spam! 📨",
    examples=[
        [NN, "Enron actuals for June 26, 2000", None, None],
        [BAYES, "Stop the aging clock\nNerissa", True, DEFAULT_INTR_THRESHOLD],
    ],
    article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.",
)

if __name__ == "__main__":
    demo.launch()