|
import os |
|
import mlflow |
|
import datetime |
|
import streamlit as st |
|
from operator import itemgetter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_databricks.vectorstores import DatabricksVectorSearch |
|
|
|
from langchain_databricks import ChatDatabricks |
|
|
|
from langchain_core.runnables import RunnableLambda |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.runnables import RunnablePassthrough, RunnableBranch |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
|
|
|
|
|
|
|
|
class ChainBuilder: |
|
|
|
def __init__(self): |
|
|
|
self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml") |
|
self.databricks_resources = self.model_config.get("databricks_resources") |
|
self.llm_config = self.model_config.get("llm_config") |
|
self.retriever_config = self.model_config.get("retriever_config") |
|
self.vector_search_schema = self.retriever_config.get("schema") |
|
|
|
|
|
def extract_user_query_string(self, chat_messages_array): |
|
return chat_messages_array[-1]["content"] |
|
|
|
|
|
def extract_chat_history(self, chat_messages_array): |
|
return chat_messages_array[:-1] |
|
|
|
def load_embedding_model(self): |
|
model_name = self.retriever_config.get("embedding_model") |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_and_cache_embedding_model(model_name): |
|
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") |
|
|
|
return embeddings |
|
|
|
return load_and_cache_embedding_model(model_name) |
|
|
|
def get_retriever(self): |
|
endpoint=self.databricks_resources.get("vector_search_endpoint_name") |
|
index_name=self.retriever_config.get("vector_search_index") |
|
embeddings = self.load_embedding_model() |
|
search_kwargs=self.retriever_config.get("parameters") |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs): |
|
vector_search_as_retriever = DatabricksVectorSearch( |
|
endpoint=endpoint, |
|
index_name=index_name, |
|
embedding=_embeddings, |
|
text_column="name", |
|
columns=["name", "description"], |
|
).as_retriever(search_kwargs=search_kwargs) |
|
|
|
return vector_search_as_retriever |
|
|
|
return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_context(self, retrieved_terms): |
|
chunk_template = self.retriever_config.get("chunk_template") |
|
chunk_contents = [ |
|
chunk_template.format( |
|
name=term.page_content, |
|
description=term.metadata[self.vector_search_schema.get("description")], |
|
) |
|
for term in retrieved_terms |
|
] |
|
return "".join(chunk_contents) |
|
|
|
def get_prompt(self): |
|
|
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", self.llm_config.get("llm_prompt_template")), |
|
|
|
|
|
MessagesPlaceholder(variable_name="formatted_chat_history"), |
|
|
|
("user", "{question}"), |
|
] |
|
) |
|
return prompt |
|
|
|
|
|
|
|
def format_chat_history_for_prompt(self, chat_messages_array): |
|
history = self.extract_chat_history(chat_messages_array) |
|
formatted_chat_history = [] |
|
if len(history) > 0: |
|
for chat_message in history: |
|
if chat_message["role"] == "user": |
|
formatted_chat_history.append(HumanMessage(content=chat_message["content"])) |
|
elif chat_message["role"] == "assistant": |
|
formatted_chat_history.append(AIMessage(content=chat_message["content"])) |
|
return formatted_chat_history |
|
|
|
def get_query_rewrite_prompt(self): |
|
|
|
query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so |
|
that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant |
|
information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation. |
|
|
|
Chat history: {chat_history} |
|
|
|
Question: {question}""" |
|
|
|
query_rewrite_prompt = PromptTemplate( |
|
template=query_rewrite_template, |
|
input_variables=["chat_history", "question"], |
|
) |
|
return query_rewrite_prompt |
|
|
|
def get_model(self): |
|
endpoint = self.databricks_resources.get("llm_endpoint_name") |
|
extra_params=self.llm_config.get("llm_parameters") |
|
|
|
@st.cache_resource |
|
def get_and_cache_model(endpoint, extra_params): |
|
model = ChatDatabricks( |
|
endpoint=endpoint, |
|
extra_params=extra_params, |
|
) |
|
return model |
|
|
|
return get_and_cache_model(endpoint, extra_params) |
|
|
|
def build_chain(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chain = ( |
|
{ |
|
"date_str": datetime.datetime.now().strftime("%B %d, %Y"), |
|
"question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), |
|
"chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history), |
|
"formatted_chat_history": itemgetter("messages") |
|
| RunnableLambda(self.format_chat_history_for_prompt), |
|
} |
|
| RunnablePassthrough() |
|
| { |
|
"context": RunnableBranch( |
|
( |
|
lambda x: len(x["chat_history"]) > 0, |
|
self.get_query_rewrite_prompt() | self.get_model() | StrOutputParser(), |
|
), |
|
itemgetter("question"), |
|
) |
|
| self.get_retriever() |
|
| RunnableLambda(self.format_context), |
|
"formatted_chat_history": itemgetter("formatted_chat_history"), |
|
"question": itemgetter("question"), |
|
} |
|
| self.get_prompt() |
|
| self.get_model() |
|
| StrOutputParser() |
|
) |
|
return chain |
|
|
|
|
|
|