Spaces:
Running
Running
# Imports | |
import gradio as gr | |
from sklearn.linear_model import LogisticRegression | |
import pickle5 as pickle | |
import re | |
import string | |
import nltk | |
from nltk.corpus import stopwords | |
nltk.download('stopwords') | |
from sklearn.feature_extraction.text import CountVectorizer | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
# file name | |
#lr_filename = 'lr_021223.pkl' | |
# Load model from pickle file | |
# model = pickle.load(open(lr_filename, 'rb')) | |
# Process input text, including removing stopwords, converting to lowercase, and removing punctuation | |
stop = stopwords.words('english') | |
def process_text(text): | |
text = [word for word in text.split() if word not in stop] | |
text = str(text).lower() | |
text = re.sub( | |
f"[{re.escape(string.punctuation)}]", " ", text | |
) | |
text = " ".join(text.split()) | |
return text | |
# Vectorize input text | |
vec = CountVectorizer() | |
''' | |
def vectorize_text(text): | |
#text = process_text(text) | |
#text = vectorizer.fit_transform([text]) | |
#return text | |
''' | |
# Valid input for the model so number of features match | |
def predict(text): | |
# Load the pickled model | |
filename = 'lr_021223.pkl' | |
loaded_model = pickle.load(open(filename, 'rb')) | |
text = process_text(text) | |
text = vec.transform([text]) | |
prediction = loaded_model.predict(text) | |
return prediction | |
''' | |
Prediction function | |
#def predict(text): | |
#text = vectorize_text(text) | |
#prediction = model.predict(text) | |
#return prediction | |
''' | |
# Define interface | |
demo = gr.Interface(fn=predict, | |
title="Text Classification Demo", | |
description="This is a demo of a text classification model using Logistic Regression.", | |
inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"), | |
outputs=gr.Textbox(label="Predicted Label: Other: 0, Healthcare: 1, Technology: 2", lines=2, placeholder='Predicted label will appear here...'), | |
allow_flagging='never' | |
) | |
demo.launch() | |