Spaces:
Running
Running
File size: 5,601 Bytes
f420881 58abf68 632ca18 f420881 58abf68 7894523 58abf68 834f6fb f02f710 834f6fb 7a62361 834f6fb 7a62361 834f6fb 7a62361 834f6fb 5a389e5 834f6fb 58abf68 834f6fb 5a389e5 58abf68 834f6fb 9a0caf3 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 834f6fb 58abf68 7894523 58abf68 f420881 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
from transformers import GPT2LMHeadModel
from indobenchmark import IndoNLGTokenizer
gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt")
def generate_story():
stop = False
prompt = "<s> awal cerita | judul:"
judul = ""
isi = ""
end_part = ""
isi_not_checked = True
yield "..."
while not stop:
prompt_stop = False
while not prompt_stop:
gpt_input = gpt_tokenizer(prompt, return_tensors='pt')
gpt_out = kancilgpt.generate(
**gpt_input,
do_sample=True,
max_new_tokens=2,
pad_token_id=gpt_tokenizer.eos_token_id,
eos_token_id=gpt_tokenizer.eos_token_id
)
gpt_out = gpt_out[0]
result = gpt_tokenizer.decode(gpt_out)
splitted_result = result.split(" | ")
if len(splitted_result) <= 2:
_, judul_prompt = splitted_result
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
yield judul + "..."
if "." in judul:
print("Invalid judul!")
prompt = "<s> awal cerita | judul:"
continue
isi = ""
end_part = ""
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
continue
else:
_, judul_prompt, isi, *end_part = splitted_result
end_part = "".join(end_part)
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
yield judul + "\n" + ("-" * len(judul)) + "\n" + isi + f"..."
if len(splitted_result) == 3:
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
continue
elif isi_not_checked:
quote_count = 0
prev_i = 0
for i, c in enumerate(isi):
if c == "\"":
quote_count += 1
prev_i = i
if quote_count % 2 != 0:
print("Invalid isi!")
trimmed_isi = isi[:prev_i].rstrip()
prompt = f"<s> awal cerita | judul: {judul} | {trimmed_isi}"
continue
isi_not_checked = False
if gpt_out[-1] == gpt_tokenizer.eos_token_id:
prompt_stop = True
else:
prompt = result
# prompt_stop
if (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
print("Invalid ending! Regenerating ....")
prompt = f"<s> awal cerita | judul: {judul} | {isi} |"
continue
stop = True
total_isi = isi
print("We skip the rest of the part for debug.")
# TODO: Solve this.
# ellipsis = "..."
# while not end_part.startswith("tamat"):
# yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + f" {ellipsis}"
# ellipsis += "."
# i = 0
# in_quote = False
# end_sentence = False
# limit = 1750
# while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit):
# if isi[i] == "\"":
# in_quote = not in_quote
# if end_sentence:
# end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz"
# else:
# end_sentence = isi[i] in ".?!"
# i += 1
# # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " "
# while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""):
# i += 1
# # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\""
# if i == len(isi):
# raise ValueError("What???")
# next_isi = isi[i:]
# stop = False
# while not stop:
# gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt')
# gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
# result = gpt_tokenizer.decode(gpt_out[0])
# _, judul_prompt, isi, *end_part = result.split(" | ")
# end_part = "".join(end_part)
# _, *judul_words = judul_prompt.split()
# judul = " ".join(judul_words)
# if isi[len(next_isi) + 1:].strip() != "":
# print(isi[len(next_isi) + 1:])
# if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
# print("Invalid output! Regenerating ....")
# continue
# quote_count = 0
# for c in isi:
# if c == "\"":
# quote_count += 1
# if quote_count % 2 != 0:
# print("Invalid output! Regenerating ....")
# continue
# stop = True
# total_isi += " " + isi[len(next_isi) + 1:]
# ellipsis = "..."
yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + "\n\ntamat."
demo = gr.Interface(
fn=generate_story,
inputs=None,
outputs=[
gr.Textbox(label="cerita", lines=7)
]
)
demo.launch()
|