Hamza1702 commited on
Commit
df376e8
1 Parent(s): be41157

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ deploy-as-bot\gradio_chatbot.py
4
+
5
+ A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses.
6
+ Note that the URL is displayed once the script is run.
7
+
8
+ Set the working directory to */deploy-as-bot in terminal before running.
9
+
10
+ """
11
+ from utils import remove_trailing_punctuation, DisableLogger
12
+ import os
13
+ import sys
14
+ from os.path import dirname
15
+
16
+ # add the path to the script to the sys.path
17
+ sys.path.append(dirname(dirname(os.path.abspath(__file__))))
18
+
19
+ import gradio as gr
20
+ import logging
21
+ import argparse
22
+ import time
23
+ import warnings
24
+ from pathlib import Path
25
+ from transformers import pipeline
26
+ from datetime import datetime
27
+ from ai_single_response import query_gpt_model
28
+
29
+ logging.basicConfig(
30
+ filename=f"LOGFILE-{Path(__file__).stem}.log",
31
+ filemode="a",
32
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
33
+ level=logging.INFO,
34
+ )
35
+
36
+ with DisableLogger():
37
+ from cleantext import clean
38
+
39
+ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
40
+
41
+ cwd = Path.cwd()
42
+ my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects
43
+
44
+
45
+ def gramformer_correct(corrector, qphrase: str):
46
+ """
47
+ gramformer_correct - correct a string using a text2textgen pipeline model from transformers
48
+
49
+ Args:
50
+ corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model]
51
+ qphrase (str): [text to be corrected]
52
+
53
+ Returns:
54
+ [str]: [corrected text]
55
+ """
56
+
57
+ try:
58
+ corrected = corrector(
59
+ clean(qphrase), return_text=True, clean_up_tokenization_spaces=True
60
+ )
61
+ return corrected[0]["generated_text"]
62
+ except:
63
+ print("NOTE - failed to correct with gramformer")
64
+ return clean(
65
+ qphrase
66
+ ) # fallback is to return the cleaned up version of the message
67
+
68
+
69
+ def ask_gpt(message: str, sender: str = ""):
70
+ """
71
+ ask_gpt - queries the relevant model with a prompt message and (optional) speaker name.
72
+ nnote this version is modified w.r.t gradio local server deploy
73
+
74
+ Args:
75
+ message (str): prompt message to respond to
76
+ sender (str, optional): speaker aka who said the message. Defaults to "".
77
+
78
+ Returns:
79
+ [str]: [model response as a string]
80
+ """
81
+ st = time.time()
82
+ prompt = clean(message) # clean user input
83
+ prompt = prompt.strip() # get rid of any extra whitespace
84
+ if len(prompt) > 100:
85
+ prompt = prompt[:100] # truncate
86
+ sender = clean(sender.strip())
87
+ if len(sender) > 2:
88
+ try:
89
+ prompt_speaker = clean(sender)
90
+ except:
91
+ prompt_speaker = None # fallback
92
+ else:
93
+ prompt_speaker = None # fallback
94
+
95
+ resp = query_gpt_model(
96
+ folder_path=model_loc,
97
+ prompt_msg=prompt,
98
+ speaker=prompt_speaker,
99
+ kparam=150, # top k responses
100
+ temp=0.75, # temperature
101
+ top_p=0.65, # nucleus sampling
102
+ )
103
+ bot_resp = gramformer_correct(
104
+ corrector, qphrase=resp["out_text"]
105
+ ) # correct grammar
106
+ bot_resp = remove_trailing_punctuation(
107
+ bot_resp
108
+ ) # remove trailing punctuation to seem more natural
109
+ rt = round(time.time() - st, 2)
110
+ print(f"took {rt} sec to respond")
111
+
112
+ return bot_resp
113
+
114
+
115
+ def chat(first_and_last_name, message):
116
+ """
117
+ chat - helper function that makes the whole gradio thing work.
118
+
119
+ Args:
120
+ first_and_last_name (str or None): [speaker of the prompt, if provided]
121
+ message (str): [description]
122
+
123
+ Returns:
124
+ [str]: [returns an html string to display]
125
+ """
126
+ history = gr.get_state() or []
127
+ response = ask_gpt(message, sender=first_and_last_name)
128
+ history.append(("You: " + message, " GPT-Model: " + response + " [end] "))
129
+ gr.set_state(history) # save the history
130
+ html = ""
131
+ for user_msg, resp_msg in history:
132
+ html += f"{user_msg}"
133
+ html += f"{resp_msg}"
134
+ html += ""
135
+ return html
136
+
137
+
138
+ def get_parser():
139
+ """
140
+ get_parser - a helper function for the argparse module
141
+
142
+ Returns:
143
+ [argparse.ArgumentParser]: [the argparser relevant for this script]
144
+ """
145
+
146
+ parser = argparse.ArgumentParser(
147
+ description="host a chatbot on gradio",
148
+ )
149
+ parser.add_argument(
150
+ "--model",
151
+ required=False,
152
+ type=str,
153
+ default="GPT2_trivNatQAdailydia_774M_175Ksteps", # folder name of model
154
+ help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
155
+ "config.json). No models? Run the script download_models.py",
156
+ )
157
+
158
+ parser.add_argument(
159
+ "--gram-model",
160
+ required=False,
161
+ type=str,
162
+ default="prithivida/grammar_error_correcter_v1", # huggingface model
163
+ help="text2text generation model ID from huggingface for the model to correct grammar",
164
+ )
165
+
166
+ return parser
167
+
168
+
169
+ if __name__ == "__main__":
170
+ args = get_parser().parse_args()
171
+ default_model = str(args.model)
172
+ model_loc = cwd.parent / default_model
173
+ model_loc = str(model_loc.resolve())
174
+ gram_model = args.gram_model
175
+
176
+ # init items for the pipeline
177
+ iface = gr.Interface(
178
+ chat,
179
+ inputs=["text", "text"],
180
+ outputs="html",
181
+ title=f"GPT-Chatbot Demo: {default_model} Model",
182
+ description=f"A basic interface with a GPT2-based model, specifically {default_model}. Treat it like a friend!",
183
+ article="**Important Notes & About:**\n"
184
+ "1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n"
185
+ "2. entering a username is completely optional.\n"
186
+ "3. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says sshould be fact-checked before being regarded as a true statement.\n ",
187
+ css="""
188
+ .chatbox {display:flex;flex-direction:column}
189
+ .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
190
+ .user_msg {background-color:cornflowerblue;color:white;align-self:start}
191
+ .resp_msg {background-color:lightgray;align-self:self-end}
192
+ """,
193
+ allow_screenshot=True,
194
+ allow_flagging=True, # allow users to flag responses as inappropriate
195
+ flagging_dir="gradio_data",
196
+ flagging_options=[
197
+ "great response",
198
+ "doesn't make sense",
199
+ "bad/offensive response",
200
+ ],
201
+ enable_queue=True, # allows for dealing with multiple users simultaneously
202
+ theme="darkhuggingface",
203
+ )
204
+
205
+ corrector = pipeline("text2text-generation", model=gram_model, device=-1)
206
+ print("Finished loading the gramformer model - ", datetime.now())
207
+ print(f"using model stored here: \n {model_loc} \n")
208
+
209
+ # launch the gradio interface and start the server
210
+ iface.launch(share=True)