import os import shutil import logging from typing import List, Dict import torch import gradio as gr from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain_community.llms import HuggingFacePipeline from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from huggingface_hub import login import bitsandbytes as bnb # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Constants MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" UPLOAD_FOLDER = "uploaded_docs" EMBEDDING_MODEL = "intfloat/multilingual-e5-large" class RAGSystem: """Main RAG system class.""" def __init__(self): # Initialize device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Initialize folders self.upload_folder = UPLOAD_FOLDER if os.path.exists(self.upload_folder): shutil.rmtree(self.upload_folder) os.makedirs(self.upload_folder, exist_ok=True) # Set limits self.max_files = 5 self.max_file_size = 10 * 1024 * 1024 # 10 MB self.supported_formats = ['.pdf', '.txt', '.docx'] # Initialize components self.embeddings = None self.vector_store = None self.qa_chain = None self.documents = [] # Initialize embeddings self.initialize_embeddings() def initialize_embeddings(self): """Initialize embedding model.""" try: self.embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={ 'device': self.device, 'torch_dtype': torch.float32, } ) logger.info("Embeddings initialized successfully") except Exception as e: logger.error(f"Error initializing embeddings: {str(e)}") raise def initialize_llm(self): """Initialize the language model and QA chain.""" try: # Get Hugging Face token hf_token = os.environ.get('HUGGINGFACE_TOKEN') if not hf_token: raise ValueError("Please set HUGGINGFACE_TOKEN environment variable") # Login to Hugging Face login(token=hf_token) # Configure model loading based on device if self.device == "cuda": model_config = { 'torch_dtype': torch.float16, 'device_map': "auto", } else: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model_config = { 'quantization_config': quantization_config, 'device_map': "auto", 'torch_dtype': torch.float32, 'low_cpu_mem_usage': True, } # Initialize tokenizer and model tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, token=hf_token, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, token=hf_token, trust_remote_code=True, **model_config ) # Create pipeline pipe_config = { "model": model, "tokenizer": tokenizer, "max_new_tokens": 512, "temperature": 0.1, "device_map": "auto", "torch_dtype": torch.float32 if self.device == "cpu" else torch.float16, } if self.device == "cpu": pipe_config["model"] = pipe_config["model"].to('cpu') pipe = pipeline("text-generation", **pipe_config) # Create QA chain llm = HuggingFacePipeline(pipeline=pipe) prompt_template = """ Context: {context} Based on the context above, please provide a clear and concise answer to the following question. If the information is not in the context, explicitly state so. Question: {question} """ PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) self.qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}), return_source_documents=True, chain_type_kwargs={"prompt": PROMPT} ) logger.info("LLM initialized successfully") except Exception as e: logger.error(f"Error initializing LLM: {str(e)}") raise def validate_file(self, file_path: str, file_size: int) -> bool: """Validate uploaded file.""" if file_size > self.max_file_size: raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit") ext = os.path.splitext(file_path)[1].lower() if ext not in self.supported_formats: raise ValueError(f"Unsupported format. Supported: {', '.join(self.supported_formats)}") return True def process_file(self, file: gr.File) -> List: """Process a single file and return documents.""" try: file_path = file.name file_size = os.path.getsize(file_path) self.validate_file(file_path, file_size) # Copy file to upload directory filename = os.path.basename(file_path) save_path = os.path.join(self.upload_folder, filename) shutil.copy2(file_path, save_path) # Load documents based on file type ext = os.path.splitext(file_path)[1].lower() if ext == '.pdf': loader = PyPDFLoader(save_path) elif ext == '.txt': loader = TextLoader(save_path) else: # .docx loader = Docx2txtLoader(save_path) documents = loader.load() for doc in documents: doc.metadata.update({ 'source': filename, 'type': 'uploaded' }) return documents except Exception as e: logger.error(f"Error processing {file_path}: {str(e)}") raise def update_vector_store(self, new_documents: List): """Update vector store with new documents.""" try: text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ". ", " ", ""] ) chunks = text_splitter.split_documents(new_documents) if self.vector_store is None: self.vector_store = FAISS.from_documents(chunks, self.embeddings) else: self.vector_store.add_documents(chunks) logger.info(f"Vector store updated with {len(chunks)} chunks") except Exception as e: logger.error(f"Error updating vector store: {str(e)}") raise def process_upload(self, files: List[gr.File]) -> str: """Process uploaded files and initialize/update the system.""" if not files: return "Please select files to upload." try: current_files = len(os.listdir(self.upload_folder)) if current_files + len(files) > self.max_files: return f"Maximum number of documents ({self.max_files}) exceeded" processed_files = [] new_documents = [] for file in files: documents = self.process_file(file) new_documents.extend(documents) processed_files.append(os.path.basename(file.name)) self.update_vector_store(new_documents) self.documents.extend(new_documents) if self.qa_chain is None: self.initialize_llm() return f"Successfully processed: {', '.join(processed_files)}" except Exception as e: return f"Error: {str(e)}" def generate_response(self, question: str) -> Dict: """Generate response for a given question.""" if not self.qa_chain: return {"error": "System not initialized. Please upload documents first."} try: result = self.qa_chain({"query": question}) response = { 'answer': result['result'], 'sources': [] } for doc in result['source_documents']: source = { 'title': doc.metadata.get('source', 'Unknown'), 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content } response['sources'].append(source) return response except Exception as e: logger.error(f"Error generating response: {str(e)}") return {"error": str(e)} # Initialize system rag_system = RAGSystem() def process_query(message: str, history: List) -> List: """Process user query and return updated history.""" try: if not rag_system.qa_chain: return history + [(message, "Please upload documents first.")] response = rag_system.generate_response(message) if "error" in response: return history + [(message, f"Error: {response['error']}")] answer = response['answer'] sources = set([source['title'] for source in response['sources']]) if sources: answer += "\n\nšŸ“š Sources:\n" + "\n".join([f"• {source}" for source in sources]) return history + [(message, answer)] except Exception as e: return history + [(message, f"Error: {str(e)}")] # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.HTML("""

šŸ¤– Easy RAG

A simple and powerful RAG system for your documents

""") with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.HTML("""

šŸ“ Upload Documents

""") file_output = gr.File( file_count="multiple", label="Select Files", elem_id="file-upload" ) gr.HTML("""

• Maximum 5 files

• 10MB per file

• Supported: PDF, TXT, DOCX

""") system_output = gr.Textbox( label="Status", interactive=False ) gr.HTML("
") with gr.Column(scale=3): chatbot = gr.Chatbot( show_label=False, container=True, height=600, show_copy_button=True ) with gr.Row(): message = gr.Textbox( placeholder="Ask a question about your documents...", show_label=False, container=False, scale=8 ) clear = gr.Button("šŸ—‘ļø", size="sm", scale=1) gr.HTML("""

šŸ” About Easy RAG

Powered by state-of-the-art AI technology:

Based on original work by Camilo Vega

""") # Set up event handlers file_output.upload( rag_system.process_upload, inputs=[file_output], outputs=[system_output] ) message.submit( process_query, inputs=[message, chatbot], outputs=[chatbot] ) clear.click(lambda: None, None, chatbot) if __name__ == "__main__": # Log system information logger.info("Starting Easy RAG system...") logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}") else: logger.info("Running on CPU mode with optimizations") # Check for HUGGINGFACE_TOKEN if not os.environ.get('HUGGINGFACE_TOKEN'): logger.warning("HUGGINGFACE_TOKEN not found in environment variables") logger.warning("Please set it before running the application") print("Please set your HUGGINGFACE_TOKEN environment variable") print("Example: export HUGGINGFACE_TOKEN=your_token_here") exit(1) # Create upload directory if it doesn't exist if not os.path.exists(UPLOAD_FOLDER): os.makedirs(UPLOAD_FOLDER) logger.info(f"Created upload directory: {UPLOAD_FOLDER}") try: # Launch the Gradio interface demo.launch( share=False, # Set to True if you want to create a public link server_name="0.0.0.0", # Listen on all network interfaces server_port=7860, # Default Gradio port show_error=True, enable_queue=True ) except Exception as e: logger.error(f"Error launching Gradio interface: {str(e)}") raise