Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
deploy-as-bot\gradio_chatbot.py
|
4 |
+
|
5 |
+
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
|
6 |
+
Note that the URL is displayed once the script is run.
|
7 |
+
|
8 |
+
Set the working directory to */deploy-as-bot in terminal before running.
|
9 |
+
|
10 |
+
"""
|
11 |
+
from utils import remove_trailing_punctuation, DisableLogger
|
12 |
+
import os
|
13 |
+
import sys
|
14 |
+
from os.path import dirname
|
15 |
+
|
16 |
+
# add the path to the script to the sys.path
|
17 |
+
sys.path.append(dirname(dirname(os.path.abspath(__file__))))
|
18 |
+
|
19 |
+
import gradio as gr
|
20 |
+
import logging
|
21 |
+
import argparse
|
22 |
+
import time
|
23 |
+
import warnings
|
24 |
+
from pathlib import Path
|
25 |
+
from transformers import pipeline
|
26 |
+
from datetime import datetime
|
27 |
+
from ai_single_response import query_gpt_model
|
28 |
+
|
29 |
+
logging.basicConfig(
|
30 |
+
filename=f"LOGFILE-{Path(__file__).stem}.log",
|
31 |
+
filemode="a",
|
32 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
33 |
+
level=logging.INFO,
|
34 |
+
)
|
35 |
+
|
36 |
+
with DisableLogger():
|
37 |
+
from cleantext import clean
|
38 |
+
|
39 |
+
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
|
40 |
+
|
41 |
+
cwd = Path.cwd()
|
42 |
+
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
|
43 |
+
|
44 |
+
|
45 |
+
def gramformer_correct(corrector, qphrase: str):
|
46 |
+
"""
|
47 |
+
gramformer_correct - correct a string using a text2textgen pipeline model from transformers
|
48 |
+
|
49 |
+
Args:
|
50 |
+
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
|
51 |
+
qphrase (str): [text to be corrected]
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
[str]: [corrected text]
|
55 |
+
"""
|
56 |
+
|
57 |
+
try:
|
58 |
+
corrected = corrector(
|
59 |
+
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
|
60 |
+
)
|
61 |
+
return corrected[0]["generated_text"]
|
62 |
+
except:
|
63 |
+
print("NOTE - failed to correct with gramformer")
|
64 |
+
return clean(
|
65 |
+
qphrase
|
66 |
+
) # fallback is to return the cleaned up version of the message
|
67 |
+
|
68 |
+
|
69 |
+
def ask_gpt(message: str, sender: str = ""):
|
70 |
+
"""
|
71 |
+
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name.
|
72 |
+
nnote this version is modified w.r.t gradio local server deploy
|
73 |
+
|
74 |
+
Args:
|
75 |
+
message (str): prompt message to respond to
|
76 |
+
sender (str, optional): speaker aka who said the message. Defaults to "".
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
[str]: [model response as a string]
|
80 |
+
"""
|
81 |
+
st = time.time()
|
82 |
+
prompt = clean(message) # clean user input
|
83 |
+
prompt = prompt.strip() # get rid of any extra whitespace
|
84 |
+
if len(prompt) > 100:
|
85 |
+
prompt = prompt[:100] # truncate
|
86 |
+
sender = clean(sender.strip())
|
87 |
+
if len(sender) > 2:
|
88 |
+
try:
|
89 |
+
prompt_speaker = clean(sender)
|
90 |
+
except:
|
91 |
+
prompt_speaker = None # fallback
|
92 |
+
else:
|
93 |
+
prompt_speaker = None # fallback
|
94 |
+
|
95 |
+
resp = query_gpt_model(
|
96 |
+
folder_path=model_loc,
|
97 |
+
prompt_msg=prompt,
|
98 |
+
speaker=prompt_speaker,
|
99 |
+
kparam=150, # top k responses
|
100 |
+
temp=0.75, # temperature
|
101 |
+
top_p=0.65, # nucleus sampling
|
102 |
+
)
|
103 |
+
bot_resp = gramformer_correct(
|
104 |
+
corrector, qphrase=resp["out_text"]
|
105 |
+
) # correct grammar
|
106 |
+
bot_resp = remove_trailing_punctuation(
|
107 |
+
bot_resp
|
108 |
+
) # remove trailing punctuation to seem more natural
|
109 |
+
rt = round(time.time() - st, 2)
|
110 |
+
print(f"took {rt} sec to respond")
|
111 |
+
|
112 |
+
return bot_resp
|
113 |
+
|
114 |
+
|
115 |
+
def chat(first_and_last_name, message):
|
116 |
+
"""
|
117 |
+
chat - helper function that makes the whole gradio thing work.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
first_and_last_name (str or None): [speaker of the prompt, if provided]
|
121 |
+
message (str): [description]
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
[str]: [returns an html string to display]
|
125 |
+
"""
|
126 |
+
history = gr.get_state() or []
|
127 |
+
response = ask_gpt(message, sender=first_and_last_name)
|
128 |
+
history.append(("You: " + message, " GPT-Model: " + response + " [end] "))
|
129 |
+
gr.set_state(history) # save the history
|
130 |
+
html = ""
|
131 |
+
for user_msg, resp_msg in history:
|
132 |
+
html += f"{user_msg}"
|
133 |
+
html += f"{resp_msg}"
|
134 |
+
html += ""
|
135 |
+
return html
|
136 |
+
|
137 |
+
|
138 |
+
def get_parser():
|
139 |
+
"""
|
140 |
+
get_parser - a helper function for the argparse module
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
[argparse.ArgumentParser]: [the argparser relevant for this script]
|
144 |
+
"""
|
145 |
+
|
146 |
+
parser = argparse.ArgumentParser(
|
147 |
+
description="host a chatbot on gradio",
|
148 |
+
)
|
149 |
+
parser.add_argument(
|
150 |
+
"--model",
|
151 |
+
required=False,
|
152 |
+
type=str,
|
153 |
+
default="GPT2_trivNatQAdailydia_774M_175Ksteps", # folder name of model
|
154 |
+
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
|
155 |
+
"config.json). No models? Run the script download_models.py",
|
156 |
+
)
|
157 |
+
|
158 |
+
parser.add_argument(
|
159 |
+
"--gram-model",
|
160 |
+
required=False,
|
161 |
+
type=str,
|
162 |
+
default="prithivida/grammar_error_correcter_v1", # huggingface model
|
163 |
+
help="text2text generation model ID from huggingface for the model to correct grammar",
|
164 |
+
)
|
165 |
+
|
166 |
+
return parser
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
args = get_parser().parse_args()
|
171 |
+
default_model = str(args.model)
|
172 |
+
model_loc = cwd.parent / default_model
|
173 |
+
model_loc = str(model_loc.resolve())
|
174 |
+
gram_model = args.gram_model
|
175 |
+
|
176 |
+
# init items for the pipeline
|
177 |
+
iface = gr.Interface(
|
178 |
+
chat,
|
179 |
+
inputs=["text", "text"],
|
180 |
+
outputs="html",
|
181 |
+
title=f"GPT-Chatbot Demo: {default_model} Model",
|
182 |
+
description=f"A basic interface with a GPT2-based model, specifically {default_model}. Treat it like a friend!",
|
183 |
+
article="**Important Notes & About:**\n"
|
184 |
+
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
|
185 |
+
"2. entering a username is completely optional.\n"
|
186 |
+
"3. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says sshould be fact-checked before being regarded as a true statement.\n ",
|
187 |
+
css="""
|
188 |
+
.chatbox {display:flex;flex-direction:column}
|
189 |
+
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
190 |
+
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
191 |
+
.resp_msg {background-color:lightgray;align-self:self-end}
|
192 |
+
""",
|
193 |
+
allow_screenshot=True,
|
194 |
+
allow_flagging=True, # allow users to flag responses as inappropriate
|
195 |
+
flagging_dir="gradio_data",
|
196 |
+
flagging_options=[
|
197 |
+
"great response",
|
198 |
+
"doesn't make sense",
|
199 |
+
"bad/offensive response",
|
200 |
+
],
|
201 |
+
enable_queue=True, # allows for dealing with multiple users simultaneously
|
202 |
+
theme="darkhuggingface",
|
203 |
+
)
|
204 |
+
|
205 |
+
corrector = pipeline("text2text-generation", model=gram_model, device=-1)
|
206 |
+
print("Finished loading the gramformer model - ", datetime.now())
|
207 |
+
print(f"using model stored here: \n {model_loc} \n")
|
208 |
+
|
209 |
+
# launch the gradio interface and start the server
|
210 |
+
iface.launch(share=True)
|