Spaces:
Sleeping
Sleeping
Merge branch 'main' of https://huggingface.co/spaces/EntrepreneurFirst/FitnessEquation
Browse files- Modules/PoseEstimation/pose_agent.py +17 -8
- app.py +39 -26
- data/AI_Bro.png +0 -0
- requirements.txt +10 -7
- uploaded/__init__.py +0 -0
- utils.py +22 -1
Modules/PoseEstimation/pose_agent.py
CHANGED
@@ -3,6 +3,7 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent
|
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.messages import HumanMessage
|
5 |
from langchain_mistralai.chat_models import ChatMistralAI
|
|
|
6 |
import os
|
7 |
import sys
|
8 |
import json
|
@@ -11,6 +12,7 @@ from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angl
|
|
11 |
|
12 |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
|
13 |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
|
|
|
14 |
|
15 |
@tool
|
16 |
def get_keypoints_from_keypoints(video_path: str) -> str:
|
@@ -27,7 +29,7 @@ def get_keypoints_from_keypoints(video_path: str) -> str:
|
|
27 |
save_folder='tmp'
|
28 |
os.makedirs(save_folder, exist_ok=True)
|
29 |
keypoints = []
|
30 |
-
results = model(video_path, save=True, show_conf=False, show_boxes=False)
|
31 |
for (i, frame) in enumerate(results):
|
32 |
frame_dict = {}
|
33 |
frame_dict['frame'] = i
|
@@ -77,7 +79,7 @@ def check_knee_angle(json_path: str) -> bool:
|
|
77 |
return False
|
78 |
|
79 |
@tool
|
80 |
-
def check_squat(
|
81 |
"""
|
82 |
Checks if the squat is correct.
|
83 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
@@ -89,8 +91,17 @@ def check_squat(video_path: str) -> bool:
|
|
89 |
Returns:
|
90 |
is_correct (bool): True if the squat is correct, False otherwise
|
91 |
"""
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
tools = [check_squat]
|
96 |
|
@@ -98,7 +109,7 @@ prompt = ChatPromptTemplate.from_messages(
|
|
98 |
[
|
99 |
(
|
100 |
"system",
|
101 |
-
"You are a helpful assistant. Make sure to use the
|
102 |
),
|
103 |
("placeholder", "{chat_history}"),
|
104 |
("human", "{input}"),
|
@@ -109,6 +120,4 @@ prompt = ChatPromptTemplate.from_messages(
|
|
109 |
# Construct the Tools agent
|
110 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
111 |
|
112 |
-
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
113 |
-
response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
|
114 |
-
print(response["output"])
|
|
|
3 |
from langchain_core.prompts import ChatPromptTemplate
|
4 |
from langchain_core.messages import HumanMessage
|
5 |
from langchain_mistralai.chat_models import ChatMistralAI
|
6 |
+
import torch
|
7 |
import os
|
8 |
import sys
|
9 |
import json
|
|
|
12 |
|
13 |
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
|
14 |
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
|
15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
|
17 |
@tool
|
18 |
def get_keypoints_from_keypoints(video_path: str) -> str:
|
|
|
29 |
save_folder='tmp'
|
30 |
os.makedirs(save_folder, exist_ok=True)
|
31 |
keypoints = []
|
32 |
+
results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device)
|
33 |
for (i, frame) in enumerate(results):
|
34 |
frame_dict = {}
|
35 |
frame_dict['frame'] = i
|
|
|
79 |
return False
|
80 |
|
81 |
@tool
|
82 |
+
def check_squat(file_name: str) -> str:
|
83 |
"""
|
84 |
Checks if the squat is correct.
|
85 |
This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
|
|
|
91 |
Returns:
|
92 |
is_correct (bool): True if the squat is correct, False otherwise
|
93 |
"""
|
94 |
+
|
95 |
+
video_path = os.path.join('uploaded', file_name)
|
96 |
+
if os.path.exists(video_path):
|
97 |
+
json_path = get_keypoints_from_keypoints(video_path)
|
98 |
+
is_correct = check_knee_angle(json_path)
|
99 |
+
if is_correct:
|
100 |
+
return "The squat is correct because your knee angle is smaller than 90 degrees."
|
101 |
+
else:
|
102 |
+
return "The squat is incorrect because your knee angle is greater than 90 degrees."
|
103 |
+
else:
|
104 |
+
return "The video file does not exist."
|
105 |
|
106 |
tools = [check_squat]
|
107 |
|
|
|
109 |
[
|
110 |
(
|
111 |
"system",
|
112 |
+
"You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
|
113 |
),
|
114 |
("placeholder", "{chat_history}"),
|
115 |
("human", "{input}"),
|
|
|
120 |
# Construct the Tools agent
|
121 |
agent = create_tool_calling_agent(llm, tools, prompt)
|
122 |
|
123 |
+
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
|
|
|
|
app.py
CHANGED
@@ -9,15 +9,15 @@ import json
|
|
9 |
from dotenv import load_dotenv
|
10 |
load_dotenv() # load .env api keys
|
11 |
import os
|
12 |
-
|
13 |
from Modules.rag import rag_chain
|
14 |
from Modules.router import router_chain
|
15 |
from Modules.workout_plan import workout_chain
|
16 |
-
|
17 |
|
18 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
19 |
from Modules.PoseEstimation import pose_estimator
|
20 |
-
from utils import save_uploaded_file
|
21 |
|
22 |
|
23 |
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
@@ -44,6 +44,9 @@ with col1:
|
|
44 |
|
45 |
if "messages" not in st.session_state:
|
46 |
st.session_state.messages = []
|
|
|
|
|
|
|
47 |
|
48 |
for message in st.session_state.messages:
|
49 |
with st.chat_message(message["role"]):
|
@@ -54,25 +57,28 @@ with col1:
|
|
54 |
with st.chat_message("user"):
|
55 |
st.markdown(prompt)
|
56 |
|
57 |
-
with st.chat_message("assistant"):
|
58 |
# Build answer from LLM
|
59 |
direction = router_chain.invoke({"question":prompt})
|
60 |
print(type(direction))
|
61 |
print(direction)
|
62 |
if direction=='fitness_advices':
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
elif direction=='smalltalk':
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
elif direction =='movement_analysis':
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
76 |
else:
|
77 |
response = "Sure! I just made a workout for you. Check on the table I just provided you."
|
78 |
json_output = workout_chain.invoke({"query":prompt})
|
@@ -87,28 +93,35 @@ with col1:
|
|
87 |
if display_workout:
|
88 |
st.subheader("Workout")
|
89 |
st.data_editor(workout_df)
|
90 |
-
# TO DO
|
91 |
# Second column containers
|
92 |
with col2:
|
93 |
-
st.subheader("Sports Agenda")
|
94 |
# TO DO
|
95 |
st.subheader("Video Analysis")
|
96 |
-
|
97 |
-
|
98 |
-
video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
|
99 |
if video_uploaded:
|
100 |
video_uploaded = save_uploaded_file(video_uploaded)
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
with mid:
|
104 |
-
if os.path.exists('runs'):
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
else :
|
107 |
st.video(video_uploaded)
|
108 |
|
109 |
-
|
110 |
-
st.subheader("Graph Displayer")
|
111 |
if os.path.exists('fig'):
|
|
|
112 |
file_list = os.listdir('fig')
|
113 |
for file in file_list:
|
114 |
st.image(os.path.join('fig', file))
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
load_dotenv() # load .env api keys
|
11 |
import os
|
12 |
+
import shutil
|
13 |
from Modules.rag import rag_chain
|
14 |
from Modules.router import router_chain
|
15 |
from Modules.workout_plan import workout_chain
|
16 |
+
from Modules.PoseEstimation.pose_agent import agent_executor
|
17 |
|
18 |
mistral_api_key = os.getenv("MISTRAL_API_KEY")
|
19 |
from Modules.PoseEstimation import pose_estimator
|
20 |
+
from utils import save_uploaded_file, encode_video_H264
|
21 |
|
22 |
|
23 |
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
|
|
44 |
|
45 |
if "messages" not in st.session_state:
|
46 |
st.session_state.messages = []
|
47 |
+
|
48 |
+
if "file_name" not in st.session_state:
|
49 |
+
st.session_state.file_name = None
|
50 |
|
51 |
for message in st.session_state.messages:
|
52 |
with st.chat_message(message["role"]):
|
|
|
57 |
with st.chat_message("user"):
|
58 |
st.markdown(prompt)
|
59 |
|
60 |
+
with st.chat_message("assistant", avatar="data/AI_Bro.png"):
|
61 |
# Build answer from LLM
|
62 |
direction = router_chain.invoke({"question":prompt})
|
63 |
print(type(direction))
|
64 |
print(direction)
|
65 |
if direction=='fitness_advices':
|
66 |
+
with st.spinner("Thinking..."):
|
67 |
+
response = rag_chain.invoke(
|
68 |
+
prompt
|
69 |
+
)
|
70 |
elif direction=='smalltalk':
|
71 |
+
with st.spinner("Thinking..."):
|
72 |
+
response = base_chain.invoke(
|
73 |
+
{"question":prompt}
|
74 |
+
).content
|
75 |
elif direction =='movement_analysis':
|
76 |
+
if st.session_state.file_name is not None:
|
77 |
+
prompt += "the file name is " + st.session_state.file_name
|
78 |
+
with st.spinner("Analyzing movement..."):
|
79 |
+
response = agent_executor.invoke(
|
80 |
+
{"input" : prompt}
|
81 |
+
)["output"]
|
82 |
else:
|
83 |
response = "Sure! I just made a workout for you. Check on the table I just provided you."
|
84 |
json_output = workout_chain.invoke({"query":prompt})
|
|
|
93 |
if display_workout:
|
94 |
st.subheader("Workout")
|
95 |
st.data_editor(workout_df)
|
|
|
96 |
# Second column containers
|
97 |
with col2:
|
98 |
+
# st.subheader("Sports Agenda")
|
99 |
# TO DO
|
100 |
st.subheader("Video Analysis")
|
101 |
+
|
102 |
+
video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
|
|
|
103 |
if video_uploaded:
|
104 |
video_uploaded = save_uploaded_file(video_uploaded)
|
105 |
+
if video_uploaded.split("/")[-1] != st.session_state.file_name:
|
106 |
+
shutil.rmtree('fig', ignore_errors=True)
|
107 |
+
shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True)
|
108 |
+
st.session_state.file_name = None
|
109 |
+
st.session_state.file_name = video_uploaded.split("/")[-1]
|
110 |
+
_left, mid, _right = st.columns([1, 3, 1])
|
111 |
with mid:
|
112 |
+
if os.path.exists('/home/user/.pyenv/runs'):
|
113 |
+
predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose'))
|
114 |
+
predict_list.sort()
|
115 |
+
predict_dir = predict_list[-1]
|
116 |
+
file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0]
|
117 |
+
file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name)
|
118 |
+
file_path = encode_video_H264(file_path, remove_original=True)
|
119 |
+
st.video(file_path, loop=True)
|
120 |
else :
|
121 |
st.video(video_uploaded)
|
122 |
|
|
|
|
|
123 |
if os.path.exists('fig'):
|
124 |
+
st.subheader("Graph Displayer")
|
125 |
file_list = os.listdir('fig')
|
126 |
for file in file_list:
|
127 |
st.image(os.path.join('fig', file))
|
data/AI_Bro.png
ADDED
![]() |
requirements.txt
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
transformers
|
2 |
torch
|
3 |
-
langchain-core
|
4 |
-
langchain-mistralai
|
5 |
pandas
|
6 |
-
langchain-community
|
7 |
streamlit-audiorec
|
8 |
openai-whisper
|
9 |
tiktoken
|
10 |
-
langchain
|
11 |
bs4
|
12 |
chromadb
|
13 |
-
langgraph
|
14 |
-
langchainhub
|
15 |
pypdf
|
16 |
duckduckgo-search
|
17 |
python-dotenv
|
18 |
pypdf
|
19 |
chromadb
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
transformers
|
2 |
torch
|
|
|
|
|
3 |
pandas
|
|
|
4 |
streamlit-audiorec
|
5 |
openai-whisper
|
6 |
tiktoken
|
|
|
7 |
bs4
|
8 |
chromadb
|
|
|
|
|
9 |
pypdf
|
10 |
duckduckgo-search
|
11 |
python-dotenv
|
12 |
pypdf
|
13 |
chromadb
|
14 |
+
moviepy
|
15 |
+
ultralytics
|
16 |
+
langchain==0.1.16
|
17 |
+
langchain-community==0.0.34
|
18 |
+
langchain-core==0.1.45
|
19 |
+
langchain-experimental==0.0.57
|
20 |
+
langchain-mistralai==0.1.2
|
21 |
+
langchain-text-splitters==0.0.1
|
22 |
+
langchainhub==0.1.15
|
23 |
+
langsmith==0.1.50
|
uploaded/__init__.py
ADDED
File without changes
|
utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import os
|
|
|
3 |
|
4 |
def save_uploaded_file(uploaded_file):
|
5 |
try:
|
@@ -9,4 +10,24 @@ def save_uploaded_file(uploaded_file):
|
|
9 |
return file_path
|
10 |
except Exception as e:
|
11 |
st.error(f"Error: {e}")
|
12 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
import os
|
3 |
+
from moviepy.editor import VideoFileClip
|
4 |
|
5 |
def save_uploaded_file(uploaded_file):
|
6 |
try:
|
|
|
10 |
return file_path
|
11 |
except Exception as e:
|
12 |
st.error(f"Error: {e}")
|
13 |
+
return None
|
14 |
+
|
15 |
+
def encode_video_H264(video_path, remove_original=False):
|
16 |
+
"""
|
17 |
+
Encode video to H264 codec
|
18 |
+
|
19 |
+
Args:
|
20 |
+
video_path (str): path to video to be encoded
|
21 |
+
remove_original (bool): whether to remove original video after encoding
|
22 |
+
Returns:
|
23 |
+
output_path (str): path to encoded video
|
24 |
+
"""
|
25 |
+
|
26 |
+
output_path = video_path.split('.')[0] + '_H264.mp4'
|
27 |
+
clip = VideoFileClip(video_path)
|
28 |
+
clip.write_videofile(output_path, codec='libx264')
|
29 |
+
if remove_original:
|
30 |
+
os.remove(video_path)
|
31 |
+
clip.close()
|
32 |
+
|
33 |
+
return output_path
|