import streamlit as st
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.messages import HumanMessage, AIMessageChunk, AIMessage
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.vectorstores import InMemoryVectorStore
import os
from langchain_core.chat_history import InMemoryChatMessageHistory, BaseChatMessageHistory
import time

from langgraph.errors import GraphRecursionError

from graph import get_graph
from langchain_core.runnables import RunnableConfig

if 'read_file' not in st.session_state:
    st.session_state.read_file = False
    st.session_state.retriever = None

if 'chat_history' not in st.session_state:
    st.session_state.chat_history = {}
    st.session_state.first_msg = True

def get_session_by_id(session_id: str) -> BaseChatMessageHistory:
    if session_id not in st.session_state.chat_history:
        st.session_state.chat_history[session_id] = InMemoryChatMessageHistory()
        return st.session_state.chat_history[session_id]
    return st.session_state.chat_history[session_id]

if not st.session_state.read_file:
    st.title('🤓 Upload your PDF to talk with it', anchor=False)
    file = st.file_uploader('Upload a PDF file', type='pdf')
    if file:
        with st.status('🤗 Booting up the things!', expanded=True):
            with st.spinner('📁 Uploading the PDF...', show_time=True):
                with open('file.pdf', 'wb') as f:
                    f.write(file.read())
                    loader = PyPDFLoader('file.pdf')
                    documents = loader.load_and_split(RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200))
                st.success('📁 File uploaded successfully!!!')
            with st.spinner('🧐 Reading the file...', show_time=True):
                vstore = InMemoryVectorStore.from_documents(documents, HuggingFaceEmbeddings(model_name='all-MiniLM-L6-v2'))
                st.session_state.retriever = vstore.as_retriever()
                st.success('🧐 File read successfully!!!')
            os.remove('file.pdf')
            with st.spinner('😴 Waking up the LLM...', show_time=True):
                st.session_state.graph = get_graph(st.session_state.retriever)
                st.success('😁 LLM awakened!!!')
            st.balloons()
        placeholder = st.empty()
        for _ in range(5, -1, -1):
            placeholder.write(f'⏳ Chat starting in 0{_} sec.')
            time.sleep(1)
        st.session_state.read_file = True
        st.rerun()

if st.session_state.read_file:

    st.title('🤗 DocAI', anchor=False)
    st.subheader('Chat with your document!', anchor=False)

    if st.session_state.first_msg:
        st.session_state.first_msg = False
        get_session_by_id('chat42').add_message(AIMessage(content='Hello, how are you? How about we talk about the '
                                                                  'document you sent me to read?'))

    for msg in get_session_by_id('chat42').messages:
        with st.chat_message(name='user' if isinstance(msg, HumanMessage) else 'ai'):
            st.write(msg.content)

    prompt = st.chat_input('Try to ask something about your file!')
    if prompt:
        with st.chat_message(name='user'):
            st.write(prompt)

        response = st.session_state.graph.stream(
            {
                'question': prompt,
                'scratchpad': None,
                'answer': None,
                'next_node': None,
                'history': get_session_by_id('chat42').messages,
            },
            stream_mode='messages',
            config = RunnableConfig(recursion_limit=4)
        )

        get_session_by_id('chat42').add_message(HumanMessage(content=prompt))

        def get_message():
            for chunk, _ in response:
                if chunk.content and isinstance(chunk, AIMessageChunk):
                    yield chunk.content

        with st.chat_message(name='ai'):
            full_response = ''
            tool_placeholder = st.empty()
            placeholders = {}
            prompt_message_placeholder = st.empty()

            try:
                for msg in get_message():
                    full_response += msg
                    if '<tool>' in full_response:
                        with tool_placeholder.status('Reading document...', expanded=True):
                            if 'tool_message_placeholder' not in placeholders:
                                placeholders['tool_message_placeholder'] = st.empty()
                            placeholders['tool_message_placeholder'].write(full_response
                                                                           .replace('<tool>', '')
                                                                           .replace('</tool>', '')
                                                                           .replace('retriever', 'Retrieving document'))
                            prompt_message_placeholder.empty()
                    if '</tool>' in full_response:
                        full_response = ''
                        continue
                    else:
                        prompt_message_placeholder.write(full_response.replace('$', '\$'))
            except GraphRecursionError:
                message = 'Não consegui responder a sua pergunta. 😥 Poderia me perguntar outra coisa?'
                full_response = ''
                for letter in message:
                    full_response += letter
                    time.sleep(0.015)
                    prompt_message_placeholder.write(full_response)

        get_session_by_id('chat42').add_message(AIMessage(content=full_response.replace('$', '\$')))