Théo Rousseaux
squat agent working
5730896
raw
history blame
3.32 kB
from langchain.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
import os
import sys
import json
sys.path.append(os.getcwd())
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
@tool
def get_keypoints_from_keypoints(video_path: str) -> str:
"""
Extracts keypoints from a video file.
Args:
video_path (str): path to the video file
Returns:
file_path (str): path to the JSON file containing the keypoints
"""
save_folder='tmp'
os.makedirs(save_folder, exist_ok=True)
keypoints = []
results = model(video_path, save=True, show_conf=False, show_boxes=False)
for (i, frame) in enumerate(results):
frame_dict = {}
frame_dict['frame'] = i
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist()
keypoints.append(frame_dict)
file_path = os.path.join(save_folder, 'keypoints.json')
with open(file_path, 'w') as f:
json.dump(keypoints, f)
return file_path
def compute_right_knee_angle_list(json_path: str) -> list[float]:
"""
Computes the knee angle from a list of keypoints.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
right_knee_angle_list (list[float]): list of knee angles
"""
keypoints_list = json.load(open(json_path))
right_knee_angle_list = []
for keypoints in keypoints_list:
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
right_knee_angle_list.append(right_knee_angle)
return moving_average(right_knee_angle_list, 10)
@tool
def check_knee_angle(json_path: str) -> bool:
"""
Checks if the minimum knee angle is smaller than a threshold.
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise
"""
angles_list = compute_right_knee_angle_list(json_path)
for angle in angles_list:
if angle < 90:
return True
return False
tools = [get_keypoints_from_keypoints, check_knee_angle]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# Construct the Tools agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
response = agent_executor.invoke({"input": f"Is my squat correct ? The json file is in tmp/keypoints.json."})
print(response["output"])