BMD / app.py
yogesh69's picture
Update app.py
b2b05e9 verified
raw
history blame
8.53 kB
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
# Define MuRIL model and tokenizer
muril_tokenizer = AutoTokenizer.from_pretrained("google/muril-base-cased")
muril_model = AutoModelForMaskedLM.from_pretrained("google/muril-base-cased")
# Function to initialize MuRIL pipeline
def initialize_muril_pipeline(temperature, max_tokens, top_k):
muril_pipeline = pipeline(
"text-generation",
model=muril_model,
tokenizer=muril_tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens=max_tokens,
do_sample=True,
top_k=top_k,
num_return_sequences=1,
eos_token_id=muril_tokenizer.eos_token_id
)
return muril_pipeline
# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize langchain LLM chain using MuRIL
def initialize_llmchain(temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Initializing MuRIL model...")
muril_pipeline = initialize_muril_pipeline(temperature, max_tokens, top_k)
# Integrate pipeline with langchain
llm = HuggingFacePipeline(pipeline=muril_pipeline, model_kwargs={'temperature': temperature})
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
# Initialize the LLM chain for your chatbot
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
qa_chain = initialize_llmchain(llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Complete!"
# Demo function with Gradio UI
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown(
"""<center><h2>BookMyDarshan: Your Personalized Spiritual Assistant</h2></center>
<h3>Explore Sacred Texts and Enhance Your Spiritual Journey</h3>""")
gr.Markdown(
"""<b>About BookMyDarshan.in:</b> We are a Hyderabad-based startup dedicated to providing pilgrims with exceptional temple darshan experiences.
Our platform offers a comprehensive suite of spiritual and religious services, customized to meet your devotional needs.<br><br>
<b>Note:</b> This spiritual assistant uses state-of-the-art AI to help you explore and understand your uploaded spiritual documents.
With a blend of technology and tradition, this tool assists in connecting you more deeply with your faith.<br>""")
with gr.Tab("Step 1: Upload PDF"):
document = gr.Files(label="Upload your PDF documents", file_count="multiple", file_types=["pdf"], interactive=True)
with gr.Tab("Step 2: Process Document"):
db_btn = gr.Radio(["ChromaDB"], label="Select Vector Database", value="ChromaDB", info="Choose your vector database")
with gr.Accordion("Advanced Options: Text Splitter", open=False):
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk Size", info="Adjust chunk size for text splitting")
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk Overlap", info="Adjust overlap between chunks")
db_progress = gr.Textbox(label="Vector Database Initialization Status", value="None", interactive=False)
generate_db_btn = gr.Button("Generate Vector Database")
with gr.Tab("Step 3 - Initialize QA chain"):
with gr.Row():
with gr.Accordion("Advanced options - LLM model", open=False):
with gr.Row():
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
with gr.Row():
slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
with gr.Row():
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
with gr.Row():
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
with gr.Row():
qachain_btn = gr.Button("Initialize Question Answering chain")
with gr.Tab("Step 4: Chatbot"):
chatbot = gr.Chatbot(label="Chat with your PDF", height=300)
with gr.Accordion("Advanced: Document References", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2)
source1_page = gr.Number(label="Page", interactive=True)
with gr.Row():
doc_source2 = gr.Textbox(label="Reference 2", lines=2)
source2_page = gr.Number(label="Page", interactive=True)
with gr.Row():
doc_source3 = gr.Textbox(label="Reference 3", lines=2)
source3_page = gr.Number(label="Page", interactive=True)
msg = gr.Textbox(placeholder="Type your question here...", label="Ask a Question", container=True)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear Conversation")
# Preprocessing events
generate_db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(
initialize_LLM,
inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db],
outputs=[qa_chain, llm_progress]
).then(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
# Chatbot events
msg.submit(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
submit_btn.click(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
clear_btn.click(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
demo.launch()