BigSalmon commited on
Commit
2473dde
·
1 Parent(s): 5c8f4a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -16,7 +16,7 @@ def load_model(model_name):
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(model_name)
18
  return model, tokenizer
19
- def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95):
20
  if len(input_text) == 0:
21
  input_text = ""
22
  encoded_prompt = tokenizer.encode(
@@ -27,10 +27,18 @@ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95):
27
  else:
28
  input_ids = encoded_prompt
29
 
 
 
 
 
 
 
 
30
  output_sequences = model.generate(
31
  input_ids=input_ids,
32
  max_length=max_size + len(encoded_prompt[0]),
33
  top_k=top_k,
 
34
  top_p=top_p,
35
  do_sample=True,
36
  num_return_sequences=num_return_sequences)
@@ -73,6 +81,8 @@ if __name__ == "__main__":
73
  num_return_sequences = st.sidebar.slider("Outputs", 1, 50, 5,help="The number of outputs to be returned.")
74
  top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
75
  top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
 
 
76
  if st.button("Generate Text"):
77
  with st.spinner(text="Generating results..."):
78
  st.subheader("Result")
@@ -83,7 +93,8 @@ if __name__ == "__main__":
83
  num_return_sequences=int(num_return_sequences),
84
  max_size=int(max_len),
85
  top_k=int(top_k),
86
- top_p=float(top_p))
 
87
  print("Done length: " + str(len(result)) + " bytes")
88
  #<div class="rtl" dir="rtl" style="text-align:right;">
89
  st.markdown(f"{result}", unsafe_allow_html=True)
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(model_name)
18
  return model, tokenizer
19
+ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95, bad_words):
20
  if len(input_text) == 0:
21
  input_text = ""
22
  encoded_prompt = tokenizer.encode(
 
27
  else:
28
  input_ids = encoded_prompt
29
 
30
+ bad_words = bad_words.split()
31
+ bad_word_ids = []
32
+ for bad_word in bad_words:
33
+ bad_word = " " + bad_word
34
+ ids = tokenizer(bad_word).input_ids
35
+ bad_word_ids.append(ids)
36
+
37
  output_sequences = model.generate(
38
  input_ids=input_ids,
39
  max_length=max_size + len(encoded_prompt[0]),
40
  top_k=top_k,
41
+ bad_word_ids = bad_word_ids,
42
  top_p=top_p,
43
  do_sample=True,
44
  num_return_sequences=num_return_sequences)
 
81
  num_return_sequences = st.sidebar.slider("Outputs", 1, 50, 5,help="The number of outputs to be returned.")
82
  top_k = st.sidebar.slider("Top-K", 0, 100, 40, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
83
  top_p = st.sidebar.slider("Top-P", 0.0, 1.0, 0.92, help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.")
84
+ bad_words = st.text_input("Words You Do Not Want Generated", " core lemon height time ")
85
+
86
  if st.button("Generate Text"):
87
  with st.spinner(text="Generating results..."):
88
  st.subheader("Result")
 
93
  num_return_sequences=int(num_return_sequences),
94
  max_size=int(max_len),
95
  top_k=int(top_k),
96
+ top_p=float(top_p),
97
+ bad_words = bad_words))
98
  print("Done length: " + str(len(result)) + " bytes")
99
  #<div class="rtl" dir="rtl" style="text-align:right;">
100
  st.markdown(f"{result}", unsafe_allow_html=True)