|
from modules.data_class import DataState |
|
from modules.instrctions import MEDICAL_INTAKE_SYSINT, WELCOME_MSG |
|
from modules.llm_in_use import get_llm |
|
from modules.tools import patient_id, symptom, pain, medical_hist, family_hist, social_hist, review_system, pain_manage, functional, plan, confirm_data, get_data, clear_data, save_data |
|
|
|
from datetime import date |
|
from typing import Literal |
|
from langgraph.graph import StateGraph, START, END |
|
from langchain_core.messages.ai import AIMessage |
|
|
|
llm = get_llm() |
|
|
|
intake_tools = [patient_id, symptom, pain, medical_hist, family_hist, social_hist, review_system, pain_manage, functional, plan, confirm_data, get_data, clear_data, save_data] |
|
|
|
|
|
llm_with_tools = llm.bind_tools(intake_tools) |
|
|
|
def human_node(state: DataState) -> DataState: |
|
"""Display the last model message to the user, and receive the user's input.""" |
|
last_msg = state["messages"][-1] |
|
print("Model:", last_msg.content) |
|
|
|
user_input = input("User: ") |
|
|
|
|
|
|
|
if user_input in {"q", "quit", "exit", "goodbye"}: |
|
state["finished"] = True |
|
|
|
return state | {"messages": [("user", user_input)]} |
|
|
|
|
|
def maybe_exit_human_node(state: DataState) -> Literal["chatbot_healthassistant", "__end__"]: |
|
"""Route to the chatbot, unless it looks like the user is exiting.""" |
|
if state.get("finished", False): |
|
return END |
|
else: |
|
return "chatbot_healthassistant" |
|
|
|
|
|
def chatbot_with_tools(state: DataState) -> DataState: |
|
"""The chatbot with tools. A simple wrapper around the model's own chat interface.""" |
|
defaults = {"data": {"ID": { |
|
"name": "", |
|
"DOB": date(1900, 1, 1), |
|
"gender": "", |
|
"contact": "", |
|
"emergency_contact": "" |
|
}, |
|
"symptom": { |
|
"main_symptom": "", |
|
"symptom_length": "" |
|
}, |
|
"pain": { |
|
"pain_location": "", |
|
"pain_side": "", |
|
"pain_intensity": 0, |
|
"pain_description": "", |
|
"start_time": date(1900, 1, 1), |
|
"radiation": False, |
|
"triggers": "", |
|
"symptom": "" |
|
}, |
|
"medical_hist": { |
|
"medical_condition": "", |
|
"first_time": date(1900, 1, 1), |
|
"surgery_history": "", |
|
"medication": "", |
|
"allergy": "" |
|
}, |
|
"family_hist": { |
|
"family_history": "", |
|
}, |
|
"social_hist": { |
|
"occupation": "", |
|
"smoke": False, |
|
"alcohol": False, |
|
"drug": False, |
|
"support_system": "", |
|
"living_condition": "", |
|
}, |
|
"review_system": { |
|
"weight_change": "", |
|
"fever": False, |
|
"chill": False, |
|
"night_sweats": False, |
|
"sleep": "", |
|
"gastrointestinal": "", |
|
"urinary": "", |
|
}, |
|
"pain_manage": { |
|
"pain_medication": "", |
|
"specialist": False, |
|
"other_therapy": "", |
|
"effectiveness": False, |
|
}, |
|
"functional": { |
|
"life_quality": "", |
|
"limit_activity": "", |
|
"mood": "", |
|
}, |
|
"plan": { |
|
"goal": "", |
|
"expectation": "", |
|
"alternative_treatment": "", |
|
} |
|
}, "finished": False} |
|
|
|
if state["messages"]: |
|
new_output = llm_with_tools.invoke([MEDICAL_INTAKE_SYSINT] + state["messages"]) |
|
else: |
|
new_output = AIMessage(content=WELCOME_MSG) |
|
|
|
|
|
|
|
return defaults | state | {"messages": [new_output]} |
|
|
|
|
|
def maybe_route_to_tools(state: DataState) -> str: |
|
"""Route between chat and tool nodes if a tool call is made.""" |
|
if not (msgs := state.get("messages", [])): |
|
raise ValueError(f"No messages found when parsing state: {state}") |
|
|
|
msg = msgs[-1] |
|
|
|
if state.get("finished", False): |
|
|
|
|
|
|
|
return END |
|
|
|
elif hasattr(msg, "tool_calls") and len(msg.tool_calls) > 0: |
|
|
|
if any( |
|
tool["name"] for tool in msg.tool_calls |
|
): |
|
|
|
|
|
return "documenting" |
|
|
|
else: |
|
return "patient" |
|
|