Doux Thibault commited on
Commit
00a277e
·
2 Parent(s): c104abf 45e5f54

Merge branch 'main' of https://huggingface.co/spaces/EntrepreneurFirst/FitnessEquation

Browse files
Modules/PoseEstimation/pose_agent.py CHANGED
@@ -3,6 +3,7 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.messages import HumanMessage
5
  from langchain_mistralai.chat_models import ChatMistralAI
 
6
  import os
7
  import sys
8
  import json
@@ -11,6 +12,7 @@ from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angl
11
 
12
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
13
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
 
14
 
15
  @tool
16
  def get_keypoints_from_keypoints(video_path: str) -> str:
@@ -27,7 +29,7 @@ def get_keypoints_from_keypoints(video_path: str) -> str:
27
  save_folder='tmp'
28
  os.makedirs(save_folder, exist_ok=True)
29
  keypoints = []
30
- results = model(video_path, save=True, show_conf=False, show_boxes=False)
31
  for (i, frame) in enumerate(results):
32
  frame_dict = {}
33
  frame_dict['frame'] = i
@@ -77,7 +79,7 @@ def check_knee_angle(json_path: str) -> bool:
77
  return False
78
 
79
  @tool
80
- def check_squat(video_path: str) -> bool:
81
  """
82
  Checks if the squat is correct.
83
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
@@ -89,8 +91,17 @@ def check_squat(video_path: str) -> bool:
89
  Returns:
90
  is_correct (bool): True if the squat is correct, False otherwise
91
  """
92
- json_path = get_keypoints_from_keypoints(video_path)
93
- return check_knee_angle(json_path)
 
 
 
 
 
 
 
 
 
94
 
95
  tools = [check_squat]
96
 
@@ -98,7 +109,7 @@ prompt = ChatPromptTemplate.from_messages(
98
  [
99
  (
100
  "system",
101
- "You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement. Also explain your response",
102
  ),
103
  ("placeholder", "{chat_history}"),
104
  ("human", "{input}"),
@@ -109,6 +120,4 @@ prompt = ChatPromptTemplate.from_messages(
109
  # Construct the Tools agent
110
  agent = create_tool_calling_agent(llm, tools, prompt)
111
 
112
- agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
113
- response = agent_executor.invoke({"input": f"Is my squat correct ? The video file is in data/pose/squat.mp4."})
114
- print(response["output"])
 
3
  from langchain_core.prompts import ChatPromptTemplate
4
  from langchain_core.messages import HumanMessage
5
  from langchain_mistralai.chat_models import ChatMistralAI
6
+ import torch
7
  import os
8
  import sys
9
  import json
 
12
 
13
  # If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
14
  llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
 
17
  @tool
18
  def get_keypoints_from_keypoints(video_path: str) -> str:
 
29
  save_folder='tmp'
30
  os.makedirs(save_folder, exist_ok=True)
31
  keypoints = []
32
+ results = model(video_path, save=True, show_conf=False, show_boxes=False, device=device)
33
  for (i, frame) in enumerate(results):
34
  frame_dict = {}
35
  frame_dict['frame'] = i
 
79
  return False
80
 
81
  @tool
82
+ def check_squat(file_name: str) -> str:
83
  """
84
  Checks if the squat is correct.
85
  This function uses the check_knee_angle tool to check if the squat is correct. If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
 
91
  Returns:
92
  is_correct (bool): True if the squat is correct, False otherwise
93
  """
94
+
95
+ video_path = os.path.join('uploaded', file_name)
96
+ if os.path.exists(video_path):
97
+ json_path = get_keypoints_from_keypoints(video_path)
98
+ is_correct = check_knee_angle(json_path)
99
+ if is_correct:
100
+ return "The squat is correct because your knee angle is smaller than 90 degrees."
101
+ else:
102
+ return "The squat is incorrect because your knee angle is greater than 90 degrees."
103
+ else:
104
+ return "The video file does not exist."
105
 
106
  tools = [check_squat]
107
 
 
109
  [
110
  (
111
  "system",
112
+ "You are a helpful assistant. Make sure to use the check_squat tool if the user wants to check his movement. Also explain your response",
113
  ),
114
  ("placeholder", "{chat_history}"),
115
  ("human", "{input}"),
 
120
  # Construct the Tools agent
121
  agent = create_tool_calling_agent(llm, tools, prompt)
122
 
123
+ agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
 
 
app.py CHANGED
@@ -9,15 +9,15 @@ import json
9
  from dotenv import load_dotenv
10
  load_dotenv() # load .env api keys
11
  import os
12
-
13
  from Modules.rag import rag_chain
14
  from Modules.router import router_chain
15
  from Modules.workout_plan import workout_chain
16
- # from Modules.PoseEstimation.pose_agent import agent_executor
17
 
18
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
19
  from Modules.PoseEstimation import pose_estimator
20
- from utils import save_uploaded_file
21
 
22
 
23
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
@@ -44,6 +44,9 @@ with col1:
44
 
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
 
 
 
47
 
48
  for message in st.session_state.messages:
49
  with st.chat_message(message["role"]):
@@ -54,25 +57,28 @@ with col1:
54
  with st.chat_message("user"):
55
  st.markdown(prompt)
56
 
57
- with st.chat_message("assistant"):
58
  # Build answer from LLM
59
  direction = router_chain.invoke({"question":prompt})
60
  print(type(direction))
61
  print(direction)
62
  if direction=='fitness_advices':
63
- response = rag_chain.invoke(
64
- prompt
65
- )
 
66
  elif direction=='smalltalk':
67
- response = base_chain.invoke(
68
- {"question":prompt}
69
- ).content
 
70
  elif direction =='movement_analysis':
71
- response = "I can't do that for the moment"
72
- # response = agent_executor.invoke(
73
- # {"input" : instruction}
74
- # )["output"]
75
- # elif direction == 'workout_plan':
 
76
  else:
77
  response = "Sure! I just made a workout for you. Check on the table I just provided you."
78
  json_output = workout_chain.invoke({"query":prompt})
@@ -87,28 +93,35 @@ with col1:
87
  if display_workout:
88
  st.subheader("Workout")
89
  st.data_editor(workout_df)
90
- # TO DO
91
  # Second column containers
92
  with col2:
93
- st.subheader("Sports Agenda")
94
  # TO DO
95
  st.subheader("Video Analysis")
96
- ask_video = st.empty()
97
- if video_uploaded is None:
98
- video_uploaded = ask_video.file_uploader("Choose a video file", type=["mp4", "ogg", "webm"])
99
  if video_uploaded:
100
  video_uploaded = save_uploaded_file(video_uploaded)
101
- ask_video.empty()
102
- _left, mid, _right = st.columns(3)
 
 
 
 
103
  with mid:
104
- if os.path.exists('runs'):
105
- st.video(os.path.join('runs', 'pose', 'predict', 'squat.mp4'), loop=True)
 
 
 
 
 
 
106
  else :
107
  st.video(video_uploaded)
108
 
109
-
110
- st.subheader("Graph Displayer")
111
  if os.path.exists('fig'):
 
112
  file_list = os.listdir('fig')
113
  for file in file_list:
114
  st.image(os.path.join('fig', file))
 
9
  from dotenv import load_dotenv
10
  load_dotenv() # load .env api keys
11
  import os
12
+ import shutil
13
  from Modules.rag import rag_chain
14
  from Modules.router import router_chain
15
  from Modules.workout_plan import workout_chain
16
+ from Modules.PoseEstimation.pose_agent import agent_executor
17
 
18
  mistral_api_key = os.getenv("MISTRAL_API_KEY")
19
  from Modules.PoseEstimation import pose_estimator
20
+ from utils import save_uploaded_file, encode_video_H264
21
 
22
 
23
  st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
 
44
 
45
  if "messages" not in st.session_state:
46
  st.session_state.messages = []
47
+
48
+ if "file_name" not in st.session_state:
49
+ st.session_state.file_name = None
50
 
51
  for message in st.session_state.messages:
52
  with st.chat_message(message["role"]):
 
57
  with st.chat_message("user"):
58
  st.markdown(prompt)
59
 
60
+ with st.chat_message("assistant", avatar="data/AI_Bro.png"):
61
  # Build answer from LLM
62
  direction = router_chain.invoke({"question":prompt})
63
  print(type(direction))
64
  print(direction)
65
  if direction=='fitness_advices':
66
+ with st.spinner("Thinking..."):
67
+ response = rag_chain.invoke(
68
+ prompt
69
+ )
70
  elif direction=='smalltalk':
71
+ with st.spinner("Thinking..."):
72
+ response = base_chain.invoke(
73
+ {"question":prompt}
74
+ ).content
75
  elif direction =='movement_analysis':
76
+ if st.session_state.file_name is not None:
77
+ prompt += "the file name is " + st.session_state.file_name
78
+ with st.spinner("Analyzing movement..."):
79
+ response = agent_executor.invoke(
80
+ {"input" : prompt}
81
+ )["output"]
82
  else:
83
  response = "Sure! I just made a workout for you. Check on the table I just provided you."
84
  json_output = workout_chain.invoke({"query":prompt})
 
93
  if display_workout:
94
  st.subheader("Workout")
95
  st.data_editor(workout_df)
 
96
  # Second column containers
97
  with col2:
98
+ # st.subheader("Sports Agenda")
99
  # TO DO
100
  st.subheader("Video Analysis")
101
+
102
+ video_uploaded = st.file_uploader("Choose a video file", type=["mp4", "ogg", "webm", "MOV"])
 
103
  if video_uploaded:
104
  video_uploaded = save_uploaded_file(video_uploaded)
105
+ if video_uploaded.split("/")[-1] != st.session_state.file_name:
106
+ shutil.rmtree('fig', ignore_errors=True)
107
+ shutil.rmtree('/home/user/.pyenv/runs', ignore_errors=True)
108
+ st.session_state.file_name = None
109
+ st.session_state.file_name = video_uploaded.split("/")[-1]
110
+ _left, mid, _right = st.columns([1, 3, 1])
111
  with mid:
112
+ if os.path.exists('/home/user/.pyenv/runs'):
113
+ predict_list = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose'))
114
+ predict_list.sort()
115
+ predict_dir = predict_list[-1]
116
+ file_name = os.listdir(os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir))[0]
117
+ file_path =os.path.join('/home/user/.pyenv/runs', 'pose', predict_dir, file_name)
118
+ file_path = encode_video_H264(file_path, remove_original=True)
119
+ st.video(file_path, loop=True)
120
  else :
121
  st.video(video_uploaded)
122
 
 
 
123
  if os.path.exists('fig'):
124
+ st.subheader("Graph Displayer")
125
  file_list = os.listdir('fig')
126
  for file in file_list:
127
  st.image(os.path.join('fig', file))
data/AI_Bro.png ADDED
requirements.txt CHANGED
@@ -1,20 +1,23 @@
1
  transformers
2
  torch
3
- langchain-core
4
- langchain-mistralai
5
  pandas
6
- langchain-community
7
  streamlit-audiorec
8
  openai-whisper
9
  tiktoken
10
- langchain
11
  bs4
12
  chromadb
13
- langgraph
14
- langchainhub
15
  pypdf
16
  duckduckgo-search
17
  python-dotenv
18
  pypdf
19
  chromadb
20
- ultralytics
 
 
 
 
 
 
 
 
 
 
1
  transformers
2
  torch
 
 
3
  pandas
 
4
  streamlit-audiorec
5
  openai-whisper
6
  tiktoken
 
7
  bs4
8
  chromadb
 
 
9
  pypdf
10
  duckduckgo-search
11
  python-dotenv
12
  pypdf
13
  chromadb
14
+ moviepy
15
+ ultralytics
16
+ langchain==0.1.16
17
+ langchain-community==0.0.34
18
+ langchain-core==0.1.45
19
+ langchain-experimental==0.0.57
20
+ langchain-mistralai==0.1.2
21
+ langchain-text-splitters==0.0.1
22
+ langchainhub==0.1.15
23
+ langsmith==0.1.50
uploaded/__init__.py ADDED
File without changes
utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import os
 
3
 
4
  def save_uploaded_file(uploaded_file):
5
  try:
@@ -9,4 +10,24 @@ def save_uploaded_file(uploaded_file):
9
  return file_path
10
  except Exception as e:
11
  st.error(f"Error: {e}")
12
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import os
3
+ from moviepy.editor import VideoFileClip
4
 
5
  def save_uploaded_file(uploaded_file):
6
  try:
 
10
  return file_path
11
  except Exception as e:
12
  st.error(f"Error: {e}")
13
+ return None
14
+
15
+ def encode_video_H264(video_path, remove_original=False):
16
+ """
17
+ Encode video to H264 codec
18
+
19
+ Args:
20
+ video_path (str): path to video to be encoded
21
+ remove_original (bool): whether to remove original video after encoding
22
+ Returns:
23
+ output_path (str): path to encoded video
24
+ """
25
+
26
+ output_path = video_path.split('.')[0] + '_H264.mp4'
27
+ clip = VideoFileClip(video_path)
28
+ clip.write_videofile(output_path, codec='libx264')
29
+ if remove_original:
30
+ os.remove(video_path)
31
+ clip.close()
32
+
33
+ return output_path