Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from langchain.document_loaders import PDFMinerLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
import os | |
# Initialize session state for storing the vector database and tenant | |
if 'vectordb' not in st.session_state: | |
st.session_state.vectordb = {} | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
if 'tokenizer' not in st.session_state: | |
st.session_state.tokenizer = None | |
if 'tenant' not in st.session_state: | |
st.session_state.tenant = "default_tenant" # Default tenant | |
st.title("PDF Question Answering System") | |
# Tenant selection | |
st.sidebar.title("Settings") | |
tenant = st.sidebar.text_input("Enter your tenant:", value=st.session_state.tenant) | |
st.session_state.tenant = tenant # Update the tenant in session state | |
# File uploader for PDFs | |
def load_pdfs(): | |
uploaded_files = st.file_uploader("Upload your PDF files", type=['pdf'], accept_multiple_files=True) | |
if uploaded_files and st.button("Process PDFs"): | |
with st.spinner("Processing PDFs..."): | |
# Save uploaded files temporarily | |
temp_paths = [] | |
for file in uploaded_files: | |
temp_path = f"temp_{file.name}" | |
with open(temp_path, "wb") as f: | |
f.write(file.getbuffer()) | |
temp_paths.append(temp_path) | |
# Load PDFs | |
documents = [] | |
for pdf_path in temp_paths: | |
loader = PDFMinerLoader(pdf_path) | |
doc = loader.load() | |
for d in doc: | |
d.metadata["source"] = pdf_path | |
documents.extend(doc) | |
# Clean up temporary files | |
for path in temp_paths: | |
os.remove(path) | |
# Split documents | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
splits = text_splitter.split_documents(documents) | |
# Create embeddings and vector store for the current tenant | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Directory to store the vector database | |
db_directory = "path/to/store/chroma_db" # Update with your desired path | |
if st.session_state.tenant not in st.session_state.vectordb: | |
st.session_state.vectordb[st.session_state.tenant] = Chroma.from_documents( | |
documents=splits, | |
embedding=embeddings, | |
persist_directory=db_directory | |
) | |
else: | |
# Update the existing vector store for the tenant | |
st.session_state.vectordb[st.session_state.tenant].add_documents(splits) | |
st.success("PDFs processed successfully!") | |
return True | |
return False | |
# Load model and tokenizer | |
def load_model(model_path): | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
) | |
model.eval() | |
return model, tokenizer | |
def generate_response(prompt, model, tokenizer, max_new_tokens=256): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=0.1, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response[len(prompt):].strip() | |
def combine_documents_and_answer(retrieved_docs, question, model, tokenizer): | |
context = "\n".join(doc.page_content for doc in retrieved_docs) | |
prompt = f"""You are an assistant tasked with answering questions based SOLELY on the provided context. | |
Do not use any external knowledge or information not present in the given context. | |
If the question is of any other field and irrelevant to the context provided, respond just with "I can't tell you this, ask something from the provided context." | |
DO NOT INCLUDE YOUR OWN OPINION. IMPORTANT: Your answer should be well structured and meaningful. | |
Your answer should elaborate every tiny detail mentioned in the context. So, answer the following question within the context in detail: | |
Question: {question} | |
Context: | |
{context} | |
Answer:""" | |
return generate_response(prompt, model, tokenizer) | |
# Main app logic | |
def main(): | |
if torch.cuda.is_available(): | |
st.sidebar.success("GPU is available!") | |
else: | |
st.sidebar.warning("GPU is not available. This app may run slowly on CPU.") | |
# Model path input | |
model_path = st.sidebar.text_input("Enter the path to your model:", | |
placeholder="waqasali1707/llama_3.2_3B_4_bit_Quan") | |
# Load PDFs first | |
if st.session_state.tenant not in st.session_state.vectordb: | |
pdfs_processed = load_pdfs() | |
if not pdfs_processed: | |
st.info("Please upload PDF files and click 'Process PDFs' to continue.") | |
return | |
# Load model if path is provided and model isn't loaded | |
if model_path and st.session_state.model is None: | |
with st.spinner("Loading model..."): | |
try: | |
st.session_state.model, st.session_state.tokenizer = load_model(model_path) | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# Question answering interface | |
if st.session_state.tenant in st.session_state.vectordb and st.session_state.model is not None: | |
question = st.text_area("Enter your question:", height=100) | |
if st.button("Get Answer"): | |
if question: | |
with st.spinner("Generating answer..."): | |
try: | |
# Get relevant documents | |
retriever = st.session_state.vectordb[st.session_state.tenant].as_retriever(search_kwargs={"k": 4}) | |
retrieved_docs = retriever.get_relevant_documents(question) | |
# Generate answer | |
answer = combine_documents_and_answer( | |
retrieved_docs, | |
question, | |
st.session_state.model, | |
st.session_state.tokenizer | |
) | |
# Display answer | |
st.subheader("Answer:") | |
st.write(answer) | |
# Display sources | |
st.subheader("Sources:") | |
sources = set(doc.metadata["source"] for doc in retrieved_docs) | |
for source in sources: | |
st.write(f"- {os.path.basename(source)}") | |
except Exception as e: | |
st.error(f"Error generating answer: {str(e)}") | |
else: | |
st.warning("Please enter a question.") | |
if __name__ == "__main__": | |
main() | |