DrishtiSharma's picture
Update app.py
e37ff79 verified
raw
history blame
7.88 kB
import os
import chromadb
import requests
import streamlit as st
from langchain.chains import SequentialChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain.document_loaders import PDFPlumberLoader
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from prompts import rag_prompt, relevancy_prompt, relevant_context_picker_prompt, response_synth
# Set API Keys
os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
# Load LLM models
llm_judge = ChatGroq(model="deepseek-r1-distill-llama-70b")
rag_llm = ChatGroq(model="mixtral-8x7b-32768")
llm_judge.verbose = True
rag_llm.verbose = True
# Clear ChromaDB cache to fix tenant issue
chromadb.api.client.SharedSystemClient.clear_system_cache()
st.title("❓")
# Initialize session state variables
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if "documents" not in st.session_state:
st.session_state.documents = None
if "pdf_loaded" not in st.session_state:
st.session_state.pdf_loaded = False
if "chunked" not in st.session_state:
st.session_state.chunked = False
if "vector_created" not in st.session_state:
st.session_state.vector_created = False
# Step 1: Choose PDF Source
pdf_source = st.radio("Upload or provide a link to a PDF:", ["Upload a PDF file", "Enter a PDF URL"], index=0, horizontal=True)
pdf_path = None
if pdf_source == "Upload a PDF file":
uploaded_file = st.file_uploader("Upload your PDF file", type="pdf")
if uploaded_file:
pdf_path = "temp.pdf"
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.session_state.pdf_loaded = False
st.session_state.chunked = False
st.session_state.vector_created = False
elif pdf_source == "Enter a PDF URL":
pdf_url = st.text_input("Enter PDF URL:", value="https://arxiv.org/pdf/2406.06998")
if pdf_url:
with st.spinner("Downloading PDF..."):
try:
response = requests.get(pdf_url)
if response.status_code == 200:
pdf_path = "temp.pdf"
with open(pdf_path, "wb") as f:
f.write(response.content)
st.success("βœ… PDF Downloaded Successfully!")
st.session_state.pdf_loaded = False
st.session_state.chunked = False
st.session_state.vector_created = False
else:
st.error("❌ Failed to download PDF. Check the URL.")
except Exception as e:
st.error(f"Error downloading PDF: {e}")
# Step 2: Process PDF
if pdf_path and not st.session_state.pdf_loaded:
with st.spinner("Loading PDF..."):
loader = PDFPlumberLoader(pdf_path)
docs = loader.load()
st.session_state.documents = docs
st.session_state.pdf_loaded = True
st.success(f"βœ… **PDF Loaded!** Total Pages: {len(docs)}")
# Step 3: Chunking (Only if Not Already Done)
if st.session_state.pdf_loaded and not st.session_state.chunked:
with st.spinner("Chunking the document..."):
model_name = "nomic-ai/modernbert-embed-base"
embedding_model = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': False})
text_splitter = SemanticChunker(embedding_model)
documents = text_splitter.split_documents(st.session_state.documents)
st.session_state.documents = documents
st.session_state.chunked = True
st.success(f"βœ… **Document Chunked!** Total Chunks: {len(documents)}")
# Step 4: Setup Vectorstore
if st.session_state.chunked and not st.session_state.vector_created:
with st.spinner("Creating vector store..."):
vector_store = Chroma(
collection_name="deepseek_collection",
collection_metadata={"hnsw:space": "cosine"},
embedding_function=embedding_model
)
vector_store.add_documents(st.session_state.documents)
num_documents = len(vector_store.get()["documents"])
st.session_state.vector_store = vector_store
st.session_state.vector_created = True
st.success(f"βœ… **Vector Store Created!** Total documents stored: {num_documents}")
# Step 5: Query Input
if st.session_state.vector_created:
query = st.text_input("πŸ” Enter a Query:")
if query:
with st.spinner("Retrieving relevant contexts..."):
retriever = st.session_state.vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
contexts = retriever.invoke(query)
context_texts = [doc.page_content for doc in contexts]
st.success(f"βœ… **Retrieved {len(context_texts)} Contexts!**")
for i, text in enumerate(context_texts, 1):
st.write(f"**Context {i}:** {text[:500]}...")
# Step 6: Context Relevancy Checker
with st.spinner("Evaluating context relevancy..."):
context_relevancy_checker_prompt = PromptTemplate(input_variables=["retriever_query", "context"], template=relevancy_prompt)
context_relevancy_chain = LLMChain(llm=llm_judge, prompt=context_relevancy_checker_prompt, output_key="relevancy_response")
relevancy_response = context_relevancy_chain.invoke({"context": context_texts, "retriever_query": query})
st.subheader("πŸŸ₯ Context Relevancy Evaluation")
st.json(relevancy_response['relevancy_response'])
# Step 7: Selecting Relevant Contexts
with st.spinner("Selecting the most relevant contexts..."):
relevant_prompt = PromptTemplate(input_variables=["relevancy_response"], template=relevant_context_picker_prompt)
pick_relevant_context_chain = LLMChain(llm=llm_judge, prompt=relevant_prompt, output_key="context_number")
relevant_response = pick_relevant_context_chain.invoke({"relevancy_response": relevancy_response['relevancy_response']})
st.subheader("🟦 Pick Relevant Context Chain")
st.json(relevant_response['context_number'])
# Step 8: Retrieving Context for Response Generation
with st.spinner("Retrieving final context..."):
context_prompt = PromptTemplate(input_variables=["context_number", "context"], template=response_synth)
relevant_contexts_chain = LLMChain(llm=llm_judge, prompt=context_prompt, output_key="relevant_contexts")
final_contexts = relevant_contexts_chain.invoke({"context_number": relevant_response['context_number'], "context": context_texts})
st.subheader("πŸŸ₯ Relevant Contexts Extracted")
st.json(final_contexts['relevant_contexts'])
# Step 9: Generate Final Response
with st.spinner("Generating the final answer..."):
final_prompt = PromptTemplate(input_variables=["query", "context"], template=rag_prompt)
response_chain = LLMChain(llm=rag_llm, prompt=final_prompt, output_key="final_response")
final_response = response_chain.invoke({"query": query, "context": final_contexts['relevant_contexts']})
st.subheader("πŸŸ₯ RAG Final Response")
st.success(final_response['final_response'])
# Final + Intermediate Outputs
st.subheader("πŸ” **Full Workflow Breakdown:**")
st.json({
"Context Relevancy Evaluation": relevancy_response["relevancy_response"],
"Relevant Contexts": relevant_response["context_number"],
"Extracted Contexts": final_contexts["relevant_contexts"],
"Final Answer": final_response["final_response"]
})
else:
st.warning("πŸ“„ Please upload or provide a PDF URL first.")