Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from dotenv import load_dotenv | |
from transformers import pipeline | |
from io import BytesIO | |
from pypdf import PdfReader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from main import get_index_for_pdf # Assuming 'main.py' contains this function | |
# Initialize session state for the app | |
if "vectordb" not in st.session_state: | |
st.session_state["vectordb"] = None | |
if "prompt" not in st.session_state: | |
st.session_state["prompt"] = [{"role": "system", "content": "none"}] | |
# Set the title for the Streamlit app | |
st.title("RAG Enhance Chatbot") | |
# Hugging Face API Key (avoid hardcoding for production) | |
load_dotenv() | |
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY") | |
# st.title('Model Configuration') | |
# model_name = st.sidebar.selectbox( | |
# "Choose a Hugging Face Model", | |
# [ | |
# "sentence-transformers/all-mpnet-base-v2", | |
# "sentence-transformers/all-MiniLM-L6-v2", | |
# "msmarco-distilbert-base-tas-b", | |
# "deepset/roberta-large-squad2", | |
# "facebook/dpr-ctx_encoder-single-nq-base" | |
# ], | |
# index=0 # Default model | |
# ) | |
# Define the QA pipeline | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="deepset/roberta-base-squad2", # Replace with your desired model | |
use_auth_token=HUGGINGFACE_API_KEY | |
) | |
# Define a prompt template for the assistant | |
prompt_template = """ | |
You are a helpful Assistant who answers users' questions based on PDF extracts. | |
Keep your answer lengthy and if long make points. | |
Context information includes 'filename' and 'page'. Always reference these in your responses. | |
If the text is irrelevant or insufficient to answer, respond with "Not applicable." | |
The provided PDF content is: | |
{pdf_extract} | |
""" | |
# Cached function to create a vector database for the provided PDF files | |
def create_vectordb(files, filenames, huggingface_model_name): | |
# Show a spinner while creating the vector database | |
with st.spinner("Creating Vector Database..."): | |
vectordb = get_index_for_pdf( | |
[file.getvalue() for file in files], filenames, huggingface_model_name | |
) | |
return vectordb | |
# Upload PDF files using Streamlit file uploader | |
pdf_files = st.file_uploader("Upload your PDFs", type="pdf", accept_multiple_files=True) | |
# If PDF files are uploaded, create the vector database and store it in the session state | |
if pdf_files: | |
pdf_file_names = [file.name for file in pdf_files] | |
huggingface_model_name = "sentence-transformers/all-MiniLM-L6-v2" # Correct model name | |
st.session_state["vectordb"] = create_vectordb(pdf_files, pdf_file_names, huggingface_model_name) | |
# Display previous chat messages | |
for message in st.session_state["prompt"]: | |
if message["role"] != "system": | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Get the user's question using Streamlit chat input | |
question = st.chat_input("Ask anything") | |
# Handle the user's question | |
if question: | |
vectordb = st.session_state.get("vectordb", None) | |
if not vectordb: | |
with st.chat_message("assistant"): | |
st.write("You need to upload a PDF first.") | |
st.stop() | |
# Search the vector database for similar content to the user's question | |
search_results = vectordb.similarity_search(question, k=3) | |
pdf_extract = "\n".join( | |
[ | |
f"{result.page_content} (Filename: {result.metadata['filename']}, Page: {result.metadata['page']})" | |
for result in search_results | |
] | |
) | |
# Use the QA pipeline with the context | |
response = qa_pipeline(question=question, context=pdf_extract) | |
# Update the assistant's response | |
with st.chat_message("assistant"): | |
st.write(response["answer"]) | |
# Update the session state prompt | |
st.session_state["prompt"].append({"role": "user", "content": question}) | |
st.session_state["prompt"].append({"role": "assistant", "content": response["answer"]}) | |