File size: 8,529 Bytes
db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a b2b05e9 db95f5a ed78e95 db95f5a ed78e95 ca0df93 089a548 ca0df93 089a548 bf0d3e3 b2b05e9 bf0d3e3 b2b05e9 bf0d3e3 089a548 db95f5a 089a548 db95f5a 089a548 db95f5a 089a548 db95f5a 089a548 b2b05e9 089a548 b2b05e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from pathlib import Path
import chromadb
from unidecode import unidecode
import re
# Define MuRIL model and tokenizer
muril_tokenizer = AutoTokenizer.from_pretrained("google/muril-base-cased")
muril_model = AutoModelForMaskedLM.from_pretrained("google/muril-base-cased")
# Function to initialize MuRIL pipeline
def initialize_muril_pipeline(temperature, max_tokens, top_k):
muril_pipeline = pipeline(
"text-generation",
model=muril_model,
tokenizer=muril_tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens=max_tokens,
do_sample=True,
top_k=top_k,
num_return_sequences=1,
eos_token_id=muril_tokenizer.eos_token_id
)
return muril_pipeline
# Load PDF document and create doc splits
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
# Create vector database
def create_db(splits, collection_name):
embedding = HuggingFaceEmbeddings()
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
return vectordb
# Initialize langchain LLM chain using MuRIL
def initialize_llmchain(temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
progress(0.1, desc="Initializing MuRIL model...")
muril_pipeline = initialize_muril_pipeline(temperature, max_tokens, top_k)
# Integrate pipeline with langchain
llm = HuggingFacePipeline(pipeline=muril_pipeline, model_kwargs={'temperature': temperature})
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
# Initialize the LLM chain for your chatbot
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
qa_chain = initialize_llmchain(llm_temperature, max_tokens, top_k, vector_db, progress)
return qa_chain, "Complete!"
# Demo function with Gradio UI
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
gr.Markdown(
"""<center><h2>BookMyDarshan: Your Personalized Spiritual Assistant</h2></center>
<h3>Explore Sacred Texts and Enhance Your Spiritual Journey</h3>""")
gr.Markdown(
"""<b>About BookMyDarshan.in:</b> We are a Hyderabad-based startup dedicated to providing pilgrims with exceptional temple darshan experiences.
Our platform offers a comprehensive suite of spiritual and religious services, customized to meet your devotional needs.<br><br>
<b>Note:</b> This spiritual assistant uses state-of-the-art AI to help you explore and understand your uploaded spiritual documents.
With a blend of technology and tradition, this tool assists in connecting you more deeply with your faith.<br>""")
with gr.Tab("Step 1: Upload PDF"):
document = gr.Files(label="Upload your PDF documents", file_count="multiple", file_types=["pdf"], interactive=True)
with gr.Tab("Step 2: Process Document"):
db_btn = gr.Radio(["ChromaDB"], label="Select Vector Database", value="ChromaDB", info="Choose your vector database")
with gr.Accordion("Advanced Options: Text Splitter", open=False):
slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Chunk Size", info="Adjust chunk size for text splitting")
slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Chunk Overlap", info="Adjust overlap between chunks")
db_progress = gr.Textbox(label="Vector Database Initialization Status", value="None", interactive=False)
generate_db_btn = gr.Button("Generate Vector Database")
with gr.Tab("Step 3 - Initialize QA chain"):
with gr.Row():
with gr.Accordion("Advanced options - LLM model", open=False):
with gr.Row():
slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
with gr.Row():
slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
with gr.Row():
slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
with gr.Row():
llm_progress = gr.Textbox(value="None", label="QA chain initialization")
with gr.Row():
qachain_btn = gr.Button("Initialize Question Answering chain")
with gr.Tab("Step 4: Chatbot"):
chatbot = gr.Chatbot(label="Chat with your PDF", height=300)
with gr.Accordion("Advanced: Document References", open=False):
with gr.Row():
doc_source1 = gr.Textbox(label="Reference 1", lines=2)
source1_page = gr.Number(label="Page", interactive=True)
with gr.Row():
doc_source2 = gr.Textbox(label="Reference 2", lines=2)
source2_page = gr.Number(label="Page", interactive=True)
with gr.Row():
doc_source3 = gr.Textbox(label="Reference 3", lines=2)
source3_page = gr.Number(label="Page", interactive=True)
msg = gr.Textbox(placeholder="Type your question here...", label="Ask a Question", container=True)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear Conversation")
# Preprocessing events
generate_db_btn.click(initialize_database, inputs=[document, slider_chunk_size, slider_chunk_overlap], outputs=[vector_db, collection_name, db_progress])
qachain_btn.click(
initialize_LLM,
inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db],
outputs=[qa_chain, llm_progress]
).then(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
# Chatbot events
msg.submit(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
submit_btn.click(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
clear_btn.click(
lambda: [None, "", 0, "", 0, "", 0],
inputs=None,
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
queue=False
)
demo.launch()
|