Spaces:
Running
Running
File size: 1,773 Bytes
925e416 cbd3f4c 0f34ca3 438fdb3 925e416 2aa7ddb 21aa3e9 2aa7ddb 438fdb3 c8d5deb 438fdb3 5a7c4c9 0f34ca3 438fdb3 0f34ca3 438fdb3 0f34ca3 48b6db6 0f34ca3 4a02fcc f7170b6 4a02fcc 925e416 87806b7 c8d5deb 87806b7 2aa7ddb 925e416 733f064 7862c87 f7170b6 925e416 c71ee2b 733f064 925e416 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Imports
import gradio as gr
from sklearn.linear_model import LogisticRegression
import pickle5 as pickle
import re
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
# file name
lr_filename = 'lr_021223.pkl'
# Load model from pickle file
model = pickle.load(open(lr_filename, 'rb'))
# Process input text, including removing stopwords, converting to lowercase, and removing punctuation
stop = stopwords.words('english')
def process_text(text):
text = [word for word in text.split() if word not in stop]
text = str(text).lower()
text = re.sub(
f"[{re.escape(string.punctuation)}]", " ", text
)
text = " ".join(text.split())
return text
# Vectorize input text
vectorizer = CountVectorizer()
def vectorize_text(text):
text = process_text(text)
text = vectorizer.fit_transform([text])
return text
# Valid input for the model so number of features match
# Code will go here
# Prediction function
def predict(text):
text = vectorize_text(text)
prediction = model.predict(text)
return prediction
# Define interface
demo = gr.Interface(fn=predict,
title="Text Classification Demo",
description="This is a demo of a text classification model using Logistic Regression.",
inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"),
outputs=gr.Textbox(label="Predicted Label: Other: 0, Healthcare: 1, Technology: 2", lines=2, placeholder='Predicted label will appear here...'),
allow_flagging='never'
)
demo.launch()
|