File size: 4,099 Bytes
925e416
 
 
cbd3f4c
0f34ca3
 
 
 
 
 
438fdb3
07b8953
 
925e416
e3527b4
 
 
 
 
438fdb3
5a7c4c9
0f34ca3
438fdb3
0f34ca3
438fdb3
0f34ca3
 
 
 
 
 
 
e3527b4
 
 
0f34ca3
e3527b4
 
 
 
0f34ca3
4a02fcc
e3527b4
 
 
 
 
0e46e88
e3527b4
87806b7
2aa7ddb
e3527b4
 
 
 
 
 
 
 
07b8953
8630a47
07b8953
 
 
 
 
 
 
 
 
 
 
 
e3527b4
07b8953
 
 
925e416
07b8953
 
 
 
 
 
 
 
 
 
c71ee2b
 
733f064
925e416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Imports
import gradio as gr
from sklearn.linear_model import LogisticRegression
import pickle5 as pickle
import re
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from flair.data import Sentence
from flair.models import SequenceTagger

# file name
#lr_filename = 'lr_021223.pkl'

# Load model from pickle file
# model = pickle.load(open(lr_filename, 'rb'))


# Process input text, including removing stopwords, converting to lowercase, and removing punctuation
stop = stopwords.words('english')
def process_text(text):
    text = [word for word in text.split() if word not in stop]
    text = str(text).lower()
    text = re.sub(
        f"[{re.escape(string.punctuation)}]", " ", text
    )
    text = " ".join(text.split())
    return text

# Vectorize input text
vec = CountVectorizer()
'''
def vectorize_text(text):
    #text = process_text(text)
    #text = vectorizer.fit_transform([text])
    #return text
'''

# Valid input for the model so number of features match
def predict(text):
    # Load the pickled model
    filename = 'lr_021223.pkl'
    loaded_model = pickle.load(open(filename, 'rb'))
    text = vec.transform([text])
    text = process_text(text)
    prediction = loaded_model.predict(text)
    return prediction

'''
Prediction function
#def predict(text):
    #text = vectorize_text(text)
    #prediction = model.predict(text)
    #return prediction
'''

# Specify NER model
tagger = SequenceTagger.load('best-model.pt') # SequenceTagger.load('best-model.pt')

# Runs NER on input text
def run_ner(input_text):
    sentence = Sentence(input_text)
    tagger.predict(sentence)
    output = []
    for entity in sentence.get_spans('ner'):
        output.append({'entity': entity.get_label('ner').value, 'word': entity.text, 'start': entity.start_position, 'end': entity.end_position})
    return {"text": input_text, "entities": output}

# Run both models, and return a tuple of their results
def run_models(input_text):
    prediction = 0 # This "0" is a placeholder to avoid errors; once the LR model is working, use this instead: prediction = predict(input_text)
    entities = run_ner(input_text)
    return prediction, entities

# Define interface
demo = gr.Interface(fn=run_models,
                    title="Text Classification & Named Entity Recognition Demo",
                    description="This is a demo of a text classification model using logistic regression as well as a named entity recognition model. Enter in some text or use one of the provided examples. Note that common named entity recognition tags include **geo** (geographical entity), **org** (organization), **per** (person), and **tim** (time).",
                    article='*This demo is based on Logistic Regression and Named Entity Recognition models trained by Curtis Pond and Julia Nickerson as part of their FourthBrain capstone project. For more information, check out their [GitHub repo](https://github.com/nickersonj/glg-capstone).*',
                    inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"),
                    outputs=[gr.Textbox(label="Predicted Classification Label: Other: 0, Healthcare: 1, Technology: 2", lines=2, placeholder='Predicted label will appear here...'), 
                             gr.HighlightedText(label='Named Entity Recognition Results')],
                    # These examples are just placeholders; once the LR model is working, we can use longer example text such as paragraphs
                    examples=['The indictments were announced Tuesday by the Justice Department in Cairo.', "In 2019, the men's singles winner was Novak Djokovic who defeated Roger Federer in a tournament taking place in the United Kingdom.", 'In a study published by the American Heart Association on January 18, researchers at the Johns Hopkins School of Medicine found that meal timing did not impact weight.'],
                    allow_flagging='never'
)

demo.launch()