Spaces:
Runtime error
Runtime error
""" | |
deploy-as-bot\gradio_chatbot.py | |
A system, method for deploying to Gradio. Gradio is a basic "deploy" interface which allows for other users to test your model from a web URL. It also enables some basic functionality like user flagging for weird responses. | |
Note that the URL is displayed once the script is run. | |
Set the working directory to */deploy-as-bot in terminal before running. | |
""" | |
from utils import remove_trailing_punctuation, DisableLogger | |
import os | |
import sys | |
from os.path import dirname | |
# add the path to the script to the sys.path | |
sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
import gradio as gr | |
import logging | |
import argparse | |
import time | |
import warnings | |
from pathlib import Path | |
from transformers import pipeline | |
from datetime import datetime | |
from ai_single_response import query_gpt_model | |
logging.basicConfig( | |
filename=f"LOGFILE-{Path(__file__).stem}.log", | |
filemode="a", | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
level=logging.INFO, | |
) | |
with DisableLogger(): | |
from cleantext import clean | |
warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
cwd = Path.cwd() | |
my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
def gramformer_correct(corrector, qphrase: str): | |
""" | |
gramformer_correct - correct a string using a text2textgen pipeline model from transformers | |
Args: | |
corrector (transformers.pipeline): [transformers pipeline object, already created w/ relevant model] | |
qphrase (str): [text to be corrected] | |
Returns: | |
[str]: [corrected text] | |
""" | |
try: | |
corrected = corrector( | |
clean(qphrase), return_text=True, clean_up_tokenization_spaces=True | |
) | |
return corrected[0]["generated_text"] | |
except: | |
print("NOTE - failed to correct with gramformer") | |
return clean( | |
qphrase | |
) # fallback is to return the cleaned up version of the message | |
def ask_gpt(message: str, sender: str = ""): | |
""" | |
ask_gpt - queries the relevant model with a prompt message and (optional) speaker name. | |
nnote this version is modified w.r.t gradio local server deploy | |
Args: | |
message (str): prompt message to respond to | |
sender (str, optional): speaker aka who said the message. Defaults to "". | |
Returns: | |
[str]: [model response as a string] | |
""" | |
st = time.time() | |
prompt = clean(message) # clean user input | |
prompt = prompt.strip() # get rid of any extra whitespace | |
if len(prompt) > 100: | |
prompt = prompt[:100] # truncate | |
sender = clean(sender.strip()) | |
if len(sender) > 2: | |
try: | |
prompt_speaker = clean(sender) | |
except: | |
prompt_speaker = None # fallback | |
else: | |
prompt_speaker = None # fallback | |
resp = query_gpt_model( | |
folder_path=model_loc, | |
prompt_msg=prompt, | |
speaker=prompt_speaker, | |
kparam=150, # top k responses | |
temp=0.75, # temperature | |
top_p=0.65, # nucleus sampling | |
) | |
bot_resp = gramformer_correct( | |
corrector, qphrase=resp["out_text"] | |
) # correct grammar | |
bot_resp = remove_trailing_punctuation( | |
bot_resp | |
) # remove trailing punctuation to seem more natural | |
rt = round(time.time() - st, 2) | |
print(f"took {rt} sec to respond") | |
return bot_resp | |
def chat(first_and_last_name, message): | |
""" | |
chat - helper function that makes the whole gradio thing work. | |
Args: | |
first_and_last_name (str or None): [speaker of the prompt, if provided] | |
message (str): [description] | |
Returns: | |
[str]: [returns an html string to display] | |
""" | |
history = gr.get_state() or [] | |
response = ask_gpt(message, sender=first_and_last_name) | |
history.append(("You: " + message, " GPT-Model: " + response + " [end] ")) | |
gr.set_state(history) # save the history | |
html = "" | |
for user_msg, resp_msg in history: | |
html += f"{user_msg}" | |
html += f"{resp_msg}" | |
html += "" | |
return html | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
Returns: | |
[argparse.ArgumentParser]: [the argparser relevant for this script] | |
""" | |
parser = argparse.ArgumentParser( | |
description="host a chatbot on gradio", | |
) | |
parser.add_argument( | |
"--model", | |
required=False, | |
type=str, | |
default="GPT2_trivNatQAdailydia_774M_175Ksteps", # folder name of model | |
help="folder - with respect to git directory of your repo that has the model files in it (pytorch.bin + " | |
"config.json). No models? Run the script download_models.py", | |
) | |
parser.add_argument( | |
"--gram-model", | |
required=False, | |
type=str, | |
default="prithivida/grammar_error_correcter_v1", # huggingface model | |
help="text2text generation model ID from huggingface for the model to correct grammar", | |
) | |
return parser | |
if __name__ == "__main__": | |
args = get_parser().parse_args() | |
default_model = str(args.model) | |
model_loc = cwd.parent / default_model | |
model_loc = str(model_loc.resolve()) | |
gram_model = args.gram_model | |
# init items for the pipeline | |
iface = gr.Interface( | |
chat, | |
inputs=["text", "text"], | |
outputs="html", | |
title=f"GPT-Chatbot Demo: {default_model} Model", | |
description=f"A basic interface with a GPT2-based model, specifically {default_model}. Treat it like a friend!", | |
article="**Important Notes & About:**\n" | |
"1. the model can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
"2. entering a username is completely optional.\n" | |
"3. the model started from a pretrained checkpoint, and was trained on several different datasets. Anything it says sshould be fact-checked before being regarded as a true statement.\n ", | |
css=""" | |
.chatbox {display:flex;flex-direction:column} | |
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
.user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
.resp_msg {background-color:lightgray;align-self:self-end} | |
""", | |
allow_screenshot=True, | |
allow_flagging=True, # allow users to flag responses as inappropriate | |
flagging_dir="gradio_data", | |
flagging_options=[ | |
"great response", | |
"doesn't make sense", | |
"bad/offensive response", | |
], | |
enable_queue=True, # allows for dealing with multiple users simultaneously | |
theme="darkhuggingface", | |
) | |
corrector = pipeline("text2text-generation", model=gram_model, device=-1) | |
print("Finished loading the gramformer model - ", datetime.now()) | |
print(f"using model stored here: \n {model_loc} \n") | |
# launch the gradio interface and start the server | |
iface.launch(share=True) | |