BigSalmon's picture
Update app.py
5c8f4a3
raw
history blame
4.76 kB
import argparse
import re
import os
import streamlit as st
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import tokenizers
#os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(None)
first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english:"""
suggested_text_list = [first]
@st.cache(hash_funcs={tokenizers.Tokenizer: id, tokenizers.AddedToken: id})
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95):
if len(input_text) == 0:
input_text = ""
encoded_prompt = tokenizer.encode(
input_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_size + len(encoded_prompt[0]),
top_k=top_k,
top_p=top_p,
do_sample=True,
num_return_sequences=num_return_sequences)
# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
generated_sequences = []
print(output_sequences)
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
print(text)
total_sequence = (
text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
st.write(total_sequence)
parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
if len(parsed_text) == 0:
parsed_text = "ืฉื’ื™ืื”"
return parsed_text
if __name__ == "__main__":
st.title("GPT2 Demo:")
pre_model_path = "BigSalmon/InformalToFormalLincoln15"
model, tokenizer = load_model(pre_model_path)
stop_token = "<|endoftext|>"
new_lines = "\n\n\n"
np.random.seed(None)
random_seed = np.random.randint(10000,size=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
torch.manual_seed(random_seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(random_seed)
model.to(device)
text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", first)
st.sidebar.subheader("Configurable parameters")
max_len = st.sidebar.slider("Max-Length", 0, 256, 5,help="The maximum length of the sequence to be generated.")
num_return_sequences = st.sidebar.slider("Outputs", 1, 50, 5,help="The number of outputs to be returned.")
top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
if st.button("Generate Text"):
with st.spinner(text="Generating results..."):
st.subheader("Result")
print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
if len(text_area.strip()) == 0:
text_area = random.choice(suggested_text_list)
result = extend(input_text=text_area,
num_return_sequences=int(num_return_sequences),
max_size=int(max_len),
top_k=int(top_k),
top_p=float(top_p))
print("Done length: " + str(len(result)) + " bytes")
#<div class="rtl" dir="rtl" style="text-align:right;">
st.markdown(f"{result}", unsafe_allow_html=True)
st.write("\n\nResult length: " + str(len(result)) + " bytes")
print(f"\"{result}\"")