File size: 3,402 Bytes
d296c34
 
e87f4b7
 
025e412
9aff9bb
025e412
 
 
9a30a8c
 
9aff9bb
 
9a30a8c
025e412
a755c90
 
d296c34
9a30a8c
d296c34
 
 
4c18e6f
025e412
9aff9bb
 
 
 
 
 
 
 
 
d296c34
 
 
85b06be
 
e87f4b7
85b06be
 
d296c34
85b06be
 
 
e87f4b7
85b06be
 
e87f4b7
85b06be
e87f4b7
 
 
85b06be
9aff9bb
 
85b06be
9aff9bb
 
 
85b06be
9aff9bb
 
 
 
 
9a30a8c
025e412
 
d296c34
 
4c18e6f
d296c34
 
 
4c18e6f
d296c34
4c18e6f
 
 
 
a755c90
4c18e6f
 
 
85b06be
 
 
 
a755c90
d296c34
 
85b06be
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from st_audiorec import st_audiorec
from Modules.Speech2Text.transcribe import transcribe
import base64
from langchain_mistralai import ChatMistralAI
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
load_dotenv() # load .env api keys 
import os

from Modules.rag import rag_chain
from Modules.router import router_chain
# from Modules.PoseEstimation.pose_agent import agent_executor

mistral_api_key = os.getenv("MISTRAL_API_KEY")
from Modules.PoseEstimation import pose_estimator
from utils import save_uploaded_file


st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
# Create two columns
col1, col2 = st.columns(2)
video_uploaded = None
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0)
prompt = ChatPromptTemplate.from_template(
    template =""" You are a personal AI sports coach with an expertise in nutrition and fitness. 
    You are having a conversation with your client, which is either a beginner or an advanced athlete. 
    You must be gentle, kind, and motivative.
    Always try to answer concisely to the queries.
    User: {question}
    AI Coach:"""
)
base_chain = prompt | llm 

# First column containers
with col1:
    
    st.subheader("LLM answering")

    if "messages" not in st.session_state:
        st.session_state.messages = []

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("What is up?"):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            # Build answer from LLM
            direction = router_chain.invoke({"question":prompt})
            if direction=='fitness_advices':
                response = rag_chain.invoke(
                            prompt
                            )
            elif direction=='smalltalk':
                response = base_chain.invoke(
                    {"question":prompt}
                ).content
            # elif direction =='movement_analysis':
            #     response = agent_executor.invoke(
            #         {"input" : instruction}
            #     )["output"]
            print(type(response))
            st.session_state.messages.append({"role": "assistant", "content": response})
            st.markdown(response)

    st.subheader("Movement Analysis")
        # TO DO 
# Second column containers
with col2:
    st.subheader("Sports Agenda")
        # TO DO
    st.subheader("Video Analysis")
    ask_video = st.empty()
    if video_uploaded is None:
        video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
    if video_uploaded:
        video_uploaded = save_uploaded_file(video_uploaded)
        ask_video.empty()
        _left, mid, _right = st.columns(3)
        with mid:
            if os.path.exists('runs'):
                st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True)
            else :
                st.video(video_uploaded)
            

    st.subheader("Graph Displayer")
    if os.path.exists('fig'):
        file_list = os.listdir('fig')
        for file in file_list:
            st.image(os.path.join('fig', file))