sanjeevbora's picture
Update app.py
00ea8c0 verified
raw
history blame
4.65 kB
import gradio as gr
import os
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader, PyPDFLoader
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFaceHub
import tempfile
import shutil
from langchain.prompts import PromptTemplate
# Define a proper prompt template
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
# Load environment variables
TOKEN = os.getenv("HF_TOKEN")
os.environ["HUGGINGFACEHUB_API_TOKEN"] = TOKEN
# Initialize LangChain components
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Create vector store
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
# Initialize LLM
llm = HuggingFaceHub(
repo_id="meta-llama/Meta-Llama-3.1-405B-Instruct-FP8",
model_kwargs={"temperature": 0.7, "max_length": 512}
)
# Create RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
def process_uploaded_file(file):
if file is None:
return "No file uploaded."
try:
# Create a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
# Create a path for the temporary file
temp_file_path = os.path.join(temp_dir, os.path.basename(file.name))
# Save the uploaded file to the temporary path
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(file.read())
# Determine file type and load accordingly
if file.name.endswith('.pdf'):
loader = PyPDFLoader(temp_file_path)
else:
loader = TextLoader(temp_file_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
# Update the vector store with new documents
vectorstore.add_documents(texts)
vectorstore.persist()
return f"File processed and added to the knowledge base. {len(texts)} chunks created."
except Exception as e:
return f"An error occurred while processing the file: {str(e)}"
def respond(message, history, system_message, max_tokens, temperature, top_p):
full_prompt = f"{system_message}\n\nHuman: {message}"
# Use the RetrievalQA chain to get the answer
result = qa_chain({"query": full_prompt})
answer = result['result']
# Return only the answer
yield answer
with gr.Blocks() as demo:
gr.Markdown("# RAG Chatbot with Content Upload")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
with gr.Column(scale=1):
file_upload = gr.File(label="Upload Content for RAG (TXT or PDF)")
upload_button = gr.Button("Process Uploaded File")
system_message = gr.Textbox(value="You are a friendly Chatbot.", label="System message")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, system_message, max_tokens, temperature, top_p):
user_message = history[-1][0]
bot_message = next(respond(user_message, history[:-1], system_message, max_tokens, temperature, top_p))
history[-1][1] = bot_message
return history
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, system_message, max_tokens, temperature, top_p], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
upload_button.click(process_uploaded_file, inputs=[file_upload], outputs=[gr.Textbox()])
if __name__ == "__main__":
demo.launch()