File size: 5,601 Bytes
f420881
58abf68
632ca18
f420881
58abf68
 
 
 
7894523
58abf68
 
834f6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
f02f710
834f6fb
 
 
 
 
 
 
 
 
 
7a62361
834f6fb
7a62361
 
 
 
 
 
834f6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a62361
834f6fb
 
 
 
 
5a389e5
834f6fb
 
 
58abf68
834f6fb
 
 
 
 
5a389e5
58abf68
 
 
 
 
834f6fb
9a0caf3
 
834f6fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58abf68
834f6fb
 
58abf68
834f6fb
 
 
58abf68
834f6fb
 
58abf68
834f6fb
58abf68
834f6fb
 
 
 
 
58abf68
834f6fb
 
 
 
58abf68
834f6fb
 
58abf68
834f6fb
 
 
58abf68
834f6fb
 
 
 
58abf68
834f6fb
 
 
58abf68
834f6fb
58abf68
834f6fb
 
58abf68
7894523
58abf68
 
 
 
 
 
 
 
f420881
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
from transformers import GPT2LMHeadModel
from indobenchmark import IndoNLGTokenizer

gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt")

def generate_story():
    stop = False

    prompt = "<s> awal cerita | judul:"
    judul = ""
    isi = ""
    end_part = ""
    isi_not_checked = True

    yield "..."
    while not stop:
        prompt_stop = False
        while not prompt_stop:
            gpt_input = gpt_tokenizer(prompt, return_tensors='pt')
            gpt_out = kancilgpt.generate(
                **gpt_input,
                do_sample=True,
                max_new_tokens=2,
                pad_token_id=gpt_tokenizer.eos_token_id,
                eos_token_id=gpt_tokenizer.eos_token_id
            )
            gpt_out = gpt_out[0]

            result = gpt_tokenizer.decode(gpt_out)
            splitted_result = result.split(" | ")
            if len(splitted_result) <= 2:
                _, judul_prompt = splitted_result
                _, *judul_words = judul_prompt.split()

                judul = " ".join(judul_words)
                yield judul + "..."
                if "." in judul:
                    print("Invalid judul!")
                    prompt = "<s> awal cerita | judul:"
                    continue

                isi = ""
                end_part = ""

                if gpt_out[-1] == gpt_tokenizer.eos_token_id:
                    continue
            else:
                _, judul_prompt, isi, *end_part = splitted_result
                end_part = "".join(end_part)
                _, *judul_words = judul_prompt.split()
                judul = " ".join(judul_words)

                yield judul + "\n" + ("-" * len(judul)) + "\n" + isi + f"..."

                if len(splitted_result) == 3:
                    if gpt_out[-1] == gpt_tokenizer.eos_token_id:
                        continue

                elif isi_not_checked:
                    quote_count = 0
                    prev_i = 0
                    for i, c in enumerate(isi):
                        if c == "\"":
                            quote_count += 1
                            prev_i = i

                    if quote_count % 2 != 0:
                        print("Invalid isi!")
                        trimmed_isi = isi[:prev_i].rstrip()
                        prompt = f"<s> awal cerita | judul: {judul} | {trimmed_isi}"
                        continue

                    isi_not_checked = False

            if gpt_out[-1] == gpt_tokenizer.eos_token_id:
                prompt_stop = True
            else:
                prompt = result

        # prompt_stop
        
        if (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
            print("Invalid ending! Regenerating ....")
            prompt = f"<s> awal cerita | judul: {judul} | {isi} |"
            continue

        stop = True

    total_isi = isi

    print("We skip the rest of the part for debug.")

    # TODO: Solve this.
    # ellipsis = "..."
    # while not end_part.startswith("tamat"):
    #     yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + f" {ellipsis}"
    #     ellipsis += "."

    #     i = 0
    #     in_quote = False
    #     end_sentence = False
    #     limit = 1750
    #     while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit):
    #         if isi[i] == "\"":
    #             in_quote = not in_quote

    #         if end_sentence:
    #             end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz"
    #         else:
    #             end_sentence = isi[i] in ".?!"

    #         i += 1
    #     # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " "

    #     while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""):
    #         i += 1
    #     # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\""

    #     if i == len(isi):
    #         raise ValueError("What???")

    #     next_isi = isi[i:]

    #     stop = False
    #     while not stop:
    #         gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt')
    #         gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
    #         result = gpt_tokenizer.decode(gpt_out[0])

    #         _, judul_prompt, isi, *end_part = result.split(" | ")
    #         end_part = "".join(end_part)
    #         _, *judul_words = judul_prompt.split()
    #         judul = " ".join(judul_words)

    #         if isi[len(next_isi) + 1:].strip() != "":
    #             print(isi[len(next_isi) + 1:])

    #         if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
    #             print("Invalid output! Regenerating ....")
    #             continue

    #         quote_count = 0
    #         for c in isi:
    #             if c == "\"":
    #                 quote_count += 1

    #         if quote_count % 2 != 0:
    #             print("Invalid output! Regenerating ....")
    #             continue

    #         stop = True

    #     total_isi += " " + isi[len(next_isi) + 1:]
    #     ellipsis = "..."

    yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + "\n\ntamat."

demo = gr.Interface(
    fn=generate_story,
    inputs=None,
    outputs=[
        gr.Textbox(label="cerita", lines=7)
    ]
)

demo.launch()