|
import os |
|
import streamlit as st |
|
from dotenv import load_dotenv |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.llms import llamacpp |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain, ConversationalRetrievalChain |
|
from langchain.document_loaders import TextLoader |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory |
|
from langchain.prompts import PromptTemplate |
|
from langchain.vectorstores import Chroma |
|
from utills import load_txt_documents, split_docs, load_uploaded_documents, retriever_from_chroma |
|
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders.directory import DirectoryLoader |
|
from HTML_templates import css, bot_template, user_template |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain import hub |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain.retrievers.document_compressors import LLMChainExtractor |
|
|
|
lang_api_key = os.getenv("lang_api_key") |
|
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true" |
|
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.langchain.plus" |
|
os.environ["LANGCHAIN_API_KEY"] = lang_api_key |
|
os.environ["LANGCHAIN_PROJECT"] = "Lithuanian_Law_RAG_QA" |
|
|
|
|
|
|
|
|
|
|
|
def create_retriever_from_chroma(vectorstore_path="./docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30,lambda_mult= 0.7): |
|
|
|
model_name = "Alibaba-NLP/gte-base-en-v1.5" |
|
model_kwargs = {'device': 'cpu', |
|
"trust_remote_code" : 'False'} |
|
encode_kwargs = {'normalize_embeddings': True} |
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
|
|
|
|
if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path): |
|
|
|
st.write("Vector store exists and is loaded") |
|
vectorstore = Chroma(persist_directory=vectorstore_path,embedding_function=embeddings) |
|
|
|
else: |
|
|
|
st.write("Vector store doesnt exist and will be created now") |
|
loader = DirectoryLoader('./data/', glob="./*.txt", loader_cls=TextLoader) |
|
docs = loader.load() |
|
st.write("Docs loaded") |
|
|
|
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( |
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap, |
|
separators=["\n \n \n", "\n \n", "\n1" , "(?<=\. )", " ", ""] |
|
) |
|
split_docs = text_splitter.split_documents(docs) |
|
|
|
|
|
|
|
|
|
vectorstore = Chroma.from_documents( |
|
documents=split_docs, embedding=embeddings, persist_directory=vectorstore_path |
|
) |
|
st.write("VectorStore is created") |
|
|
|
retriever=vectorstore.as_retriever(search_type = search_type, search_kwargs={"k": k}) |
|
|
|
|
|
|
|
return retriever |
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Chat with multiple Lithuanian Law Documents: ", |
|
page_icon=":books:") |
|
st.write(css, unsafe_allow_html=True) |
|
|
|
st.header("Chat with multiple Lithuanian Law Documents:" ":books:") |
|
|
|
st.markdown("Hi, I am Birute (Powered by qwen2-0_5b model), chat assistant, based on republic of Lithuania law documents. You can choose below information retrieval type and how many documents you want to be retrieved.") |
|
st.markdown("Available Documents: LR_Civil_Code_2022, LR_Constitution_2022, LR_Criminal_Code_2018, LR_Criminal_Procedure_code_2022,LR_Labour_code_2010. P.S it's a shame that there are no newest documents translations... ") |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state["messages"] = [ |
|
{"role": "assistant", "content": "Hi, I'm a chatbot who is based on respublic of Lithuania law documents. How can I help you?"} |
|
] |
|
|
|
|
|
search_type = st.selectbox( |
|
"Choose search type. Options are [Max marginal relevance search (similarity) , Similarity search (similarity). Default value (mmr)]", |
|
options=["mmr", "similarity"], |
|
index=1 |
|
) |
|
|
|
k = st.select_slider( |
|
"Select amount of documents to be retrieved. Default value (5): ", |
|
options=list(range(2, 16)), |
|
value=4 |
|
) |
|
retriever = create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type=search_type, k=k, chunk_size=200, chunk_overlap=30) |
|
rag_chain = create_conversational_rag_chain(retriever) |
|
|
|
|
|
if user_question := st.text_input("Ask a question about your documents:"): |
|
handle_userinput(user_question,retriever,rag_chain) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_userinput(user_question,retriever,rag_chain): |
|
st.session_state.messages.append({"role": "user", "content": user_question}) |
|
st.chat_message("user").write(user_question) |
|
docs = retriever.invoke(user_question) |
|
|
|
with st.sidebar: |
|
st.subheader("Your documents") |
|
with st.spinner("Processing"): |
|
for doc in docs: |
|
st.write(f"Document: {doc}") |
|
|
|
doc_txt = [doc.page_content for doc in docs] |
|
|
|
|
|
response = rag_chain.invoke({"context": doc_txt, "question": user_question}) |
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
st.chat_message("assistant").write(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_conversational_rag_chain(retriever): |
|
|
|
|
|
|
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
llm = llamacpp.LlamaCpp( |
|
model_path = "JCHAVEROT_Qwen2-0.5B-Chat_SFT_DPO.Q8_0.gguf", |
|
seed = 41, |
|
n_gpu_layers=0, |
|
temperature=0.0, |
|
n_ctx=22000, |
|
n_batch=2000, |
|
max_tokens=200, |
|
repeat_penalty=1.6, |
|
last_n_tokens_size = 200, |
|
callback_manager=callback_manager, |
|
verbose=False, |
|
) |
|
|
|
template = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. |
|
|
|
Question: {question} |
|
|
|
Context: {context} |
|
|
|
Answer: |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
|
|
|
|
rag_chain = prompt | llm | StrOutputParser() |
|
|
|
|
|
return rag_chain |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |