Spaces:
Sleeping
Sleeping
import chromadb | |
from chromadb.utils import embedding_functions | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
import streamlit as st | |
import fitz # PyMuPDF for PDF parsing | |
# # Step 1: Setup ChromaDB | |
# def setup_chromadb(): | |
# # Initialize ChromaDB in-memory instance | |
# client = chromadb.Client() | |
# try: | |
# client.delete_collection("pdf_data") | |
# print("Existing collection 'pdf_data' deleted.") | |
# except: | |
# print("Collection 'pdf_data' not found, creating a new one.") | |
# # Create a new collection with the embedding function | |
# ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# collection = client.create_collection("pdf_data", embedding_function=ef) | |
# return client, collection | |
# import chromadb | |
from chromadb.config import Settings | |
# Configure ChromaDB with persistent SQLite database | |
config = Settings( | |
persist_directory="./chromadb_data", | |
chroma_db_impl="sqlite", | |
) | |
import chromadb | |
# Initialize persistent client with SQLite | |
def setup_chromadb(): | |
client = chromadb.PersistentClient(path="./chromadb_data") | |
collection = client.get_or_create_collection( | |
name="pdf_data", | |
embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="sentence-transformers/all-MiniLM-L6-v2" | |
), | |
) | |
return client, collection | |
# Initialize ChromaDB client | |
# def setup_chromadb(): | |
# try: | |
# client = chromadb.Client(config) | |
# collections = client.list_collections() | |
# print(f"Existing collections: {collections}") | |
# if "pdf_data" in [c.name for c in collections]: | |
# client.delete_collection("pdf_data") | |
# print("Existing collection 'pdf_data' deleted.") | |
# collection = client.create_collection( | |
# "pdf_data", | |
# embedding_function=chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction( | |
# model_name="sentence-transformers/all-MiniLM-L6-v2" | |
# ), | |
# ) | |
# return client, collection | |
# except Exception as e: | |
# print("Error setting up ChromaDB:", e) | |
# raise e | |
# Step 2: Extract Text from PDF | |
# def extract_text_from_pdf(pdf_path): | |
# pdf_text = "" | |
# with fitz.open(pdf_path) as doc: | |
# for page in doc: | |
# pdf_text += page.get_text() | |
# return pdf_text | |
def extract_text_from_pdf(uploaded_file): | |
with fitz.open(stream=uploaded_file.read(), filetype="pdf") as doc: | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
# Step 3: Add Extracted Text to Vector Database | |
def add_pdf_text_to_db(collection, pdf_text): | |
sentences = pdf_text.split("\n") # Split text into lines for granularity | |
for idx, sentence in enumerate(sentences): | |
if sentence.strip(): # Avoid empty lines | |
collection.add( | |
ids=[f"pdf_text_{idx}"], | |
documents=[sentence], | |
metadatas={"line_number": idx, "text": sentence} | |
) | |
# Step 4: Query Function | |
def query_pdf_data(collection, query, retriever_model): | |
results = collection.query( | |
query_texts=[query], | |
n_results=3 | |
) | |
context = " ".join([doc for doc in results["documents"][0]]) | |
answer = retriever_model(f"Context: {context}\nQuestion: {query}") | |
return answer, results["metadatas"] | |
# Streamlit Interface | |
def main(): | |
st.title("PDF Chatbot with Retrieval-Augmented Generation") | |
st.write("Upload a PDF, and ask questions about its content!") | |
# Initialize components | |
client, collection = setup_chromadb() | |
retriever_model = pipeline("text2text-generation", model="google/flan-t5-small") # Free LLM | |
# File upload | |
uploaded_file = st.file_uploader("Upload your PDF file", type="pdf") | |
if uploaded_file: | |
try: | |
pdf_text = extract_text_from_pdf(uploaded_file) | |
st.success("Text extracted successfully!") | |
st.text_area("Extracted Text:", pdf_text, height=300) | |
except Exception as e: | |
st.error(f"Error extracting text: {e}") | |
# if uploaded_file: | |
# st.write("Extracting text and populating the database...") | |
# pdf_text = extract_text_from_pdf(uploaded_file) | |
# add_pdf_text_to_db(collection, pdf_text) | |
# st.success("PDF text has been added to the database. You can now query it!") | |
# # Query Input | |
# query = st.text_input("Enter your query about the PDF:") | |
# if query: | |
# try: | |
# answer, metadata = query_pdf_data(collection, query, retriever_model) | |
# st.subheader("Answer:") | |
# st.write(answer[0]['generated_text']) | |
# st.subheader("Retrieved Context:") | |
# for meta in metadata[0]: | |
# st.write(meta) | |
# except Exception as e: | |
# st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() | |
# import tempfile | |
# import PyPDF2 | |
# import streamlit as st | |
# from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# # Load pre-trained GPT-3 model and tokenizer | |
# tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") | |
# model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2") | |
# def extract_text_from_pdf(file_path): | |
# text = "" | |
# with open(file_path, "rb") as f: | |
# reader = PyPDF2.PdfFileReader(f) | |
# for page_num in range(reader.numPages): | |
# text += reader.getPage(page_num).extractText() | |
# return text | |
# def generate_response(user_input): | |
# input_ids = tokenizer.encode(user_input, return_tensors="pt") | |
# output = model.generate(input_ids, max_length=100, num_return_sequences=1, temperature=0.7) | |
# response = tokenizer.decode(output[0], skip_special_tokens=True) | |
# return response | |
# def main(): | |
# st.title("PDF Chatbot") | |
# pdf_file = st.file_uploader("Upload an pdf file", type=["pdf"], accept_multiple_files=False) | |
# if pdf_file is not None: | |
# with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
# tmp_file.write(pdf_file.read()) | |
# st.success("PDF file successfully uploaded and stored temporally.") | |
# file_path = tmp_file.name | |
# pdf_text = extract_text_from_pdf(file_path) | |
# st.text_area("PDF Content", pdf_text) | |
# else: | |
# st.markdown('File not found!') | |
# user_input = st.text_input("You:", "") | |
# if st.button("Send"): | |
# response = generate_response(user_input) | |
# st.text_area("Chatbot:", response) | |
# if __name__ == "__main__": | |
# main() | |