Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -37,6 +37,7 @@ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95):
|
|
37 |
if len(output_sequences.shape) > 2:
|
38 |
output_sequences.squeeze_()
|
39 |
generated_sequences = []
|
|
|
40 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
41 |
generated_sequence = generated_sequence.tolist()
|
42 |
# Decode text
|
@@ -81,18 +82,13 @@ if __name__ == "__main__":
|
|
81 |
print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
|
82 |
if len(text_area.strip()) == 0:
|
83 |
text_area = random.choice(suggested_text_list)
|
84 |
-
result = extend(input_text=text_area,
|
|
|
85 |
max_size=int(max_len),
|
86 |
top_k=int(top_k),
|
87 |
-
top_p=float(top_p)
|
88 |
-
num_return_sequences=int(num_return_sequences))
|
89 |
print("Done length: " + str(len(result)) + " bytes")
|
90 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
91 |
-
st.markdown(f"
|
92 |
st.write("\n\nResult length: " + str(len(result)) + " bytes")
|
93 |
-
print(f"\"{result}\"")
|
94 |
-
|
95 |
-
st.markdown(
|
96 |
-
"""Hebrew text generation model (125M parameters) based on EleutherAI's gpt-neo architecture. Originally trained on a TPUv3-8 which was made avilable to me via the [TPU Research Cloud Program](https://sites.research.google/trc/)."""
|
97 |
-
)
|
98 |
-
st.markdown("<footer><hr><p style=\"font-size:14px\">Enjoy</p><p style=\"font-size:12px\">Created by <a href=\"https://linktr.ee/Norod78\">Doron Adler</a></p></footer> ", unsafe_allow_html=True)
|
|
|
37 |
if len(output_sequences.shape) > 2:
|
38 |
output_sequences.squeeze_()
|
39 |
generated_sequences = []
|
40 |
+
print(output_sequences)
|
41 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
42 |
generated_sequence = generated_sequence.tolist()
|
43 |
# Decode text
|
|
|
82 |
print(f"device:{device}, n_gpu:{n_gpu}, random_seed:{random_seed}, maxlen:{max_len}, top_k:{top_k}, top_p:{top_p}")
|
83 |
if len(text_area.strip()) == 0:
|
84 |
text_area = random.choice(suggested_text_list)
|
85 |
+
result = extend(input_text=text_area,
|
86 |
+
num_return_sequences=int(num_return_sequences),
|
87 |
max_size=int(max_len),
|
88 |
top_k=int(top_k),
|
89 |
+
top_p=float(top_p))
|
|
|
90 |
print("Done length: " + str(len(result)) + " bytes")
|
91 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
92 |
+
st.markdown(f"{result}", unsafe_allow_html=True)
|
93 |
st.write("\n\nResult length: " + str(len(result)) + " bytes")
|
94 |
+
print(f"\"{result}\"")
|
|
|
|
|
|
|
|
|
|