import gradio as gr import os api_token = os.getenv("HF_TOKEN") from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFaceEndpoint list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"] list_llm_simple = [os.path.basename(llm) for llm in list_llm] def load_doc(list_file_path): loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size=1024, chunk_overlap=64 ) return text_splitter.split_documents(pages) def create_db(splits): return FAISS.from_documents(splits, HuggingFaceEmbeddings()) def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db): llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, ) memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) return ConversationalRetrievalChain.from_llm( llm, retriever=vector_db.as_retriever(), chain_type="stuff", memory=memory, return_source_documents=True, verbose=False, ) def initialize_database(list_file_obj): list_file_path = [x.name for x in list_file_obj if x is not None] return create_db(load_doc(list_file_path)), "Database created!" def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db): llm_name = list_llm[llm_option] return initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db), "QA chain initialized!" def format_chat_history(message, chat_history): formatted = [] for user_msg, bot_msg in chat_history: formatted.extend([f"User: {user_msg}", f"Assistant: {bot_msg}"]) return formatted def conversation(qa_chain, message, history): response = qa_chain.invoke({ "question": message, "chat_history": format_chat_history(message, history) }) answer = response["answer"].split("Helpful Answer:")[-1] sources = response["source_documents"] return ( qa_chain, gr.update(value=""), history + [(message, answer)], sources[0].page_content.strip(), sources[0].metadata["page"] + 1, sources[1].page_content.strip(), sources[1].metadata["page"] + 1, sources[2].page_content.strip(), sources[2].metadata["page"] + 1 ) def demo(): css = """ #main-container { display: flex; flex-wrap: nowrap; width: 80% !important; max-width: 100vw !important; overflow: hidden; } #left-column { flex: 0 1 35%; min-width: 150px !important; max-width: 100% !important; } #right-column { flex: 1 1 65%; min-width: 300px !important; max-width: 100% !important; } .chatbot { width: 100% !important; max-width: 100% !important; } .textbox { max-width: 100% !important; } @media (max-width: 900px) { #main-container { flex-wrap: wrap; } #left-column, #right-column { flex: 1 1 100% !important; } } """ with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky"), css=css) as demo: vector_db = gr.State() qa_chain = gr.State() gr.Markdown("## RAG PDF Chatbot") with gr.Row(elem_id="main-container"): with gr.Column(elem_id="left-column"): # Step 1 Column with gr.Group(): gr.Markdown("**Step 1 - Setup**") docs = gr.Files(file_types=[".pdf"], label="Upload PDFs") db_btn = gr.Button("Create Vector DB") db_status = gr.Textbox("Not initialized", show_label=False) gr.Markdown("**LLM Selection**") llm_select = gr.Radio(list_llm_simple, label="Model", value=list_llm_simple[0], type="index") with gr.Accordion("Parameters", open=False): temp = gr.Slider(0.01, 1.0, 0.5, label="Temperature") tokens = gr.Slider(128, 8192, 4096, step=128, label="Max Tokens") topk = gr.Slider(1, 10, 3, step=1, label="Top-K") init_btn = gr.Button("Initialize Chatbot") llm_status = gr.Textbox("Not initialized", show_label=False) with gr.Column(elem_id="right-column"): # Step 2 Column gr.Markdown("**Step 2 - Chat**") chatbot = gr.Chatbot(height=400, elem_classes="chatbot") with gr.Accordion("Source Context", open=False): with gr.Row(): src1 = gr.Textbox(label="Reference 1", lines=2, max_lines=2, elem_classes="textbox") pg1 = gr.Number(label="Page") with gr.Row(): src2 = gr.Textbox(label="Reference 2", lines=2, max_lines=2, elem_classes="textbox") pg2 = gr.Number(label="Page") with gr.Row(): src3 = gr.Textbox(label="Reference 3", lines=2, max_lines=2, elem_classes="textbox") pg3 = gr.Number(label="Page") msg = gr.Textbox(placeholder="Ask something...", elem_classes="textbox") with gr.Row(): submit = gr.Button("Submit") clear = gr.ClearButton([msg, chatbot]) # Event handlers db_btn.click( initialize_database, [docs], [vector_db, db_status] ) init_btn.click( initialize_LLM, [llm_select, temp, tokens, topk, vector_db], [qa_chain, llm_status] ).then( lambda: [None,"",0,"",0,"",0], None, [chatbot, src1, pg1, src2, pg2, src3, pg3] ) msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, src1, pg1, src2, pg2, src3, pg3]) submit.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot, src1, pg1, src2, pg2, src3, pg3]) clear.click( lambda: [None,"",0,"",0,"",0], None, [chatbot, src1, pg1, src2, pg2, src3, pg3] ) return demo if __name__ == "__main__": demo().launch()