Théo Rousseaux
début pose agent
a755c90
raw
history blame
2.03 kB
from Modules.PoseEstimation.pose_estimator import calculate_angle, joints_id_dict, model
from langchain.tools import tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_mistralai.chat_models import ChatMistralAI
# If api_key is not passed, default behavior is to use the `MISTRAL_API_KEY` environment variable.
llm = ChatMistralAI(model='mistral-large-latest', api_key="i5jSJkCFNGKfgIztloxTMjfckiFbYBj4")
@tool
def compute_right_knee_angle(pose: list) -> float:
"""
Computes the knee angle.
Args:
pose (list): list of keypoints
Returns:
knee_angle (float): knee angle
"""
right_hip = pose[joints_id_dict['right_hip']]
right_knee = pose[joints_id_dict['right_knee']]
right_ankle = pose[joints_id_dict['right_ankle']]
knee_angle = calculate_angle(right_hip, right_knee, right_ankle)
print(knee_angle)
return str(knee_angle)
@tool
def get_keypoints_from_path(video_path: str):
"""
Get keypoints from a video.
Args:
video_path (str): path to the video
model (YOLO): model to use
Returns:
keypoints (list): list of keypoints
"""
keypoints = []
results = model(video_path, save=True, show_conf=False, show_boxes=False)
for frame in results:
tensor = frame.keypoints.xy[0]
keypoints.append(tensor.tolist())
return keypoints
tools = [compute_right_knee_angle]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant. Make sure to use the compute_right_knee_angle tool for information.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
# Construct the Tools agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)