File size: 4,760 Bytes
414d64e
 
 
7bc1fb2
 
414d64e
 
 
 
 
 
b78cea5
d18493e
414d64e
 
 
 
 
86e6793
414d64e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb080d4
414d64e
 
 
 
5890042
414d64e
 
 
c1074ac
414d64e
6623bf1
414d64e
e8db36b
 
9bc5128
414d64e
 
 
 
 
b4f2d80
5c8f4a3
414d64e
 
 
 
 
 
 
 
 
 
 
95a4ec2
414d64e
c521dc5
5286cda
414d64e
 
 
 
 
 
 
 
5890042
 
9b6c3d2
414d64e
5890042
414d64e
 
5890042
414d64e
5890042
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import re
import os
import streamlit as st
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import tokenizers
#os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(None)
first = """informal english: corn fields are all across illinois, visible once you leave chicago.\nTranslated into the Style of Abraham Lincoln: corn fields ( permeate illinois / span the state of illinois / ( occupy / persist in ) all corners of illinois / line the horizon of illinois / envelop the landscape of illinois ), manifesting themselves visibly as one ventures beyond chicago.\n\ninformal english:"""
suggested_text_list = [first]
@st.cache(hash_funcs={tokenizers.Tokenizer: id, tokenizers.AddedToken: id})
def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer
def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95):
    if len(input_text) == 0:
        input_text = ""
    encoded_prompt = tokenizer.encode(
    input_text, add_special_tokens=False, return_tensors="pt")
    encoded_prompt = encoded_prompt.to(device)
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    
    output_sequences = model.generate(
    input_ids=input_ids,
    max_length=max_size + len(encoded_prompt[0]),
    top_k=top_k, 
    top_p=top_p, 
    do_sample=True,
    num_return_sequences=num_return_sequences)
    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()
    generated_sequences = []
    print(output_sequences)
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):        
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
        print(text)
        total_sequence = (
            text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )
        generated_sequences.append(total_sequence)
        st.write(total_sequence)
        
    parsed_text = total_sequence.replace("<|startoftext|>", "").replace("\r","").replace("\n\n", "\n")
    if len(parsed_text) == 0:
        parsed_text = "שגיאה"
    return parsed_text
if __name__ == "__main__":
    st.title("GPT2 Demo:")
    pre_model_path = "BigSalmon/InformalToFormalLincoln15"
    model, tokenizer = load_model(pre_model_path)
    stop_token = "<|endoftext|>"
    new_lines = "\n\n\n"
    np.random.seed(None)
    random_seed = np.random.randint(10000,size=1)    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
    torch.manual_seed(random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(random_seed)
    model.to(device)
    text_area = st.text_area("Enter the first few words (or leave blank), tap on \"Generate Text\" below. Tapping again will produce a different result.", first)
    st.sidebar.subheader("Configurable parameters")
    max_len = st.sidebar.slider("Max-Length", 0, 256, 5,help="The maximum length of the sequence to be generated.")
    num_return_sequences = st.sidebar.slider("Outputs", 1, 50, 5,help="The number of outputs to be returned.")
    top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
    top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
    if st.button("Generate Text"):
        with st.spinner(text="Generating results..."):
            st.subheader("Result")
            print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
            if len(text_area.strip()) == 0:
                text_area = random.choice(suggested_text_list)
            result = extend(input_text=text_area,
                            num_return_sequences=int(num_return_sequences),                         
                            max_size=int(max_len),                   
                            top_k=int(top_k),
                            top_p=float(top_p))
            print("Done length: " + str(len(result)) + " bytes") 
            #<div class="rtl" dir="rtl" style="text-align:right;">
            st.markdown(f"{result}", unsafe_allow_html=True)
            st.write("\n\nResult length: " + str(len(result)) + " bytes")
            print(f"\"{result}\"")