RAG-PDF-Chatbot / app.py
MuntasirHossain's picture
Update app.py
88c9b4e verified
raw
history blame
6.97 kB
import gradio as gr
import os
api_token = os.getenv("HF_TOKEN")
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(list_file_path):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=64
)
return text_splitter.split_documents(pages)
def create_db(splits):
return FAISS.from_documents(splits, HuggingFaceEmbeddings())
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
return ConversationalRetrievalChain.from_llm(
llm,
retriever=vector_db.as_retriever(),
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
def initialize_database(list_file_obj):
list_file_path = [x.name for x in list_file_obj if x is not None]
return create_db(load_doc(list_file_path)), "Database created!"
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
llm_name = list_llm[llm_option]
return initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db), "QA chain initialized!"
def format_chat_history(message, chat_history):
formatted = []
for user_msg, bot_msg in chat_history:
formatted.extend([f"User: {user_msg}", f"Assistant: {bot_msg}"])
return formatted
def conversation(qa_chain, message, history):
response = qa_chain.invoke({
"question": message,
"chat_history": format_chat_history(message, history)
})
answer = response["answer"].split("Helpful Answer:")[-1]
sources = response["source_documents"]
return (
qa_chain,
gr.update(value=""),
history + [(message, answer)],
sources[0].page_content.strip(), sources[0].metadata["page"] + 1,
sources[1].page_content.strip(), sources[1].metadata["page"] + 1,
sources[2].page_content.strip(), sources[2].metadata["page"] + 1
)
def demo():
css = """
#main-container {
display: flex;
flex-wrap: nowrap;
width: 80% !important;
max-width: 100vw !important;
overflow: hidden;
}
#left-column {
flex: 0 1 35%;
min-width: 150px !important;
max-width: 100% !important;
}
#right-column {
flex: 1 1 65%;
min-width: 300px !important;
max-width: 100% !important;
}
.chatbot {
width: 100% !important;
max-width: 100% !important;
}
.textbox {
max-width: 100% !important;
}
@media (max-width: 900px) {
#main-container { flex-wrap: wrap; }
#left-column, #right-column { flex: 1 1 100% !important; }
}
"""
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky"), css=css) as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("## RAG PDF Chatbot")
with gr.Row(elem_id="main-container"):
with gr.Column(elem_id="left-column"):
# Step 1 Column
with gr.Group():
gr.Markdown("**Step 1 - Setup**")
docs = gr.Files(file_types=[".pdf"], label="Upload PDFs")
db_btn = gr.Button("Create Vector DB")
db_status = gr.Textbox("Not initialized", show_label=False)
gr.Markdown("**LLM Selection**")
llm_select = gr.Radio(list_llm_simple, label="Model", value=list_llm_simple[0], type="index")
with gr.Accordion("Parameters", open=False):
temp = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
tokens = gr.Slider(128, 8192, 4096, step=128, label="Max Tokens")
topk = gr.Slider(1, 10, 3, step=1, label="Top-K")
init_btn = gr.Button("Initialize Chatbot")
llm_status = gr.Textbox("Not initialized", show_label=False)
with gr.Column(elem_id="right-column"):
# Step 2 Column
gr.Markdown("**Step 2 - Chat**")
chatbot = gr.Chatbot(height=400, elem_classes="chatbot")
with gr.Accordion("Source Context", open=False):
with gr.Row():
src1 = gr.Textbox(label="Reference 1", lines=2, max_lines=2, elem_classes="textbox")
pg1 = gr.Number(label="Page")
with gr.Row():
src2 = gr.Textbox(label="Reference 2", lines=2, max_lines=2, elem_classes="textbox")
pg2 = gr.Number(label="Page")
with gr.Row():
src3 = gr.Textbox(label="Reference 3", lines=2, max_lines=2, elem_classes="textbox")
pg3 = gr.Number(label="Page")
msg = gr.Textbox(placeholder="Ask something...", elem_classes="textbox")
with gr.Row():
submit = gr.Button("Submit")
clear = gr.ClearButton([msg, chatbot])
# Event handlers
db_btn.click(
initialize_database, [docs], [vector_db, db_status]
)
init_btn.click(
initialize_LLM, [llm_select, temp, tokens, topk, vector_db], [qa_chain, llm_status]
).then(
lambda: [None,"",0,"",0,"",0], None, [chatbot, src1, pg1, src2, pg2, src3, pg3]
)
msg.submit(conversation, [qa_chain, msg, chatbot],
[qa_chain, msg, chatbot, src1, pg1, src2, pg2, src3, pg3])
submit.click(conversation, [qa_chain, msg, chatbot],
[qa_chain, msg, chatbot, src1, pg1, src2, pg2, src3, pg3])
clear.click(
lambda: [None,"",0,"",0,"",0], None, [chatbot, src1, pg1, src2, pg2, src3, pg3]
)
return demo
if __name__ == "__main__":
demo().launch()