File size: 1,773 Bytes
925e416
 
 
cbd3f4c
0f34ca3
 
 
 
 
 
438fdb3
925e416
2aa7ddb
21aa3e9
2aa7ddb
438fdb3
c8d5deb
438fdb3
5a7c4c9
0f34ca3
438fdb3
0f34ca3
438fdb3
0f34ca3
 
 
 
 
 
 
 
 
 
 
48b6db6
 
0f34ca3
4a02fcc
f7170b6
4a02fcc
 
925e416
87806b7
c8d5deb
87806b7
2aa7ddb
925e416
 
733f064
7862c87
 
 
f7170b6
925e416
c71ee2b
 
733f064
925e416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Imports
import gradio as gr
from sklearn.linear_model import LogisticRegression
import pickle5 as pickle
import re
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

# file name
lr_filename = 'lr_021223.pkl'

# Load model from pickle file
model = pickle.load(open(lr_filename, 'rb'))


# Process input text, including removing stopwords, converting to lowercase, and removing punctuation
stop = stopwords.words('english')
def process_text(text):
    text = [word for word in text.split() if word not in stop]
    text = str(text).lower()
    text = re.sub(
        f"[{re.escape(string.punctuation)}]", " ", text
    )
    text = " ".join(text.split())
    return text

# Vectorize input text
vectorizer = CountVectorizer()
def vectorize_text(text):
    text = process_text(text)
    text = vectorizer.fit_transform([text])
    return text

# Valid input for the model so number of features match
# Code will go here

# Prediction function
def predict(text):
    text = vectorize_text(text)
    prediction = model.predict(text)
    return prediction


# Define interface
demo = gr.Interface(fn=predict,
                        title="Text Classification Demo",
                        description="This is a demo of a text classification model using Logistic Regression.",
                        inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"),
                        outputs=gr.Textbox(label="Predicted Label: Other: 0, Healthcare: 1, Technology: 2", lines=2, placeholder='Predicted label will appear here...'),
                        allow_flagging='never'
)

demo.launch()