import streamlit as st
from st_audiorec import st_audiorec
from Modules.Speech2Text.transcribe import transcribe
import base64
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
load_dotenv() # load .env api keys 
import os
import shutil
from Modules.rag import rag_chain
from Modules.router import router_chain
from Modules.PoseEstimation.pose_agent import agent_executor

mistral_api_key = os.getenv("MISTRAL_API_KEY")
from Modules.PoseEstimation import pose_estimator
from utils import save_uploaded_file


st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
# Create two columns
col1, col2 = st.columns(2)
video_uploaded = None
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
    template =""" You are a personal AI sports coach with an expertise in nutrition and fitness. 
    You are having a conversation with your client, which is either a beginner or an advanced athlete. 
    You must be gentle, kind, and motivative.
    Always try to answer concisely to the queries.
    User: {question}
    AI Coach:"""
)
base_chain = prompt | llm 

# First column containers
with col1:
    
    st.subheader("LLM answering")

    if "messages" not in st.session_state:
        st.session_state.messages = []
    
    if "file_name" not in st.session_state:
        st.session_state.file_name = None

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("What is up?"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant", avatar="data/AI_Bro.png"):
            # Build answer from LLM
            direction = router_chain.invoke({"question":prompt})

            if direction=='fitness_advices':
                with st.spinner("Thinking..."):
                    response = rag_chain.invoke(
                                prompt
                                )
            elif direction=='smalltalk':
                with st.spinner("Thinking..."):
                    response = base_chain.invoke(
                        {"question":prompt}
                    ).content
            elif direction =='movement_analysis':
                if st.session_state.file_name is not None:
                    prompt += "the file name is " + st.session_state.file_name
                with st.spinner("Analyzing movement..."):
                    response = agent_executor.invoke(
                        {"input" : prompt}
                    )["output"]
            st.session_state.messages.append({"role": "assistant", "content": response})
            st.markdown(response)

# Second column containers
with col2:
    # st.subheader("Sports Agenda")
        # TO DO
    st.subheader("Video Analysis")

    video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
    if video_uploaded:
        video_uploaded = save_uploaded_file(video_uploaded)
        if video_uploaded.split("/")[-1] != st.session_state.file_name:
            shutil.rmtree('fig', ignore_errors=True)
            shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True)
            st.session_state.file_name = None
        st.session_state.file_name = video_uploaded.split("/")[-1]
        _left, mid, _right = st.columns([1, 3, 1])
        with mid:
            if os.path.exists('/home/user/.pyenv/runs'):
                predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose'))
                predict_list.sort()
                predict_dir = predict_list[-1]
                file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0]
                print(file_name)
                st.video(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name), loop=True)
            else :
                st.video(video_uploaded)
            
    if os.path.exists('fig'):
        st.subheader("Graph Displayer")
        file_list = os.listdir('fig')
        for file in file_list:
            st.image(os.path.join('fig', file))