File size: 3,620 Bytes
f420881
58abf68
632ca18
f420881
58abf68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f420881
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from transformers import GPT2LMHeadModel
from indobenchmark import IndoNLGTokenizer

gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt")

def generate_story(judul: str):
    yield "...", "..."

    stop = False
    while not stop:
        gpt_input = gpt_tokenizer('<s> awal cerita | judul:', return_tensors='pt')
        gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
        result = gpt_tokenizer.decode(gpt_out[0])
        _, judul_prompt, isi, *end_part = result.split(" | ")
        end_part = "".join(end_part)
        _, *judul_words = judul_prompt.split()
        judul = " ".join(judul_words)


        if "</s>" in judul or "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
            print("Invalid output! Regenerating ....")
            continue
        

        quote_count = 0
        for c in isi:
            if c == "\"":
                quote_count += 1

        if quote_count % 2 != 0:
            print("Invalid output! Regenerating ....")
            continue

        stop = True

    yield judul, isi + " ..."

    total_isi = isi

    while not end_part.startswith("tamat"):
        yield judul, total_isi + " ..."

        i = 0
        in_quote = False
        end_sentence = False
        limit = 1750
        while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit):
            if isi[i] == "\"":
                in_quote = not in_quote

            if end_sentence:
                end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz"
            else:
                end_sentence = isi[i] in ".?!"

            i += 1
        # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " "

        while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""):
            i += 1
        # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\""

        if i == len(isi):
            raise ValueError("What???")

        next_isi = isi[i:]

        stop = False
        while not stop:
            gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt')
            gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
            result = gpt_tokenizer.decode(gpt_out[0])

            _, judul_prompt, isi, *end_part = result.split(" | ")
            end_part = "".join(end_part)
            _, *judul_words = judul_prompt.split()
            judul = " ".join(judul_words)

            if isi[len(next_isi) + 1:].strip() != "":
                print(isi[len(next_isi) + 1:])

            if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
                print("Invalid output! Regenerating ....")
                continue

            quote_count = 0
            for c in isi:
                if c == "\"":
                    quote_count += 1

            if quote_count % 2 != 0:
                print("Invalid output! Regenerating ....")
                continue

            stop = True

        total_isi += " " + isi[len(next_isi) + 1:]

    yield judul, total_isi + "\n\ntamat."

demo = gr.Interface(
    fn=generate_story,
    inputs=None,
    outputs=[
        gr.Textbox(label="judul", lines=1),
        gr.Textbox(label="cerita", lines=7)
    ]
)

demo.launch()