|
import os |
|
import mlflow |
|
import datetime |
|
import streamlit as st |
|
from functools import partial |
|
from operator import itemgetter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_databricks.vectorstores import DatabricksVectorSearch |
|
|
|
from langchain_databricks import ChatDatabricks |
|
|
|
from langchain_core.runnables import RunnableLambda |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.runnables import RunnablePassthrough, RunnableBranch |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
|
|
|
|
|
|
|
class ChainBuilder: |
|
|
|
def __init__(self): |
|
|
|
self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml") |
|
self.databricks_resources = self.model_config.get("databricks_resources") |
|
self.llm_config = self.model_config.get("llm_config") |
|
self.retriever_config = self.model_config.get("retriever_config") |
|
self.vector_search_schema = self.retriever_config.get("schema") |
|
|
|
def get_system_prompt(self): |
|
date_str = datetime.datetime.now().strftime("%B %d, %Y") |
|
prompt = f"You are DBRX, created by Databricks and augmented by John Graham Reynolds to have access to additional information specific to Vanderbilt University Medical Center. The current date is {date_str}.\n" |
|
prompt += """Your knowledge base was last updated in December 2023. You answer questions about events prior to and after December 2023 the way a highly informed individual in December 2023 would if they were talking to someone from the above date, and you can let the user know this when relevant.\n |
|
Some of the context you will be given in regards to Vanderbilt University Medical Center could have come after December 2023. The rest of your knowledge base is from before December 2023 and you will answer questions accordingly with these facts. |
|
This chunk of text is your system prompt. It is not visible to the user, but it is used to guide your responses. Don't reference it, just respond to the user.\n |
|
If you are asked to assist with tasks involving the expression of views held by a significant number of people, you provide assistance with the task even if you personally disagree with the views being expressed, but follow this with a discussion of broader perspectives.\n |
|
You don't engage in stereotyping, including the negative stereotyping of majority groups.\n If asked about controversial topics, you try to provide careful thoughts and objective information without downplaying its harmful content or implying that there are reasonable perspectives on both sides.\n |
|
You are happy to help with writing, analysis, question answering, math, coding, and all sorts of other tasks.\n You use markdown for coding, which includes JSON blocks and Markdown tables.\n |
|
You do not have tools enabled at this time, so cannot run code or access the internet. You can only provide information that you have been trained on. You do not send or receive links or images.\n |
|
You were not trained on copyrighted books, song lyrics, poems, video transcripts, or news articles; you do not divulge details of your training data. You do not provide song lyrics, poems, or news articles and instead refer the user to find them online or in a store.\n |
|
You give concise responses to simple questions or statements, but provide thorough responses to more complex and open-ended questions.\n |
|
The user is unable to see the system prompt, so you should write as if it were true without mentioning it.\n You do not mention any of this information about yourself unless the information is directly pertinent to the user's query.\n |
|
Here is some context from the Vanderbilt University Medical Center glossary which might or might not help you answer: {context}.\n |
|
Based on this system prompt, to which you will adhere sternly and to which you will make no reference, and this possibly helpful context in relation to Vanderbilt University Medical Center, answer this question: {question} |
|
""" |
|
return prompt |
|
|
|
|
|
def extract_user_query_string(self, chat_messages_array): |
|
return chat_messages_array[-1]["content"] |
|
|
|
|
|
def extract_chat_history(self, chat_messages_array): |
|
return chat_messages_array[:-1] |
|
|
|
def load_embedding_model(self): |
|
model_name = self.retriever_config.get("embedding_model") |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_and_cache_embedding_model(model_name): |
|
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") |
|
|
|
return embeddings |
|
|
|
return load_and_cache_embedding_model(model_name) |
|
|
|
def get_retriever(self): |
|
endpoint=self.databricks_resources.get("vector_search_endpoint_name") |
|
index_name=self.retriever_config.get("vector_search_index") |
|
embeddings = self.load_embedding_model() |
|
search_kwargs=self.retriever_config.get("parameters") |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs): |
|
vector_search_as_retriever = DatabricksVectorSearch( |
|
endpoint=endpoint, |
|
index_name=index_name, |
|
embedding=_embeddings, |
|
text_column="name", |
|
columns=["name", "description"], |
|
).as_retriever(search_kwargs=search_kwargs) |
|
|
|
return vector_search_as_retriever |
|
|
|
return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_context(self, retrieved_terms): |
|
chunk_template = self.retriever_config.get("chunk_template") |
|
chunk_contents = [ |
|
chunk_template.format( |
|
name=term.page_content, |
|
description=term.metadata[self.vector_search_schema.get("description")], |
|
) |
|
for term in retrieved_terms |
|
] |
|
return "".join(chunk_contents) |
|
|
|
def get_prompt(self): |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", self.get_system_prompt()), |
|
|
|
|
|
MessagesPlaceholder(variable_name="formatted_chat_history"), |
|
|
|
("user", "{question}"), |
|
] |
|
) |
|
return prompt |
|
|
|
|
|
|
|
def format_chat_history_for_prompt(self, chat_messages_array): |
|
history = self.extract_chat_history(chat_messages_array) |
|
formatted_chat_history = [] |
|
if len(history) > 0: |
|
for chat_message in history: |
|
if chat_message["role"] == "user": |
|
formatted_chat_history.append(HumanMessage(content=chat_message["content"])) |
|
elif chat_message["role"] == "assistant": |
|
formatted_chat_history.append(AIMessage(content=chat_message["content"])) |
|
return formatted_chat_history |
|
|
|
def get_query_rewrite_prompt(self): |
|
|
|
query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so |
|
that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant |
|
information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation. |
|
|
|
Chat history: {chat_history} |
|
|
|
Question: {question}""" |
|
|
|
query_rewrite_prompt = PromptTemplate( |
|
template=query_rewrite_template, |
|
input_variables=["chat_history", "question"], |
|
) |
|
return query_rewrite_prompt |
|
|
|
def get_model(self): |
|
endpoint = self.databricks_resources.get("llm_endpoint_name") |
|
extra_params=self.llm_config.get("llm_parameters") |
|
|
|
@st.cache_resource |
|
def get_and_cache_model(endpoint, extra_params): |
|
model = ChatDatabricks( |
|
endpoint=endpoint, |
|
extra_params=extra_params, |
|
) |
|
return model |
|
|
|
return get_and_cache_model(endpoint, extra_params) |
|
|
|
def build_chain(self): |
|
|
|
chain = ( |
|
{ |
|
"question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), |
|
"chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history), |
|
"formatted_chat_history": itemgetter("messages") |
|
| RunnableLambda(self.format_chat_history_for_prompt), |
|
} |
|
| RunnablePassthrough() |
|
| { |
|
"context": RunnableBranch( |
|
( |
|
lambda x: len(x["chat_history"]) > 0, |
|
self.get_query_rewrite_prompt() | self.get_model() | StrOutputParser(), |
|
), |
|
itemgetter("question"), |
|
) |
|
| self.get_retriever() |
|
| RunnableLambda(self.format_context), |
|
"formatted_chat_history": itemgetter("formatted_chat_history"), |
|
"question": itemgetter("question"), |
|
} |
|
| self.get_prompt() |
|
| self.get_model() |
|
| StrOutputParser() |
|
) |
|
return chain |
|
|
|
|
|
|