import os import logging from typing import List, Dict import torch import gradio as gr from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.llms import HuggingFacePipeline from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" UPLOAD_FOLDER = "uploaded_docs" class DocumentManager: """Class to manage document uploads and processing.""" def __init__(self): self.upload_folder = UPLOAD_FOLDER os.makedirs(self.upload_folder, exist_ok=True) self.max_files = 5 self.max_file_size = 10 * 1024 * 1024 # 10 MB self.supported_formats = ['.pdf', '.txt', '.docx'] self.documents = [] def validate_file(self, file): if os.path.getsize(file.name) > self.max_file_size: raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit") ext = os.path.splitext(file.name)[1].lower() if ext not in self.supported_formats: raise ValueError(f"Unsupported file format. Supported formats: {', '.join(self.supported_formats)}") def load_document(self, file_path: str) -> List: ext = os.path.splitext(file_path)[1].lower() try: if ext == '.pdf': loader = PyPDFLoader(file_path) elif ext == '.txt': loader = TextLoader(file_path) elif ext == '.docx': loader = Docx2txtLoader(file_path) else: raise ValueError(f"Unsupported file format: {ext}") documents = loader.load() for doc in documents: doc.metadata.update({ 'source': os.path.basename(file_path), 'type': 'uploaded' }) return documents except Exception as e: logger.error(f"Error loading {file_path}: {str(e)}") raise def process_upload(self, files: List) -> str: if len(os.listdir(self.upload_folder)) + len(files) > self.max_files: raise ValueError(f"Maximum number of documents ({self.max_files}) exceeded") processed_files = [] for file in files: try: self.validate_file(file) save_path = os.path.join(self.upload_folder, file.name) file.save(save_path) docs = self.load_document(save_path) self.documents.extend(docs) processed_files.append(file.name) except Exception as e: logger.error(f"Error processing {file.name}: {str(e)}") return f"Error processing {file.name}: {str(e)}" return f"Successfully processed files: {', '.join(processed_files)}" class RAGSystem: """Main RAG system class.""" def __init__(self, model_name: str = MODEL_NAME): self.model_name = model_name self.document_manager = DocumentManager() self.embeddings = None self.vector_store = None self.qa_chain = None self.is_initialized = False def initialize_system(self, documents: List = None): """Initialize RAG system with provided documents.""" try: if not documents: raise ValueError("No documents provided for initialization") # Initialize text splitter text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ". ", " ", ""] ) # Process documents chunks = text_splitter.split_documents(documents) # Initialize embeddings self.embeddings = HuggingFaceEmbeddings( model_name="intfloat/multilingual-e5-large", model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'} ) # Create vector store self.vector_store = FAISS.from_documents(chunks, self.embeddings) # Initialize LLM pipeline tokenizer = AutoTokenizer.from_pretrained(self.model_name) model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto" ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.1, device_map="auto" ) llm = HuggingFacePipeline(pipeline=pipe) # Create prompt template prompt_template = """ Context: {context} Based on the context above, please provide a clear and concise answer to the following question. If the information is not in the context, explicitly state so. Question: {question} """ PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) # Set up QA chain self.qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}), return_source_documents=True, chain_type_kwargs={"prompt": PROMPT} ) self.is_initialized = True return "System initialized successfully" except Exception as e: logger.error(f"Error during system initialization: {str(e)}") return f"Error: {str(e)}" def generate_response(self, question: str) -> Dict: """Generate response for a given question.""" if not self.is_initialized: return {"error": "System not initialized. Please upload documents first."} try: result = self.qa_chain({"query": question}) response = { 'answer': result['result'], 'sources': [] } for doc in result['source_documents']: source = { 'title': doc.metadata.get('source', 'Unknown'), 'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content } response['sources'].append(source) return response except Exception as e: logger.error(f"Error generating response: {str(e)}") return {"error": str(e)} # Initialize RAG system rag_system = RAGSystem() def process_file_upload(files): """Handle file uploads and system initialization.""" try: upload_result = rag_system.document_manager.process_upload(files) if "Error" in upload_result: return upload_result init_result = rag_system.initialize_system(rag_system.document_manager.documents) return f"{upload_result}\n{init_result}" except Exception as e: return f"Error: {str(e)}" def process_query(message, history): """Process user query and generate response.""" try: if not rag_system.is_initialized: return history + [(message, "Please upload documents first.")] response = rag_system.generate_response(message) if "error" in response: return history + [(message, f"Error: {response['error']}")] answer = response['answer'] sources = set([source['title'] for source in response['sources']]) if sources: answer += "\n\nšŸ“š Sources:\n" + "\n".join([f"ā€¢ {source}" for source in sources]) return history + [(message, answer)] except Exception as e: return history + [(message, f"Error: {str(e)}")] # Create Gradio interface demo = gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") with demo: gr.HTML("""

šŸ¤– Easy RAG

A simple and powerful RAG system for your documents

""") with gr.Row(): file_output = gr.File( file_count="multiple", label="Upload Documents (PDF, TXT, DOCX - Max 5 files, 10MB each)" ) upload_button = gr.Button("Upload and Initialize") system_output = gr.Textbox(label="System Status") chatbot = gr.Chatbot( show_label=False, container=True, height=400, show_copy_button=True ) with gr.Row(): message = gr.Textbox( placeholder="Ask a question about your documents...", show_label=False, container=False, scale=8 ) clear = gr.Button("šŸ—‘ļø Clear", size="sm", scale=1) gr.HTML("""

šŸ” About Easy RAG

A powerful RAG system that lets you query your documents using:

Based on original work by Camilo Vega

""") # Set up event handlers upload_button.click( process_file_upload, inputs=[file_output], outputs=[system_output] ) message.submit( process_query, inputs=[message, chatbot], outputs=[chatbot] ) clear.click(lambda: None, None, chatbot) demo.launch()