|
import os |
|
|
|
import streamlit as st |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
|
|
from modules.graph import invoke_our_graph |
|
from modules.st_callable_util import get_streamlit_cb |
|
|
|
|
|
|
|
st.title("Paintrek Medical Assistant") |
|
st.markdown("Chat with an AI-powered health assistant.") |
|
|
|
|
|
|
|
if "expander_open" not in st.session_state: |
|
st.session_state.expander_open = True |
|
|
|
|
|
if not os.getenv('GOOGLE_API_KEY'): |
|
|
|
st.sidebar.header("GOOGLE_API_KEY Setup") |
|
api_key = st.sidebar.text_input(label="API Key", type="password", label_visibility="collapsed") |
|
os.environ["GOOGLE_API_KEY"] = api_key |
|
|
|
if not api_key: |
|
st.info("Please enter your GOOGLE_API_KEY in the sidebar.") |
|
st.stop() |
|
|
|
|
|
|
|
prompt = st.chat_input() |
|
|
|
|
|
if prompt is not None: |
|
st.session_state.expander_open = False |
|
|
|
|
|
with st.expander(label="Paintrek Bot", expanded=st.session_state.expander_open): |
|
""" |
|
At any time you can type 'q' or 'quit' to quit. |
|
""" |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [AIMessage(content="Welcome to the Paintrek world. I am a health assistant, an interactive clinical recording system. I will ask you questions about your pain and related symptoms and record your responses. I will then store this information securely. At any time, you can type `q` to quit.")] |
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
|
|
|
if isinstance(msg, AIMessage): |
|
st.chat_message("assistant").write(msg.content) |
|
elif isinstance(msg, HumanMessage): |
|
st.chat_message("user").write(msg.content) |
|
|
|
|
|
if prompt: |
|
st.session_state.messages.append(HumanMessage(content=prompt)) |
|
st.chat_message("user").write(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
|
|
st_callback = get_streamlit_cb(st.container()) |
|
response = invoke_our_graph(st.session_state.messages, [st_callback]) |
|
st.session_state.messages.append(AIMessage(content=response["messages"][-1].content)) |
|
st.write(response["messages"][-1].content) |
|
|