showcase_sujon / app.py
SujonPro24's picture
Create app.py
1da409e verified
#import streamlit as st
#from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
#import torch
## Load the fine-tuned model and tokenizer
#model_name = "fine-tuned-model"
#model = DistilBertForSequenceClassification.from_pretrained(model_name)
#tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
## Function to classify text
#def classify_text(text):
# inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# logits = outputs.logits
# predicted_class_id = torch.argmax(logits, dim=1).item()
# return "spam" if predicted_class_id == 1 else "ham"
## Streamlit app
#st.title("Text Message Classification")
#st.write("Enter a text message and see if it's classified as spam or ham.")
#user_input = st.text_area("Text Message", "")
#if st.button("Classify"):
# if user_input:
# prediction = classify_text(user_input)
# st.write(f"The message is classified as: \n **{prediction}**")
# else:
# st.write("Please enter a text message.")
import streamlit as st
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
import torch
# Load the fine-tuned model and tokenizer
model_name = "fine-tuned-model"
model = DistilBertForSequenceClassification.from_pretrained(model_name)
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
# Function to classify text
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=1).item()
return "spam" if predicted_class_id == 1 else "ham"
# Streamlit app
st.set_page_config(page_title="Text Message Classification", page_icon="📧")
# Header
st.title("📧 Text Message Classification")
# Text input area
#st.subheader("Enter a Text Message:")
user_input = st.text_area("Type your message here...", height=50)
# Classify button and result display
if st.button("Classify"):
if user_input:
prediction = classify_text(user_input)
if prediction == "ham":
st.success(f"The message is classified as: **{prediction}**")
else:
st.error(f"The message is classified as: **{prediction}**")
else:
st.warning("Please enter a text message.")
# Footer
st.markdown("""
---
Built with ❤️ using [Streamlit](https://streamlit.io/) and [Transformers](https://huggingface.co/transformers/).
""")