Spaces:
Sleeping
Sleeping
import argparse | |
import logging | |
import time | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
use_gpu = torch.cuda.is_available() | |
def generate_text( | |
prompt: str, | |
gen_length=64, | |
num_beams=4, | |
no_repeat_ngram_size=2, | |
length_penalty=1.0, | |
# perma params (not set by user) | |
repetition_penalty=3.5, | |
abs_max_length=512, | |
verbose=False, | |
): | |
""" | |
generate_text - generate text from a prompt using a text generation pipeline | |
Args: | |
prompt (str): the prompt to generate text from | |
model_input (_type_): the text generation pipeline | |
max_length (int, optional): the maximum length of the generated text. Defaults to 128. | |
method (str, optional): the generation method. Defaults to "Sampling". | |
verbose (bool, optional): the verbosity of the output. Defaults to False. | |
Returns: | |
str: the generated text | |
""" | |
global generator | |
logging.info(f"Generating text from prompt: {prompt}") | |
st = time.perf_counter() | |
input_tokens = generator.tokenizer(prompt) | |
input_len = len(input_tokens["input_ids"]) | |
if input_len > abs_max_length: | |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") | |
result = generator( | |
prompt, | |
max_length=gen_length + input_len, | |
min_length=input_len + 4, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
no_repeat_ngram_size=no_repeat_ngram_size, | |
length_penalty=length_penalty, | |
do_sample=False, | |
early_stopping=True, | |
# tokenizer | |
truncation=True, | |
) # generate | |
response = result[0]["generated_text"] | |
rt = time.perf_counter() - st | |
if verbose: | |
logging.info(f"Generated text: {response}") | |
logging.info(f"Generation time: {rt:.2f}s") | |
return response | |
def get_parser(): | |
""" | |
get_parser - a helper function for the argparse module | |
""" | |
parser = argparse.ArgumentParser( | |
description="Text Generation demo for postbot", | |
) | |
parser.add_argument( | |
"-m", | |
"--model", | |
required=False, | |
type=str, | |
default="postbot/distilgpt2-emailgen", | |
help="Pass an different huggingface model tag to use a custom model", | |
) | |
parser.add_argument( | |
"-v", | |
"--verbose", | |
required=False, | |
action="store_true", | |
help="Verbose output", | |
) | |
return parser | |
default_prompt = """ | |
Hello, | |
Following up on the bubblegum shipment.""" | |
if __name__ == "__main__": | |
logging.info("\n\n\nStarting new instance of app.py") | |
args = get_parser().parse_args() | |
logging.info(f"received args:\t{args}") | |
model_tag = args.model | |
verbose = args.verbose | |
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") | |
generator = pipeline( | |
"text-generation", | |
model_tag, | |
device=0 if use_gpu else -1, | |
) | |
demo = gr.Blocks() | |
logging.info("launching interface...") | |
with demo: | |
gr.Markdown("# Autocompleting Emails with Textgen - Demo") | |
gr.Markdown( | |
"Enter part of an email, and the model will autocomplete it for you! The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)" | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## Generate Text") | |
gr.Markdown( | |
"Enter/edit the prompt and adjust the parameters as needed. Then press the Generate button!" | |
) | |
prompt_text = gr.Textbox( | |
lines=4, | |
label="Email Prompt", | |
value=default_prompt, | |
) | |
num_gen_tokens = gr.Slider( | |
label="Generation Tokens", | |
value=64, | |
maximum=128, | |
minimum=32, | |
step=16, | |
) | |
num_beams = gr.Radio( | |
choices=[4, 8, 16], | |
label="num beams", | |
value=4, | |
) | |
no_repeat_ngram_size = gr.Radio( | |
choices=[1, 2, 3, 4], | |
label="no repeat ngram size", | |
value=2, | |
) | |
length_penalty = gr.Slider( | |
minimum=0.5, maximum=1.0, label="length penalty", value=0.8, step=0.1 | |
) | |
generated_email = gr.Textbox( | |
label="Generated Result", | |
placeholder="The completed email will appear here", | |
) | |
generate_button = gr.Button( | |
"Generate!", | |
) | |
gr.Markdown("---") | |
with gr.Column(): | |
gr.Markdown("## About") | |
gr.Markdown( | |
"This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset." | |
) | |
gr.Markdown( | |
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something." | |
) | |
gr.Markdown("---") | |
generate_button.click( | |
fn=generate_text, | |
inputs=[ | |
prompt_text, | |
num_gen_tokens, | |
num_beams, | |
no_repeat_ngram_size, | |
length_penalty, | |
], | |
outputs=[generated_email], | |
) | |
demo.launch( | |
enable_queue=True, | |
share=True, # for local testing | |
) | |