Hamza1702 commited on
Commit
c818b2a
1 Parent(s): de6262a

Create ai_single_response.py

Browse files
Files changed (1) hide show
  1. ai_single_response.py +383 -0
ai_single_response.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ ai_single_response.py - a script to generate a response to a prompt from a pretrained GPT model
5
+
6
+ example:
7
+ *\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time
8
+
9
+ query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models.
10
+
11
+ """
12
+ import argparse
13
+ import pprint as pp
14
+ import sys
15
+ import time
16
+ import warnings
17
+ from datetime import datetime
18
+ from pathlib import Path
19
+ import logging
20
+
21
+ logging.basicConfig(
22
+ filename=f"LOGFILE-{Path(__file__).stem}.log",
23
+ filemode="a",
24
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
25
+ level=logging.INFO,
26
+ )
27
+
28
+ from utils import DisableLogger, print_spacer, remove_trailing_punctuation
29
+
30
+ with DisableLogger():
31
+ from cleantext import clean
32
+
33
+ warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*")
34
+
35
+ from aitextgen import aitextgen
36
+
37
+
38
+ def extract_response(full_resp: list, plist: list, verbose: bool = False):
39
+ """
40
+ extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response
41
+
42
+ Args:
43
+ full_resp (list): the full response from aitextgen
44
+ plist (list): the prompt list
45
+ verbose (bool, optional): Defaults to False.
46
+
47
+ Returns:
48
+ response (str): the response, without the prompt
49
+ """
50
+ bot_response = []
51
+ for line in full_resp:
52
+ if line.lower() in plist and len(bot_response) < len(plist):
53
+ first_loc = plist.index(line)
54
+ del plist[first_loc]
55
+ continue
56
+ bot_response.append(line)
57
+ full_resp = [clean(ele, lower=False) for ele in bot_response]
58
+
59
+ if verbose:
60
+ print("the isolated responses are:\n")
61
+ pp.pprint(full_resp)
62
+ print_spacer()
63
+ print("the input prompt was:\n")
64
+ pp.pprint(plist)
65
+ print_spacer()
66
+ return full_resp # list of only the model generated responses
67
+
68
+
69
+ def get_bot_response(
70
+ name_resp: str, model_resp: list, name_spk: str, verbose: bool = False
71
+ ):
72
+ """
73
+ get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response.
74
+
75
+ Args:
76
+ name_resp (str): the name of the responder
77
+ model_resp (list): the model response
78
+ name_spk (str): the name of the speaker
79
+ verbose (bool, optional): Defaults to False.
80
+
81
+ Returns:
82
+ bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker.
83
+ """
84
+
85
+ fn_resp = []
86
+
87
+ name_counter = 0
88
+ break_safe = False
89
+ for resline in model_resp:
90
+ if name_resp.lower() in resline.lower():
91
+ name_counter += 1
92
+ break_safe = True
93
+ continue
94
+ if ":" in resline and name_resp.lower() not in resline.lower():
95
+ break
96
+ if name_spk.lower() in resline.lower() and not break_safe:
97
+ break
98
+ else:
99
+ fn_resp.append(resline)
100
+ if verbose:
101
+ print("the full response is:\n")
102
+ print("\n".join(fn_resp))
103
+
104
+ return fn_resp
105
+
106
+
107
+ def query_gpt_model(
108
+ folder_path: str or Path,
109
+ prompt_msg: str,
110
+ conversation_history: list = None,
111
+ speaker: str = None,
112
+ responder: str = None,
113
+ resp_length: int = 48,
114
+ kparam: int = 20,
115
+ temp: float = 0.4,
116
+ top_p: float = 0.9,
117
+ aitextgen_obj=None,
118
+ verbose: bool = False,
119
+ use_gpu: bool = False,
120
+ ):
121
+ """
122
+ query_gpt_model - queries the GPT model and returns the first response by <responder>
123
+
124
+ Args:
125
+ folder_path (str or Path): the path to the model folder
126
+ prompt_msg (str): the prompt message
127
+ conversation_history (list, optional): the conversation history. Defaults to None.
128
+ speaker (str, optional): the name of the speaker. Defaults to None.
129
+ responder (str, optional): the name of the responder. Defaults to None.
130
+ resp_length (int, optional): the length of the response in tokens. Defaults to 48.
131
+ kparam (int, optional): the k parameter for the top_k. Defaults to 40.
132
+ temp (float, optional): the temperature for the softmax. Defaults to 0.7.
133
+ top_p (float, optional): the top_p parameter for nucleus sampling. Defaults to 0.9.
134
+ aitextgen_obj (_type_, optional): a pre-loaded aitextgen object. Defaults to None.
135
+ verbose (bool, optional): Defaults to False.
136
+ use_gpu (bool, optional): Defaults to False.
137
+
138
+ Returns:
139
+ model_resp (dict): the model response, as a dict with the following keys: out_text (str) the generated text and full_conv (dict) the conversation history
140
+ """
141
+
142
+ try:
143
+ ai = (
144
+ aitextgen_obj
145
+ if aitextgen_obj
146
+ else aitextgen(
147
+ model_folder=folder_path,
148
+ to_gpu=use_gpu,
149
+ )
150
+ )
151
+ except Exception as e:
152
+ print(f"Unable to initialize aitextgen model: {e}")
153
+ print(
154
+ f"Check model folder: {folder_path}, run the download_models.py script to download the model files"
155
+ )
156
+ sys.exit(1)
157
+
158
+ mpath = Path(folder_path)
159
+ mpath_base = (
160
+ mpath.stem
161
+ ) # only want the base name of the model folder for check below
162
+ # these models used person alpha and person beta in training
163
+ mod_ids = ["natqa", "dd", "trivqa", "wow", "conversational"]
164
+ if any(substring in str(mpath_base).lower() for substring in mod_ids):
165
+ speaker = "person alpha" if speaker is None else speaker
166
+ responder = "person beta" if responder is None else responder
167
+ else:
168
+ if verbose:
169
+ print("speaker and responder not set - using default")
170
+ speaker = "person" if speaker is None else speaker
171
+ responder = "george robot" if responder is None else responder
172
+
173
+ prompt_list = (
174
+ conversation_history if conversation_history is not None else []
175
+ ) # track conversation
176
+ prompt_list.append(speaker.lower() + ":" + "\n")
177
+ prompt_list.append(prompt_msg.lower() + "\n")
178
+ prompt_list.append("\n")
179
+ prompt_list.append(responder.lower() + ":" + "\n")
180
+ this_prompt = "".join(prompt_list)
181
+ pr_len = len(this_prompt)
182
+ if verbose:
183
+ print("overall prompt:\n")
184
+ pp.pprint(prompt_list)
185
+ # call the model
186
+ print("\n... generating...")
187
+ this_result = ai.generate(
188
+ n=1,
189
+ top_k=kparam,
190
+ batch_size=128,
191
+ # the prompt input counts for text length constraints
192
+ max_length=resp_length + pr_len,
193
+ min_length=16 + pr_len,
194
+ prompt=this_prompt,
195
+ temperature=temp,
196
+ top_p=top_p,
197
+ do_sample=True,
198
+ return_as_list=True,
199
+ use_cache=True,
200
+ )
201
+ if verbose:
202
+ print("\n... generated:\n")
203
+ pp.pprint(this_result) # for debugging
204
+ # process the full result to get the ~bot response~ piece
205
+ this_result = str(this_result[0]).split("\n")
206
+ input_prompt = this_prompt.split("\n")
207
+
208
+ diff_list = extract_response(
209
+ this_result, input_prompt, verbose=verbose
210
+ ) # isolate the responses from the prompts
211
+ # extract the bot response from the model generated text
212
+ bot_dialogue = get_bot_response(
213
+ name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose
214
+ )
215
+ bot_resp = ", ".join(bot_dialogue)
216
+ bot_resp = remove_trailing_punctuation(
217
+ bot_resp.strip()
218
+ ) # remove trailing punctuation to seem more natural
219
+ if verbose:
220
+ print("\n... bot response:\n")
221
+ pp.pprint(bot_resp)
222
+ prompt_list.append(bot_resp + "\n")
223
+ prompt_list.append("\n")
224
+ conv_history = {}
225
+ for i, line in enumerate(prompt_list):
226
+ if i not in conv_history.keys():
227
+ conv_history[i] = line
228
+ if verbose:
229
+ print("\n... conversation history:\n")
230
+ pp.pprint(conv_history)
231
+ print("\nfinished!")
232
+
233
+ # return the bot response and the full conversation
234
+ return {"out_text": bot_resp, "full_conv": conv_history}
235
+
236
+
237
+ # Set up the parsing of command-line arguments
238
+ def get_parser():
239
+ """
240
+ get_parser [a helper function for the argparse module]
241
+
242
+ Returns: argparse.ArgumentParser
243
+ """
244
+
245
+ parser = argparse.ArgumentParser(
246
+ description="submit a message and have a pretrained GPT model respond"
247
+ )
248
+ parser.add_argument(
249
+ "-p",
250
+ "--prompt",
251
+ required=True, # MUST HAVE A PROMPT
252
+ type=str,
253
+ help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.",
254
+ )
255
+ parser.add_argument(
256
+ "-m",
257
+ "--model",
258
+ required=False,
259
+ type=str,
260
+ default="distilgpt2-tiny-conversational",
261
+ help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + "
262
+ "config.json). You can also pass the huggingface model name (e.g. distilgpt2)",
263
+ )
264
+
265
+ parser.add_argument(
266
+ "-s",
267
+ "--speaker",
268
+ required=False,
269
+ default=None,
270
+ help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data",
271
+ )
272
+ parser.add_argument(
273
+ "-r",
274
+ "--responder",
275
+ required=False,
276
+ default="person beta",
277
+ help="who the responder is. Primarily relevant to bots trained on multi-individual chat data",
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--topk",
282
+ required=False,
283
+ type=int,
284
+ default=20,
285
+ help="how many responses to sample (positive integer). lower = more random responses",
286
+ )
287
+
288
+ parser.add_argument(
289
+ "--temp",
290
+ required=False,
291
+ type=float,
292
+ default=0.4,
293
+ help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'",
294
+ )
295
+
296
+ parser.add_argument(
297
+ "--topp",
298
+ required=False,
299
+ type=float,
300
+ default=0.9,
301
+ help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?",
302
+ )
303
+
304
+ parser.add_argument(
305
+ "--resp_length",
306
+ required=False,
307
+ type=int,
308
+ default=50,
309
+ help="max length of the response (positive integer)",
310
+ )
311
+
312
+ parser.add_argument(
313
+ "-v",
314
+ "--verbose",
315
+ default=False,
316
+ action="store_true",
317
+ help="pass this argument if you want all the printouts",
318
+ )
319
+
320
+ parser.add_argument(
321
+ "-rt",
322
+ "--time",
323
+ default=False,
324
+ action="store_true",
325
+ help="pass this argument if you want to know runtime",
326
+ )
327
+
328
+ parser.add_argument(
329
+ "--use_gpu",
330
+ required=False,
331
+ action="store_true",
332
+ help="use gpu if available",
333
+ )
334
+
335
+ return parser
336
+
337
+
338
+ if __name__ == "__main__":
339
+ # parse the command line arguments
340
+ args = get_parser().parse_args()
341
+ query = args.prompt
342
+ model_dir = str(args.model)
343
+ model_loc = Path.cwd() / model_dir if "/" not in model_dir else model_dir
344
+ spkr = args.speaker
345
+ rspndr = args.responder
346
+ k_results = args.topk
347
+ my_temp = args.temp
348
+ my_top_p = args.topp
349
+ resp_length = args.resp_length
350
+ assert resp_length > 0, "response length must be positive"
351
+ want_verbose = args.verbose
352
+ want_rt = args.time
353
+ use_gpu = args.use_gpu
354
+
355
+ st = time.perf_counter()
356
+
357
+ resp = query_gpt_model(
358
+ folder_path=model_loc,
359
+ prompt_msg=query,
360
+ speaker=spkr,
361
+ responder=rspndr,
362
+ kparam=k_results,
363
+ temp=my_temp,
364
+ top_p=my_top_p,
365
+ resp_length=resp_length,
366
+ verbose=want_verbose,
367
+ use_gpu=use_gpu,
368
+ )
369
+
370
+ output = resp["out_text"]
371
+ pp.pprint(output, indent=4)
372
+
373
+ rt = round(time.perf_counter() - st, 1)
374
+
375
+ if want_rt:
376
+ print("took {runtime} seconds to generate. \n".format(runtime=rt))
377
+
378
+ if want_verbose:
379
+ print("finished - ", datetime.now())
380
+ p_list = resp["full_conv"]
381
+ print("A transcript of your chat is as follows: \n")
382
+ p_list = [item.strip() for item in p_list]
383
+ pp.pprint(p_list)