pdf-chatbot / app.py
ashok2216's picture
Update app.py
d051bce verified
raw
history blame
6.78 kB
import chromadb
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import streamlit as st
import fitz # PyMuPDF for PDF parsing
# # Step 1: Setup ChromaDB
# def setup_chromadb():
# # Initialize ChromaDB in-memory instance
# client = chromadb.Client()
# try:
# client.delete_collection("pdf_data")
# print("Existing collection 'pdf_data' deleted.")
# except:
# print("Collection 'pdf_data' not found, creating a new one.")
# # Create a new collection with the embedding function
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/all-MiniLM-L6-v2")
# collection = client.create_collection("pdf_data", embedding_function=ef)
# return client, collection
# import chromadb
from chromadb.config import Settings
# Configure ChromaDB with persistent SQLite database
config = Settings(
persist_directory="./chromadb_data",
chroma_db_impl="sqlite",
)
import chromadb
# Initialize persistent client with SQLite
def setup_chromadb():
client = chromadb.PersistentClient(path="./chromadb_data")
collection = client.get_or_create_collection(
name="pdf_data",
embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2"
),
)
return client, collection
# Initialize ChromaDB client
# def setup_chromadb():
# try:
# client = chromadb.Client(config)
# collections = client.list_collections()
# print(f"Existing collections: {collections}")
# if "pdf_data" in [c.name for c in collections]:
# client.delete_collection("pdf_data")
# print("Existing collection 'pdf_data' deleted.")
# collection = client.create_collection(
# "pdf_data",
# embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction(
# model_name="sentence-transformers/all-MiniLM-L6-v2"
# ),
# )
# return client, collection
# except Exception as e:
# print("Error setting up ChromaDB:", e)
# raise e
# Step 2: Extract Text from PDF
# def extract_text_from_pdf(pdf_path):
# pdf_text = ""
# with fitz.open(pdf_path) as doc:
# for page in doc:
# pdf_text += page.get_text()
# return pdf_text
def extract_text_from_pdf(uploaded_file):
with fitz.open(stream=uploaded_file.read(), filetype="pdf") as doc:
text = ""
for page in doc:
text += page.get_text()
return text
# Step 3: Add Extracted Text to Vector Database
def add_pdf_text_to_db(collection, pdf_text):
sentences = pdf_text.split("\n") # Split text into lines for granularity
for idx, sentence in enumerate(sentences):
if sentence.strip(): # Avoid empty lines
collection.add(
ids=[f"pdf_text_{idx}"],
documents=[sentence],
metadatas={"line_number": idx, "text": sentence}
)
# Step 4: Query Function
def query_pdf_data(collection, query, retriever_model):
results = collection.query(
query_texts=[query],
n_results=3
)
context = " ".join([doc for doc in results["documents"][0]])
answer = retriever_model(f"Context: {context}\nQuestion: {query}")
return answer, results["metadatas"]
# Streamlit Interface
def main():
st.title("PDF Chatbot with Retrieval-Augmented Generation")
st.write("Upload a PDF, and ask questions about its content!")
# Initialize components
client, collection = setup_chromadb()
retriever_model = pipeline("text2text-generation", model="google/flan-t5-small") # Free LLM
# File upload
uploaded_file = st.file_uploader("Upload your PDF file", type="pdf")
if uploaded_file:
try:
pdf_text = extract_text_from_pdf(uploaded_file)
st.success("Text extracted successfully!")
st.text_area("Extracted Text:", pdf_text, height=300)
except Exception as e:
st.error(f"Error extracting text: {e}")
# if uploaded_file:
# st.write("Extracting text and populating the database...")
# pdf_text = extract_text_from_pdf(uploaded_file)
# add_pdf_text_to_db(collection, pdf_text)
# st.success("PDF text has been added to the database. You can now query it!")
# # Query Input
# query = st.text_input("Enter your query about the PDF:")
# if query:
# try:
# answer, metadata = query_pdf_data(collection, query, retriever_model)
# st.subheader("Answer:")
# st.write(answer[0]['generated_text'])
# st.subheader("Retrieved Context:")
# for meta in metadata[0]:
# st.write(meta)
# except Exception as e:
# st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()
# import tempfile
# import PyPDF2
# import streamlit as st
# from transformers import GPT2LMHeadModel, GPT2Tokenizer
# # Load pre-trained GPT-3 model and tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
# model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
# def extract_text_from_pdf(file_path):
# text = ""
# with open(file_path, "rb") as f:
# reader = PyPDF2.PdfFileReader(f)
# for page_num in range(reader.numPages):
# text += reader.getPage(page_num).extractText()
# return text
# def generate_response(user_input):
# input_ids = tokenizer.encode(user_input, return_tensors="pt")
# output = model.generate(input_ids, max_length=100, num_return_sequences=1, temperature=0.7)
# response = tokenizer.decode(output[0], skip_special_tokens=True)
# return response
# def main():
# st.title("PDF Chatbot")
# pdf_file = st.file_uploader("Upload an pdf file", type=["pdf"], accept_multiple_files=False)
# if pdf_file is not None:
# with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
# tmp_file.write(pdf_file.read())
# st.success("PDF file successfully uploaded and stored temporally.")
# file_path = tmp_file.name
# pdf_text = extract_text_from_pdf(file_path)
# st.text_area("PDF Content", pdf_text)
# else:
# st.markdown('File not found!')
# user_input = st.text_input("You:", "")
# if st.button("Send"):
# response = generate_response(user_input)
# st.text_area("Chatbot:", response)
# if __name__ == "__main__":
# main()