Spaces:
Sleeping
Sleeping
#import streamlit as st | |
#from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | |
#import torch | |
## Load the fine-tuned model and tokenizer | |
#model_name = "fine-tuned-model" | |
#model = DistilBertForSequenceClassification.from_pretrained(model_name) | |
#tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) | |
## Function to classify text | |
#def classify_text(text): | |
# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# logits = outputs.logits | |
# predicted_class_id = torch.argmax(logits, dim=1).item() | |
# return "spam" if predicted_class_id == 1 else "ham" | |
## Streamlit app | |
#st.title("Text Message Classification") | |
#st.write("Enter a text message and see if it's classified as spam or ham.") | |
#user_input = st.text_area("Text Message", "") | |
#if st.button("Classify"): | |
# if user_input: | |
# prediction = classify_text(user_input) | |
# st.write(f"The message is classified as: \n **{prediction}**") | |
# else: | |
# st.write("Please enter a text message.") | |
import streamlit as st | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast | |
import torch | |
# Load the fine-tuned model and tokenizer | |
model_name = "fine-tuned-model" | |
model = DistilBertForSequenceClassification.from_pretrained(model_name) | |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) | |
# Function to classify text | |
def classify_text(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_id = torch.argmax(logits, dim=1).item() | |
return "spam" if predicted_class_id == 1 else "ham" | |
# Streamlit app | |
st.set_page_config(page_title="Text Message Classification", page_icon="📧") | |
# Header | |
st.title("📧 Text Message Classification") | |
# Text input area | |
#st.subheader("Enter a Text Message:") | |
user_input = st.text_area("Type your message here...", height=50) | |
# Classify button and result display | |
if st.button("Classify"): | |
if user_input: | |
prediction = classify_text(user_input) | |
if prediction == "ham": | |
st.success(f"The message is classified as: **{prediction}**") | |
else: | |
st.error(f"The message is classified as: **{prediction}**") | |
else: | |
st.warning("Please enter a text message.") | |
# Footer | |
st.markdown(""" | |
--- | |
Built with ❤️ using [Streamlit](https://streamlit.io/) and [Transformers](https://huggingface.co/transformers/). | |
""") |