import streamlit as st
from streamlit_chat import message


@st.cache(allow_output_mutation=True)
def get_pipe():
    from transformers import AutoTokenizer, AutoModelForCausalLM
    model_name = "heegyu/ajoublue-gpt2-medium-dialog"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
    # print("history:", history)
    context = []
    for i, text in enumerate(history):
        context.append(f"{i % 2}: {text}</s>")
    
    if len(context) > max_context:
        context = context[-max_context:]
    context = "".join(context) + f"{bot_id}: "
    inputs = tokenizer(context, return_tensors="pt")
    
    generation_args = dict(
        max_new_tokens=128,
        min_length=inputs["input_ids"].shape[1] + 5,
        # no_repeat_ngram_size=4,
        eos_token_id=2,
        do_sample=True,
        top_p=0.95,
        temperature=1.35,
        # repetition_penalty=1.0,
        early_stopping=True
    )

    outputs = model.generate(**inputs, **generation_args)
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    print("Context:", tokenizer.decode(inputs["input_ids"][0]))
    print("Response:", response)
    response = response[len(context):].replace("</s>", "").replace("\n", "")
    response = response.split("<s>")[0]
    # print("Response:", response)
    return response

st.title("ajoublue-gpt2-medium 한국어 대화 모델 demo")

with st.spinner("loading model..."):
    model, tokenizer = get_pipe()

if 'message_history' not in st.session_state:
    st.session_state.message_history =  []
history = st.session_state.message_history

# print(st.session_state.message_history)
for i, message_ in enumerate(st.session_state.message_history):
    message(message_,is_user=i % 2 == 0, key=i) # display all the previous message

# placeholder = st.empty() # placeholder for latest message
input_ = st.text_input("아무 말이나 해보세요", value="")

if input_ is not None and len(input_) > 0:
    if len(history) <= 1 or history[-2] != input_:
        with st.spinner("대답을 생성중입니다..."):
            st.session_state.message_history.append(input_)
            response = get_response(tokenizer, model, history)
            st.session_state.message_history.append(response)
            st.experimental_rerun()