|
|
|
|
|
import streamlit as st |
|
import os |
|
import sys |
|
import shutil |
|
from langchain.text_splitter import TokenTextSplitter,RecursiveCharacterTextSplitter,CharacterTextSplitter |
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.document_loaders.pdf import PyPDFDirectoryLoader |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from transformers import pipeline |
|
import torch |
|
from langchain.chains.query_constructor.base import AttributeInfo |
|
from langchain.vectorstores import DocArrayInMemorySearch |
|
from langchain.document_loaders import TextLoader |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering |
|
from langchain.chains import create_history_aware_retriever, create_retrieval_chain |
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_community.llms import Aphrodite |
|
from typing import Callable, Dict, List, Optional, Union |
|
from langchain.vectorstores import Chroma |
|
import streamlit as st |
|
from langchain_community.llms import llamacpp |
|
from utills import split_docs, retriever_from_chroma, history_aware_retriever,chroma_db |
|
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory |
|
|
|
|
|
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
data_path = os.path.join(script_dir, "data") |
|
model_path = os.path.join(script_dir, '/mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2') |
|
store = {} |
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': True} |
|
hf = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs) |
|
|
|
|
|
|
|
|
|
documents = [] |
|
|
|
for filename in os.listdir(data_path): |
|
|
|
if filename.endswith('.txt'): |
|
|
|
file_path = os.path.join(data_path, filename) |
|
|
|
documents = TextLoader(file_path).load() |
|
|
|
documents.extend(documents) |
|
|
|
|
|
docs = split_docs(documents, 450, 20) |
|
chroma_db = chroma_db(docs,hf) |
|
retriever = retriever_from_chroma(chroma_db, "mmr", 6) |
|
|
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': True} |
|
hf = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
|
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
|
|
llm = llamacpp.LlamaCpp( |
|
model_path= model_path, |
|
n_gpu_layers=0, |
|
temperature=0.1, |
|
top_p=0.5, |
|
n_ctx=7000, |
|
max_tokens=350, |
|
repeat_penalty=1.7, |
|
stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"], |
|
callback_manager=callback_manager, |
|
verbose=False, |
|
) |
|
|
|
|
|
contextualize_q_system_prompt = """Given a context, chat history and the latest user question |
|
which maybe reference context in the chat history, formulate a standalone question |
|
which can be understood without the chat history. Do NOT answer the question, |
|
just reformulate it if needed and otherwise return it as is.""" |
|
|
|
ha_retriever = history_aware_retriever(llm, retriever, contextualize_q_system_prompt) |
|
|
|
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}""" |
|
|
|
qa_prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", qa_system_prompt), |
|
MessagesPlaceholder("chat_history"), |
|
("human", "{input}"), |
|
] |
|
) |
|
|
|
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt) |
|
rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain) |
|
msgs = StreamlitChatMessageHistory(key="special_app_key") |
|
|
|
conversational_rag_chain = RunnableWithMessageHistory( |
|
rag_chain, |
|
lambda session_id: msgs, |
|
input_messages_key="input", |
|
history_messages_key="chat_history", |
|
output_messages_key="answer", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def display_chat_history(chat_history): |
|
"""Displays the chat history in Streamlit.""" |
|
for msg in chat_history.messages: |
|
st.chat_message(msg.type).write(msg.content) |
|
|
|
def display_documents(docs, on_click=None): |
|
"""Displays retrieved documents with optional click action.""" |
|
if docs: |
|
for i, document in enumerate(docs): |
|
st.write(f"**Docs {i+1}**") |
|
st.markdown(document, unsafe_allow_html=True) |
|
if on_click: |
|
if st.button(f"Expand Article {i+1}"): |
|
on_click(i) |
|
|
|
def main(conversational_rag_chain): |
|
"""Main function for the Streamlit app.""" |
|
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory()) |
|
chain_with_history =conversational_rag_chain |
|
|
|
st.title("Conversational RAG Chatbot") |
|
|
|
|
|
display_chat_history(msgs) |
|
|
|
if prompt := st.chat_input(): |
|
st.chat_message("human").write(prompt) |
|
|
|
|
|
config = {"configurable": {"session_id": "any"}} |
|
response = chain_with_history.invoke({"question": prompt}, config) |
|
st.chat_message("ai").write(response.content) |
|
|
|
|
|
if "docs" in response and response["documents"]: |
|
docs = response["documents"] |
|
def expand_document(index): |
|
|
|
st.write(f"Expanding document {index+1}...") |
|
display_documents(docs, expand_document) |
|
|
|
st.session_state["chat_history"] = msgs |
|
|
|
if __name__ == "__main__": |
|
main() |