Spaces:
Sleeping
Sleeping
import streamlit as st | |
from st_audiorec import st_audiorec | |
from Modules.Speech2Text.transcribe import transcribe | |
import base64 | |
from langchain_mistralai import ChatMistralAI | |
from langchain_core.prompts import ChatPromptTemplate | |
from dotenv import load_dotenv | |
load_dotenv() # load .env api keys | |
import os | |
from Modules.rag import rag_chain | |
from Modules.router import router_chain | |
# from Modules.PoseEstimation.pose_agent import agent_executor | |
mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
from Modules.PoseEstimation import pose_estimator | |
from utils import save_uploaded_file | |
st.set_page_config(layout="wide", initial_sidebar_state="collapsed") | |
# Create two columns | |
col1, col2 = st.columns(2) | |
video_uploaded = None | |
llm = ChatMistralAI(model="mistral-large-latest", mistral_api_key=mistral_api_key, temperature=0) | |
prompt = ChatPromptTemplate.from_template( | |
template =""" You are a personal AI sports coach with an expertise in nutrition and fitness. | |
You are having a conversation with your client, which is either a beginner or an advanced athlete. | |
You must be gentle, kind, and motivative. | |
Always try to answer concisely to the queries. | |
User: {question} | |
AI Coach:""" | |
) | |
base_chain = prompt | llm | |
# First column containers | |
with col1: | |
st.subheader("LLM answering") | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("What is up?"): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
# Build answer from LLM | |
direction = router_chain.invoke({"question":prompt}) | |
if direction=='fitness_advices': | |
response = rag_chain.invoke( | |
prompt | |
) | |
elif direction=='smalltalk': | |
response = base_chain.invoke( | |
{"question":prompt} | |
).content | |
# elif direction =='movement_analysis': | |
# response = agent_executor.invoke( | |
# {"input" : instruction} | |
# )["output"] | |
print(type(response)) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
st.markdown(response) | |
st.subheader("Movement Analysis") | |
# TO DO | |
# Second column containers | |
with col2: | |
st.subheader("Sports Agenda") | |
# TO DO | |
st.subheader("Video Analysis") | |
ask_video = st.empty() | |
if video_uploaded is None: | |
video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"]) | |
if video_uploaded: | |
video_uploaded = save_uploaded_file(video_uploaded) | |
ask_video.empty() | |
_left, mid, _right = st.columns(3) | |
with mid: | |
if os.path.exists('runs'): | |
st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True) | |
else : | |
st.video(video_uploaded) | |
st.subheader("Graph Displayer") | |
if os.path.exists('fig'): | |
file_list = os.listdir('fig') | |
for file in file_list: | |
st.image(os.path.join('fig', file)) |