Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import json | |
import tensorflow as tf | |
import numpy as np | |
# Load models | |
model_probs_path = hf_hub_download(repo_id="tbitai/bayes-enron1-spam", filename="probs.json") | |
with open(model_probs_path) as f: | |
model_probs = json.load(f) | |
nn_model_path = hf_hub_download(repo_id="tbitai/nn-enron1-spam", filename="nn-enron1-spam.keras") | |
nn_model = tf.keras.models.load_model(nn_model_path) | |
# Utils for Bayes | |
UNK = '[UNK]' | |
def tokenize(text): | |
return tf.keras.preprocessing.text.text_to_word_sequence(text) | |
def combine(probs): | |
if any(p == 0 for p in probs): | |
return 0 | |
prod = np.prod(probs) | |
neg_prod = np.prod([1 - p for p in probs]) | |
if prod + neg_prod == 0: # Still possible due to floating point arithmetic | |
return 0.5 # Assume that prod and neg_prod are equally small | |
return prod / (prod + neg_prod) | |
def get_interesting_probs(probs, intr_threshold): | |
return sorted(probs, | |
key=lambda p: abs(p - 0.5), | |
reverse=True)[:intr_threshold] | |
DEFAULT_INTR_THRESHOLD = 15 | |
def unbias(p): | |
return (2 * p) / (p + 1) | |
# Predict functions | |
def predict_bayes(text, intr_threshold, unbiased=False): | |
words = tokenize(text) | |
probs = [] | |
for w in words: | |
try: | |
p = model_probs[w] | |
if unbiased: | |
p = unbias(p) | |
except KeyError: | |
p = model_probs[UNK] | |
probs.append(p) | |
interesting_probs = get_interesting_probs(probs, intr_threshold) | |
return combine(interesting_probs) | |
def predict_nn(text): | |
return nn_model(np.array([text]))[0][0].numpy() | |
MODELS = [ | |
BAYES := "Bayes Enron1 spam", | |
NN := "NN Enron1 spam", | |
] | |
def predict(model, input_txt, unbiased, intr_threshold): | |
if model == BAYES: | |
return predict_bayes(input_txt, unbiased=unbiased, intr_threshold=intr_threshold) | |
elif model == NN: | |
return predict_nn(input_txt) | |
# UI | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Dropdown(choices=MODELS, value=BAYES, label="Model"), | |
gr.TextArea(label="Email"), | |
], | |
additional_inputs_accordion="Additional configuration for Bayes", | |
additional_inputs=[ | |
gr.Checkbox(label="Unbias", info="Correct Graham's bias?"), | |
gr.Slider(minimum=1, maximum=DEFAULT_INTR_THRESHOLD + 5, step=1, value=DEFAULT_INTR_THRESHOLD, | |
label="Interestingness threshold", | |
info=f"How many of the most interesting words to select in the probability calculation? ({DEFAULT_INTR_THRESHOLD} for Graham)"), | |
], | |
outputs=[gr.Number(label="Spam probability")], | |
title="Bayes or Spam?", | |
description="Choose your model, and predict if your email is a spam! 📨<br>COMING SOON: LLM models.", | |
examples=[ | |
[BAYES, False, DEFAULT_INTR_THRESHOLD, "Enron actuals for June 26, 2000"], | |
[BAYES, False, DEFAULT_INTR_THRESHOLD, nerissa_email := "Stop the aging clock\nNerissa"], | |
[BAYES, True, DEFAULT_INTR_THRESHOLD, nerissa_email], | |
], | |
article="This is a demo of the models in the [Bayes or Spam?](https://github.com/tbitai/bayes-or-spam) project.", | |
) | |
if __name__ == "__main__": | |
demo.launch() |