Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from mlflow import deployments | |
from databricks.vector_search.client import VectorSearchClient | |
DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST") | |
DATABRICKS_API_TOKEN = os.environ.get("DATABRICKS_API_TOKEN") | |
VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME") | |
VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME") | |
if DATABRICKS_HOST is None: | |
raise ValueError("DATABRICKS_HOST environment variable must be set") | |
if DATABRICKS_API_TOKEN is None: | |
raise ValueError("DATABRICKS_API_TOKEN environment variable must be set") | |
TITLE = "VUMC Chatbot" | |
DESCRIPTION="The first generation VUMC chatbot with knowledge of Vanderbilt specific terms." | |
EXAMPLE_PROMPTS = [ | |
"Write a short story about a robot that has a nice day.", | |
"In a table, what are some of the most common misconceptions about birds?", | |
"Give me a recipe for vegan banana bread.", | |
"Code a python function that can run merge sort on a list.", | |
"Give me the character profile of a gumdrop obsessed knight in JSON.", | |
"Write a rap battle between Alan Turing and Claude Shannon.", | |
] | |
st.set_page_config(layout="wide") | |
st.title(TITLE) | |
# test env vars get output correctly - do we need to configure PAT access further? | |
st.markdown(DESCRIPTION) | |
st.markdown("\n") | |
st.write(DATABRICKS_HOST) | |
st.markdown("\n") | |
st.write(VS_ENDPOINT_NAME) | |
st.markdown("\n") | |
# use this to format later | |
# with open("style.css") as css: | |
# st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True) | |
# TODO *** configure to run only on prompt for verification? | |
vsc = VectorSearchClient() | |
question = "What is the data lake?" | |
# question_2 = "What does EDW stand for?" | |
# question_3 = "What does AIDET stand for?" | |
deploy_client = deployments.get_deploy_client("databricks") | |
response = deploy_client.predict(endpoint="databricks-bge-large-en", inputs={"input": [question]}) | |
embeddings = [e['embedding'] for e in response.data] | |
results = vsc.get_index(VS_ENDPOINT_NAME, VS_INDEX_NAME).similarity_search( | |
query_vector=embeddings[0], | |
columns=["name", "description"], | |
num_results=5) | |
st.write(results) | |
# print(results) | |
# print("---------------------------------------") | |
# vumc_terms = results.get('result', {}).get('data_array', []) | |
# print(vumc_terms) | |
# DBRX mainbody minus functions | |
# main = st.container() | |
# with main: | |
# history = st.container(height=400) | |
# with history: | |
# for message in st.session_state["messages"]: | |
# avatar = None | |
# if message["role"] == "assistant": | |
# avatar = MODEL_AVATAR_URL | |
# with st.chat_message(message["role"],avatar=avatar): | |
# if message["content"] is not None: | |
# st.markdown(message["content"]) | |
# if message["error"] is not None: | |
# st.error(message["error"],icon="🚨") | |
# if message["warning"] is not None: | |
# st.warning(message["warning"],icon="⚠️") | |
# if prompt := st.chat_input("Type a message!", max_chars=1000): | |
# handle_user_input(prompt) | |
# st.markdown("\n") #add some space for iphone users | |
# with st.sidebar: | |
# with st.container(): | |
# st.title("Examples") | |
# for prompt in EXAMPLE_PROMPTS: | |
# st.button(prompt, args=(prompt,), on_click=handle_user_input) |