Spaces:
Running
Running
| import os | |
| import gradio as gr | |
| import torch | |
| import logging | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.document_loaders import PyMuPDFLoader # β More stable PDF loader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import pipeline | |
| # Setup Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set Hugging Face Cache Directory | |
| os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
| # Check for GPU availability | |
| DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Global variables | |
| conversation_retrieval_chain = None | |
| chat_history = [] | |
| llm_pipeline = None | |
| embeddings = None | |
| persist_directory = "/tmp/chroma_db" # Storage for vector DB | |
| def init_llm(): | |
| """Initialize LLM and Embeddings""" | |
| global llm_pipeline, embeddings | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if not hf_token: | |
| raise ValueError("HUGGINGFACEHUB_API_TOKEN is not set in environment variables.") | |
| model_id = "tiiuae/falcon-rw-1b" # β Can switch to "tiiuae/falcon-rw-1b" for lighter model | |
| hf_pipeline = pipeline("text-generation", model=model_id, device=DEVICE) | |
| llm_pipeline = HuggingFacePipeline(pipeline=hf_pipeline) | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": DEVICE} | |
| ) | |
| logger.info("β LLM and Embeddings Initialized Successfully!") | |
| def process_document(file): | |
| """Process uploaded PDF and create a retriever""" | |
| global conversation_retrieval_chain | |
| if not llm_pipeline or not embeddings: | |
| init_llm() | |
| try: | |
| file_path = file.name # β Ensures correct file path is passed | |
| logger.info(f"π Processing PDF: {file_path}") | |
| loader = PyMuPDFLoader(file_path) # β Alternative loader for stability | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64) | |
| texts = text_splitter.split_documents(documents) | |
| # Load or create ChromaDB | |
| db = Chroma.from_documents(texts, embedding=embeddings, persist_directory=persist_directory) | |
| retriever = db.as_retriever(search_type="similarity", search_kwargs={'k': 6}) | |
| conversation_retrieval_chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm_pipeline, retriever=retriever | |
| ) | |
| logger.info("β PDF Processed Successfully!") | |
| return "π PDF uploaded and processed successfully! You can now ask questions." | |
| except Exception as e: | |
| logger.error(f"β Error processing PDF: {str(e)}") | |
| return f"β Error processing PDF: {str(e)}" | |
| def process_prompt(prompt, chat_history_display): | |
| """Generate a response using the retrieval chain""" | |
| global conversation_retrieval_chain, chat_history | |
| if not conversation_retrieval_chain: | |
| return chat_history_display + [("β No document uploaded.", "Please upload a PDF first.")] | |
| output = conversation_retrieval_chain({"question": prompt, "chat_history": chat_history}) | |
| answer = output["answer"] | |
| chat_history.append((prompt, answer)) | |
| return chat_history | |
| # Define Gradio UI | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>Personal Data Assistant</h1>") | |
| with gr.Row(): | |
| dark_mode = gr.Checkbox(label="π Toggle light/dark mode") | |
| with gr.Column(): | |
| gr.Markdown("Hello there! I'm your friendly data assistant, ready to answer any questions regarding your data. Could you please upload a PDF file for me to analyze?") | |
| file_input = gr.File(label="Upload File") | |
| upload_button = gr.Button("π Upload File") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| chat_history_display = gr.Chatbot(label="Chat History") | |
| with gr.Row(): | |
| user_input = gr.Textbox(placeholder="Type your message here...", scale=4) | |
| submit_button = gr.Button("π©", scale=1) | |
| clear_button = gr.Button("π", scale=1) | |
| # Button Click Actions | |
| upload_button.click(process_document, inputs=file_input, outputs=status_output) | |
| submit_button.click(process_prompt, inputs=[user_input, chat_history_display], outputs=chat_history_display) | |
| clear_button.click(lambda: [], outputs=chat_history_display) | |
| # Launch Gradio App | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) # β Works in Hugging Face Spaces |