Easy_RAG / app.py
CamiloVega's picture
Create app.py
f498b40 verified
raw
history blame
11.3 kB
import os
import logging
from typing import List, Dict
import torch
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
UPLOAD_FOLDER = "uploaded_docs"
class DocumentManager:
"""Class to manage document uploads and processing."""
def __init__(self):
self.upload_folder = UPLOAD_FOLDER
os.makedirs(self.upload_folder, exist_ok=True)
self.max_files = 5
self.max_file_size = 10 * 1024 * 1024 # 10 MB
self.supported_formats = ['.pdf', '.txt', '.docx']
self.documents = []
def validate_file(self, file):
if os.path.getsize(file.name) > self.max_file_size:
raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit")
ext = os.path.splitext(file.name)[1].lower()
if ext not in self.supported_formats:
raise ValueError(f"Unsupported file format. Supported formats: {', '.join(self.supported_formats)}")
def load_document(self, file_path: str) -> List:
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
loader = PyPDFLoader(file_path)
elif ext == '.txt':
loader = TextLoader(file_path)
elif ext == '.docx':
loader = Docx2txtLoader(file_path)
else:
raise ValueError(f"Unsupported file format: {ext}")
documents = loader.load()
for doc in documents:
doc.metadata.update({
'source': os.path.basename(file_path),
'type': 'uploaded'
})
return documents
except Exception as e:
logger.error(f"Error loading {file_path}: {str(e)}")
raise
def process_upload(self, files: List) -> str:
if len(os.listdir(self.upload_folder)) + len(files) > self.max_files:
raise ValueError(f"Maximum number of documents ({self.max_files}) exceeded")
processed_files = []
for file in files:
try:
self.validate_file(file)
save_path = os.path.join(self.upload_folder, file.name)
file.save(save_path)
docs = self.load_document(save_path)
self.documents.extend(docs)
processed_files.append(file.name)
except Exception as e:
logger.error(f"Error processing {file.name}: {str(e)}")
return f"Error processing {file.name}: {str(e)}"
return f"Successfully processed files: {', '.join(processed_files)}"
class RAGSystem:
"""Main RAG system class."""
def __init__(self, model_name: str = MODEL_NAME):
self.model_name = model_name
self.document_manager = DocumentManager()
self.embeddings = None
self.vector_store = None
self.qa_chain = None
self.is_initialized = False
def initialize_system(self, documents: List = None):
"""Initialize RAG system with provided documents."""
try:
if not documents:
raise ValueError("No documents provided for initialization")
# Initialize text splitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ". ", " ", ""]
)
# Process documents
chunks = text_splitter.split_documents(documents)
# Initialize embeddings
self.embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
)
# Create vector store
self.vector_store = FAISS.from_documents(chunks, self.embeddings)
# Initialize LLM pipeline
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.1,
device_map="auto"
)
llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
prompt_template = """
Context: {context}
Based on the context above, please provide a clear and concise answer to the following question.
If the information is not in the context, explicitly state so.
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
# Set up QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
self.is_initialized = True
return "System initialized successfully"
except Exception as e:
logger.error(f"Error during system initialization: {str(e)}")
return f"Error: {str(e)}"
def generate_response(self, question: str) -> Dict:
"""Generate response for a given question."""
if not self.is_initialized:
return {"error": "System not initialized. Please upload documents first."}
try:
result = self.qa_chain({"query": question})
response = {
'answer': result['result'],
'sources': []
}
for doc in result['source_documents']:
source = {
'title': doc.metadata.get('source', 'Unknown'),
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
}
response['sources'].append(source)
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {"error": str(e)}
# Initialize RAG system
rag_system = RAGSystem()
def process_file_upload(files):
"""Handle file uploads and system initialization."""
try:
upload_result = rag_system.document_manager.process_upload(files)
if "Error" in upload_result:
return upload_result
init_result = rag_system.initialize_system(rag_system.document_manager.documents)
return f"{upload_result}\n{init_result}"
except Exception as e:
return f"Error: {str(e)}"
def process_query(message, history):
"""Process user query and generate response."""
try:
if not rag_system.is_initialized:
return history + [(message, "Please upload documents first.")]
response = rag_system.generate_response(message)
if "error" in response:
return history + [(message, f"Error: {response['error']}")]
answer = response['answer']
sources = set([source['title'] for source in response['sources']])
if sources:
answer += "\n\nπŸ“š Sources:\n" + "\n".join([f"β€’ {source}" for source in sources])
return history + [(message, answer)]
except Exception as e:
return history + [(message, f"Error: {str(e)}")]
# Create Gradio interface
demo = gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}")
with demo:
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
<h1 style="color: #2d333a;">πŸ€– Easy RAG</h1>
<p style="color: #4a5568;">
A simple and powerful RAG system for your documents
</p>
</div>
""")
with gr.Row():
file_output = gr.File(
file_count="multiple",
label="Upload Documents (PDF, TXT, DOCX - Max 5 files, 10MB each)"
)
upload_button = gr.Button("Upload and Initialize")
system_output = gr.Textbox(label="System Status")
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=400,
show_copy_button=True
)
with gr.Row():
message = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False,
container=False,
scale=8
)
clear = gr.Button("πŸ—‘οΈ Clear", size="sm", scale=1)
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
background-color: #f8f9fa; border-radius: 10px;">
<div style="margin-bottom: 15px;">
<h3 style="color: #2d333a;">πŸ” About Easy RAG</h3>
<p style="color: #666; font-size: 14px;">
A powerful RAG system that lets you query your documents using:
</p>
<ul style="list-style: none; color: #666; font-size: 14px;">
<li>πŸ”Ή LLM: Llama-2-7b-chat-hf</li>
<li>πŸ”Ή Embeddings: multilingual-e5-large</li>
<li>πŸ”Ή Vector Store: FAISS</li>
</ul>
</div>
<div style="border-top: 1px solid #ddd; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>
</p>
</div>
</div>
""")
# Set up event handlers
upload_button.click(
process_file_upload,
inputs=[file_output],
outputs=[system_output]
)
message.submit(
process_query,
inputs=[message, chatbot],
outputs=[chatbot]
)
clear.click(lambda: None, None, chatbot)
demo.launch()