#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ ai_single_response.py - a script to generate a response to a prompt from a pretrained GPT model example: *\gpt2_chatbot> python ai_single_response.py --model "GPT2_conversational_355M_WoW10k" --prompt "hey, what's up?" --time query_gpt_model is used throughout the code, and is the "fundamental" building block of the bot and how everything works. I would recommend testing this function with a few different models. """ import argparse import pprint as pp import sys import time import warnings from datetime import datetime from pathlib import Path import logging logging.basicConfig( filename=f"LOGFILE-{Path(__file__).stem}.log", filemode="a", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO, ) from utils import DisableLogger, print_spacer, remove_trailing_punctuation with DisableLogger(): from cleantext import clean warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") from aitextgen import aitextgen def extract_response(full_resp: list, plist: list, verbose: bool = False): """ extract_response - helper fn for ai_single_response.py. By default aitextgen returns the prompt and the response, we just want the response Args: full_resp (list): the full response from aitextgen plist (list): the prompt list verbose (bool, optional): Defaults to False. Returns: response (str): the response, without the prompt """ bot_response = [] for line in full_resp: if line.lower() in plist and len(bot_response) < len(plist): first_loc = plist.index(line) del plist[first_loc] continue bot_response.append(line) full_resp = [clean(ele, lower=False) for ele in bot_response] if verbose: print("the isolated responses are:\n") pp.pprint(full_resp) print_spacer() print("the input prompt was:\n") pp.pprint(plist) print_spacer() return full_resp # list of only the model generated responses def get_bot_response( name_resp: str, model_resp: list, name_spk: str, verbose: bool = False ): """ get_bot_response - gets the bot response to a prompt, checking to ensure that additional statements by the "speaker" are not included in the response. Args: name_resp (str): the name of the responder model_resp (list): the model response name_spk (str): the name of the speaker verbose (bool, optional): Defaults to False. Returns: bot_response (str): the bot response, isolated down to just text without the "name tokens" or further messages from the speaker. """ fn_resp = [] name_counter = 0 break_safe = False for resline in model_resp: if name_resp.lower() in resline.lower(): name_counter += 1 break_safe = True continue if ":" in resline and name_resp.lower() not in resline.lower(): break if name_spk.lower() in resline.lower() and not break_safe: break else: fn_resp.append(resline) if verbose: print("the full response is:\n") print("\n".join(fn_resp)) return fn_resp def query_gpt_model( folder_path: str or Path, prompt_msg: str, conversation_history: list = None, speaker: str = None, responder: str = None, resp_length: int = 48, kparam: int = 20, temp: float = 0.4, top_p: float = 0.9, aitextgen_obj=None, verbose: bool = False, use_gpu: bool = False, ): """ query_gpt_model - queries the GPT model and returns the first response by Args: folder_path (str or Path): the path to the model folder prompt_msg (str): the prompt message conversation_history (list, optional): the conversation history. Defaults to None. speaker (str, optional): the name of the speaker. Defaults to None. responder (str, optional): the name of the responder. Defaults to None. resp_length (int, optional): the length of the response in tokens. Defaults to 48. kparam (int, optional): the k parameter for the top_k. Defaults to 40. temp (float, optional): the temperature for the softmax. Defaults to 0.7. top_p (float, optional): the top_p parameter for nucleus sampling. Defaults to 0.9. aitextgen_obj (_type_, optional): a pre-loaded aitextgen object. Defaults to None. verbose (bool, optional): Defaults to False. use_gpu (bool, optional): Defaults to False. Returns: model_resp (dict): the model response, as a dict with the following keys: out_text (str) the generated text and full_conv (dict) the conversation history """ try: ai = ( aitextgen_obj if aitextgen_obj else aitextgen( model_folder=folder_path, to_gpu=use_gpu, ) ) except Exception as e: print(f"Unable to initialize aitextgen model: {e}") print( f"Check model folder: {folder_path}, run the download_models.py script to download the model files" ) sys.exit(1) mpath = Path(folder_path) mpath_base = ( mpath.stem ) # only want the base name of the model folder for check below # these models used person alpha and person beta in training mod_ids = ["natqa", "dd", "trivqa", "wow", "conversational"] if any(substring in str(mpath_base).lower() for substring in mod_ids): speaker = "person alpha" if speaker is None else speaker responder = "person beta" if responder is None else responder else: if verbose: print("speaker and responder not set - using default") speaker = "person" if speaker is None else speaker responder = "george robot" if responder is None else responder prompt_list = ( conversation_history if conversation_history is not None else [] ) # track conversation prompt_list.append(speaker.lower() + ":" + "\n") prompt_list.append(prompt_msg.lower() + "\n") prompt_list.append("\n") prompt_list.append(responder.lower() + ":" + "\n") this_prompt = "".join(prompt_list) pr_len = len(this_prompt) if verbose: print("overall prompt:\n") pp.pprint(prompt_list) # call the model print("\n... generating...") this_result = ai.generate( n=1, top_k=kparam, batch_size=128, # the prompt input counts for text length constraints max_length=resp_length + pr_len, min_length=16 + pr_len, prompt=this_prompt, temperature=temp, top_p=top_p, do_sample=True, return_as_list=True, use_cache=True, ) if verbose: print("\n... generated:\n") pp.pprint(this_result) # for debugging # process the full result to get the ~bot response~ piece this_result = str(this_result[0]).split("\n") input_prompt = this_prompt.split("\n") diff_list = extract_response( this_result, input_prompt, verbose=verbose ) # isolate the responses from the prompts # extract the bot response from the model generated text bot_dialogue = get_bot_response( name_resp=responder, model_resp=diff_list, name_spk=speaker, verbose=verbose ) bot_resp = ", ".join(bot_dialogue) bot_resp = remove_trailing_punctuation( bot_resp.strip() ) # remove trailing punctuation to seem more natural if verbose: print("\n... bot response:\n") pp.pprint(bot_resp) prompt_list.append(bot_resp + "\n") prompt_list.append("\n") conv_history = {} for i, line in enumerate(prompt_list): if i not in conv_history.keys(): conv_history[i] = line if verbose: print("\n... conversation history:\n") pp.pprint(conv_history) print("\nfinished!") # return the bot response and the full conversation return {"out_text": bot_resp, "full_conv": conv_history} # Set up the parsing of command-line arguments def get_parser(): """ get_parser [a helper function for the argparse module] Returns: argparse.ArgumentParser """ parser = argparse.ArgumentParser( description="submit a message and have a pretrained GPT model respond" ) parser.add_argument( "-p", "--prompt", required=True, # MUST HAVE A PROMPT type=str, help="the message the bot is supposed to respond to. Prompt is said by speaker, answered by responder.", ) parser.add_argument( "-m", "--model", required=False, type=str, default="distilgpt2-tiny-conversational", help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " "config.json). You can also pass the huggingface model name (e.g. distilgpt2)", ) parser.add_argument( "-s", "--speaker", required=False, default=None, help="Who the prompt is from (to the bot). Primarily relevant to bots trained on multi-individual chat data", ) parser.add_argument( "-r", "--responder", required=False, default="person beta", help="who the responder is. Primarily relevant to bots trained on multi-individual chat data", ) parser.add_argument( "--topk", required=False, type=int, default=20, help="how many responses to sample (positive integer). lower = more random responses", ) parser.add_argument( "--temp", required=False, type=float, default=0.4, help="specify temperature hyperparam (0-1). roughly considered as 'model creativity'", ) parser.add_argument( "--topp", required=False, type=float, default=0.9, help="nucleus sampling frac (0-1). aka: what fraction of possible options are considered?", ) parser.add_argument( "--resp_length", required=False, type=int, default=50, help="max length of the response (positive integer)", ) parser.add_argument( "-v", "--verbose", default=False, action="store_true", help="pass this argument if you want all the printouts", ) parser.add_argument( "-rt", "--time", default=False, action="store_true", help="pass this argument if you want to know runtime", ) parser.add_argument( "--use_gpu", required=False, action="store_true", help="use gpu if available", ) return parser if __name__ == "__main__": # parse the command line arguments args = get_parser().parse_args() query = args.prompt model_dir = str(args.model) model_loc = Path.cwd() / model_dir if "/" not in model_dir else model_dir spkr = args.speaker rspndr = args.responder k_results = args.topk my_temp = args.temp my_top_p = args.topp resp_length = args.resp_length assert resp_length > 0, "response length must be positive" want_verbose = args.verbose want_rt = args.time use_gpu = args.use_gpu st = time.perf_counter() resp = query_gpt_model( folder_path=model_loc, prompt_msg=query, speaker=spkr, responder=rspndr, kparam=k_results, temp=my_temp, top_p=my_top_p, resp_length=resp_length, verbose=want_verbose, use_gpu=use_gpu, ) output = resp["out_text"] pp.pprint(output, indent=4) rt = round(time.perf_counter() - st, 1) if want_rt: print("took {runtime} seconds to generate. \n".format(runtime=rt)) if want_verbose: print("finished - ", datetime.now()) p_list = resp["full_conv"] print("A transcript of your chat is as follows: \n") p_list = [item.strip() for item in p_list] pp.pprint(p_list)