File size: 6,271 Bytes
e5f4bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9749435
 
 
 
 
 
 
 
 
 
 
e5f4bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9749435
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import streamlit as st
import SessionState
from mtranslate import translate
from prompts import PROMPT_LIST
import random
import time
from transformers import pipeline, set_seed
import psutil
import codecs
import streamlit.components.v1 as stc
import pathlib

# st.set_page_config(page_title="Indonesian Story Generator")

MODELS = {
    "Indonesian Literature - GPT-2 Small": {
        "name": "cahya/gpt2-small-indonesian-story",
        "text_generator": None
    },
    "Indonesian Literature - GPT-2 Medium": {
        "name": "cahya/gpt2-medium-indonesian-story",
        "text_generator": None
    },
    "Indonesian Persona Chatbot": {
        "name": "",
        "text_generator": None
    },
}


def stc_chatbot(root_dir, width=700, height=900):
    html_file = root_dir/"app/chatbot.html"
    css_file = root_dir/"app/css/main.css"
    js_file = root_dir/"app/js/main.js"
    if css_file.exists() and js_file.exists():
        html = codecs.open(html_file, "r").read()
        css = codecs.open(css_file, "r").read()
        js = codecs.open(js_file, "r").read()
        html = html.replace('<link rel="stylesheet" href="css/main.css">', "<style>\n" + css + "\n</style>")
        html = html.replace('<script src="js/main.js"></script>', "<script>\n" + js + "\n</script>")
        stc.html(html, width=width, height=height, scrolling=True)


model = st.sidebar.selectbox('Model', (MODELS.keys()))


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_generator(model_name: str):
    st.write(f"Loading the GPT2 model {model_name}, please wait...")
    text_generator = pipeline('text-generation', model=model_name)
    return text_generator


# Disable the st.cache for this function due to issue on newer version of streamlit
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
            temperature: float = 1.0, max_time: float = 60.0, seed=42):
    # st.write("Cache miss: process")
    set_seed(seed)
    result = text_generator(text, max_length=max_length, do_sample=do_sample,
                            top_k=top_k, top_p=top_p, temperature=temperature,
                            max_time=max_time)
    return result


st.title("Indonesian GPT-2 Applications")
prompt_group_name = ""
if model.find("Indonesian Literature") != -1:
    st.subheader("Indonesian Literature")
    prompt_group_name = "Indonesian Literature"
    st.markdown(
        """
        This application is a demo for Indonesian Literature Generator using GPT2.
        """
    )
    session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
    ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]

    prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)

    # Update prompt
    if session_state.prompt is None:
        session_state.prompt = prompt
    elif session_state.prompt is not None and (prompt != session_state.prompt):
        session_state.prompt = prompt
        session_state.prompt_box = None
        session_state.text = None
    else:
        session_state.prompt = prompt

    # Update prompt box
    if session_state.prompt == "Custom":
        session_state.prompt_box = "Enter your text here"
    else:
        print(f"# prompt: {session_state.prompt}")
        print(f"# prompt_box: {session_state.prompt_box}")
        if session_state.prompt is not None and session_state.prompt_box is None:
            session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])

    session_state.text = st.text_area("Enter text", session_state.prompt_box)

    max_length = st.sidebar.number_input(
        "Maximum length",
        value=100,
        max_value=512,
        help="The maximum length of the sequence to be generated."
    )

    temperature = st.sidebar.slider(
        "Temperature",
        value=1.0,
        min_value=0.0,
        max_value=10.0
    )

    do_sample = st.sidebar.checkbox(
        "Use sampling",
        value=True
    )

    top_k = 40
    top_p = 0.95

    if do_sample:
        top_k = st.sidebar.number_input(
            "Top k",
            value=top_k
        )
        top_p = st.sidebar.number_input(
            "Top p",
            value=top_p
        )

    seed = st.sidebar.number_input(
        "Random Seed",
        value=25,
        help="The number used to initialize a pseudorandom number generator"
    )

    for group_name in MODELS:
        if group_name.find("Indonesian Literature") != -1:
            MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
    # text_generator = get_generator()
    if st.button("Run"):
        with st.spinner(text="Getting results..."):
            memory = psutil.virtual_memory()
            st.subheader("Result")
            time_start = time.time()
            # text_generator = MODELS[model]["text_generator"]
            result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
                             temperature=temperature, do_sample=do_sample,
                             top_k=int(top_k), top_p=float(top_p), seed=seed)
            time_end = time.time()
            time_diff = time_end-time_start
            result = result[0]["generated_text"]
            st.write(result.replace("\n", "  \n"))
            st.text("Translation")
            translation = translate(result, "en", "id")
            st.write(translation.replace("\n", "  \n"))
            # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
            info = f"""
            *Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*        
            *Text generated in {time_diff:.5} seconds*
            """
            st.write(info)

            # Reset state
            session_state.prompt = None
            session_state.prompt_box = None
            session_state.text = None
elif model == "Indonesian Persona Chatbot":
    st.subheader("Indonesian GPT-2 Persona Chatbot")
    root_dir = pathlib.Path(".")
    stc_chatbot(root_dir)