File size: 6,946 Bytes
df376e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""

deploy-as-bot\gradio_chatbot.py

A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
Note that the URL is displayed once the script is run.

Set the working directory to */deploy-as-bot in terminal before running.

"""
from utils import remove_trailing_punctuation, DisableLogger
import os
import sys
from os.path import dirname

# add the path to the script to the sys.path
sys.path.append(dirname(dirname(os.path.abspath(__file__))))

import gradio as gr
import logging
import argparse
import time
import warnings
from pathlib import Path
from transformers import pipeline
from datetime import datetime
from ai_single_response import query_gpt_model

logging.basicConfig(
    filename=f"LOGFILE-{Path(__file__).stem}.log",
    filemode="a",
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    level=logging.INFO,
)

with DisableLogger():
    from cleantext import clean

warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")

cwd = Path.cwd()
my_cwd = str(cwd.resolve())  # string so it can be passed to os.path() objects


def gramformer_correct(corrector, qphrase: str):
    """
    gramformer_correct - correct a string using a text2textgen pipeline model from transformers

    Args:
        corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
        qphrase (str): [text to be corrected]

    Returns:
        [str]: [corrected text]
    """

    try:
        corrected = corrector(
            clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
        )
        return corrected[0]["generated_text"]
    except:
        print("NOTE - failed to correct with gramformer")
        return clean(
            qphrase
        )  # fallback is to return the cleaned up version of the message


def ask_gpt(message: str, sender: str = ""):
    """
    ask_gpt - queries the relevant model with a prompt message and (optional) speaker name.
    nnote this version is modified w.r.t gradio local server deploy

    Args:
        message (str): prompt message to respond to
        sender (str, optional): speaker aka who said the message. Defaults to "".

    Returns:
        [str]: [model response as a string]
    """
    st = time.time()
    prompt = clean(message)  # clean user input
    prompt = prompt.strip()  # get rid of any extra whitespace
    if len(prompt) > 100:
        prompt = prompt[:100]  # truncate
    sender = clean(sender.strip())
    if len(sender) > 2:
        try:
            prompt_speaker = clean(sender)
        except:
            prompt_speaker = None  # fallback
    else:
        prompt_speaker = None  # fallback

    resp = query_gpt_model(
        folder_path=model_loc,
        prompt_msg=prompt,
        speaker=prompt_speaker,
        kparam=150,  # top k responses
        temp=0.75,  # temperature
        top_p=0.65,  # nucleus sampling
    )
    bot_resp = gramformer_correct(
        corrector, qphrase=resp["out_text"]
    )  # correct grammar
    bot_resp = remove_trailing_punctuation(
        bot_resp
    )  # remove trailing punctuation to seem more natural
    rt = round(time.time() - st, 2)
    print(f"took {rt} sec to respond")

    return bot_resp


def chat(first_and_last_name, message):
    """
    chat - helper function that makes the whole gradio thing work.

    Args:
        first_and_last_name (str or None): [speaker of the prompt, if provided]
        message (str): [description]

    Returns:
        [str]: [returns an html string to display]
    """
    history = gr.get_state() or []
    response = ask_gpt(message, sender=first_and_last_name)
    history.append(("You: " + message, " GPT-Model: " + response + " [end] "))
    gr.set_state(history)  # save the history
    html = ""
    for user_msg, resp_msg in history:
        html += f"{user_msg}"
        html += f"{resp_msg}"
    html += ""
    return html


def get_parser():
    """
    get_parser - a helper function for the argparse module

    Returns:
        [argparse.ArgumentParser]: [the argparser relevant for this script]
    """

    parser = argparse.ArgumentParser(
        description="host a chatbot on gradio",
    )
    parser.add_argument(
        "--model",
        required=False,
        type=str,
        default="GPT2_trivNatQAdailydia_774M_175Ksteps",  # folder name of model
        help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
        "config.json). No models? Run the script download_models.py",
    )

    parser.add_argument(
        "--gram-model",
        required=False,
        type=str,
        default="prithivida/grammar_error_correcter_v1",  # huggingface model
        help="text2text generation model ID from huggingface for the model to correct grammar",
    )

    return parser


if __name__ == "__main__":
    args = get_parser().parse_args()
    default_model = str(args.model)
    model_loc = cwd.parent / default_model
    model_loc = str(model_loc.resolve())
    gram_model = args.gram_model

    # init items for the pipeline
    iface = gr.Interface(
        chat,
        inputs=["text", "text"],
        outputs="html",
        title=f"GPT-Chatbot Demo: {default_model} Model",
        description=f"A basic interface with a GPT2-based model, specifically {default_model}. Treat it like a friend!",
        article="**Important Notes & About:**\n"
        "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
        "2. entering a username is completely optional.\n"
        "3. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says sshould be fact-checked before being regarded as a true statement.\n ",
        css="""
        .chatbox {display:flex;flex-direction:column}
        .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
        .user_msg {background-color:cornflowerblue;color:white;align-self:start}
        .resp_msg {background-color:lightgray;align-self:self-end}
    """,
        allow_screenshot=True,
        allow_flagging=True,  # allow users to flag responses as inappropriate
        flagging_dir="gradio_data",
        flagging_options=[
            "great response",
            "doesn't make sense",
            "bad/offensive response",
        ],
        enable_queue=True,  # allows for dealing with multiple users simultaneously
        theme="darkhuggingface",
    )

    corrector = pipeline("text2text-generation", model=gram_model, device=-1)
    print("Finished loading the gramformer model - ", datetime.now())
    print(f"using model stored here: \n {model_loc} \n")

    # launch the gradio interface and start the server
    iface.launch(share=True)