import gradio as gr
from gradio.themes.utils import colors
from t5 import T5
from koalpaca import KoAlpaca

LOCAL_TEST = False
MODEL_STRS = ['T5', 'KoAlpaca']
MODELS = []

def prepare_theme():
    theme = gr.themes.Default(primary_hue=colors.gray, 
                            secondary_hue=colors.emerald,
                            neutral_hue=colors.emerald).set(
        body_background_fill="*primary_800",
        body_background_fill_dark="*primary_800",
        
        block_background_fill="*primary_700",
        block_background_fill_dark="*primary_700",
        
        border_color_primary="*secondary_300",
        border_color_primary_dark="*secondary_300",
        
        block_border_width="3px",
        input_border_width="2px",
        
        input_background_fill="*primary_700",
        input_background_fill_dark="*primary_700",
        
        background_fill_primary="*neutral_950",
        background_fill_primary_dark="*neutral_950",
        
        background_fill_secondary="*primary_700",
        background_fill_secondary_dark="*primary_700",
        
        body_text_color="white",
        body_text_color_dark="white",
        
        block_label_text_color="*secondary_300",
        block_label_text_color_dark="*secondary_300",
        
        block_label_background_fill="*primary_800",
        block_label_background_fill_dark="*primary_800",
        
        color_accent_soft="*primary_600",
        color_accent_soft_dark="*primary_600",
    )
    return theme

if __name__=='__main__':
    theme = prepare_theme()
    with open('README.txt', 'r') as f:
        readme = f.read()
        
    MODELS.append(T5())
    MODELS[0].placeholder = '연애 관련 질문을 입력하세요!'
    if not LOCAL_TEST:
        MODELS.append(KoAlpaca())
        MODELS[1].placeholder = '연애 관련 질문을 입력하세요. (KoAlpaca는 추론 시 1분 이상 소요됩니다!)'

    with gr.Blocks(theme=prepare_theme()) as demo:
        gr.HTML("<h1>KOMUChat : Korean community-style relationship counseling chatbot</h1>")
        with gr.Tab("소개"):
            gr.Markdown(readme)
        for i in range(len(MODELS)):
            with gr.Tab(MODEL_STRS[i]):
                chatbot = gr.Chatbot(label=MODEL_STRS[i], bubble_full_width=False)
                txt = gr.Textbox(show_label=False, placeholder=MODELS[i].placeholder, container=False, elem_id=i)
                txt.submit(MODELS[i].chat, [txt, chatbot], [txt, chatbot])

    demo.launch(debug=True, share=True)