Spaces:
Sleeping
Sleeping
File size: 3,324 Bytes
a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 a755c90 5730896 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from langchain.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
import os
import sys
import json
sys.path.append(os.getcwd())
from Modules.PoseEstimation.pose_estimator import model, compute_right_knee_angle, moving_average
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
@tool
def get_keypoints_from_keypoints(video_path: str) -> str:
"""
Extracts keypoints from a video file.
Args:
video_path (str): path to the video file
Returns:
file_path (str): path to the JSON file containing the keypoints
"""
save_folder='tmp'
os.makedirs(save_folder, exist_ok=True)
keypoints = []
results = model(video_path, save=True, show_conf=False, show_boxes=False)
for (i, frame) in enumerate(results):
frame_dict = {}
frame_dict['frame'] = i
frame_dict['keypoints'] = frame.keypoints.xy[0].tolist()
keypoints.append(frame_dict)
file_path = os.path.join(save_folder, 'keypoints.json')
with open(file_path, 'w') as f:
json.dump(keypoints, f)
return file_path
def compute_right_knee_angle_list(json_path: str) -> list[float]:
"""
Computes the knee angle from a list of keypoints.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
right_knee_angle_list (list[float]): list of knee angles
"""
keypoints_list = json.load(open(json_path))
right_knee_angle_list = []
for keypoints in keypoints_list:
right_knee_angle = compute_right_knee_angle(keypoints['keypoints'])
right_knee_angle_list.append(right_knee_angle)
return moving_average(right_knee_angle_list, 10)
@tool
def check_knee_angle(json_path: str) -> bool:
"""
Checks if the minimum knee angle is smaller than a threshold.
If the minimum knee angle is smaller than 90 degrees, the squat is considered correct, because it means the user is going deep enough.
Args:
json_path (str): path to the JSON file containing the keypoints
Returns:
is_correct (bool): True if the minimum knee angle is smaller than a threshold, False otherwise
"""
angles_list = compute_right_knee_angle_list(json_path)
for angle in angles_list:
if angle < 90:
return True
return False
tools = [get_keypoints_from_keypoints, check_knee_angle]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the check_knee_angle tool if the user wants to check his movement.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# Construct the Tools agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
response = agent_executor.invoke({"input": f"Is my squat correct ? The json file is in tmp/keypoints.json."})
print(response["output"]) |