Easy_RAG / app.py
CamiloVega's picture
Update app.py
c4dde0e verified
raw
history blame
15.2 kB
import os
import shutil
import logging
from typing import List, Dict
import torch
import gradio as gr
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import HuggingFacePipeline
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
UPLOAD_FOLDER = "uploaded_docs"
EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
class RAGSystem:
"""Main RAG system class."""
def __init__(self):
# Initialize device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Initialize folders
self.upload_folder = UPLOAD_FOLDER
if os.path.exists(self.upload_folder):
shutil.rmtree(self.upload_folder)
os.makedirs(self.upload_folder, exist_ok=True)
# Set limits
self.max_files = 5
self.max_file_size = 10 * 1024 * 1024 # 10 MB
self.supported_formats = ['.pdf', '.txt', '.docx']
# Initialize components
self.embeddings = None
self.vector_store = None
self.qa_chain = None
self.documents = []
# Initialize embeddings
self.initialize_embeddings()
def initialize_embeddings(self):
"""Initialize embedding model."""
try:
self.embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL,
model_kwargs={'device': self.device},
encode_kwargs={'normalize_embeddings': True}
)
logger.info(f"Embeddings initialized successfully on {self.device}")
except Exception as e:
logger.error(f"Error initializing embeddings: {str(e)}")
raise
def validate_file(self, file_path: str, file_size: int) -> bool:
"""Validate uploaded file."""
if file_size > self.max_file_size:
raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit")
ext = os.path.splitext(file_path)[1].lower()
if ext not in self.supported_formats:
raise ValueError(f"Unsupported format. Supported: {', '.join(self.supported_formats)}")
return True
def process_file(self, file: gr.File) -> List:
"""Process a single file and return documents."""
try:
file_path = file.name
file_size = os.path.getsize(file_path)
self.validate_file(file_path, file_size)
# Copy file to upload directory
filename = os.path.basename(file_path)
save_path = os.path.join(self.upload_folder, filename)
shutil.copy2(file_path, save_path)
# Load documents based on file type
ext = os.path.splitext(file_path)[1].lower()
if ext == '.pdf':
loader = PyPDFLoader(save_path)
elif ext == '.txt':
loader = TextLoader(save_path)
else: # .docx
loader = Docx2txtLoader(save_path)
documents = loader.load()
for doc in documents:
doc.metadata.update({
'source': filename,
'type': 'uploaded'
})
return documents
except Exception as e:
logger.error(f"Error processing {file_path}: {str(e)}")
raise
def update_vector_store(self, new_documents: List):
"""Update vector store with new documents."""
try:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", ". ", " ", ""]
)
chunks = text_splitter.split_documents(new_documents)
if self.vector_store is None:
self.vector_store = FAISS.from_documents(chunks, self.embeddings)
else:
self.vector_store.add_documents(chunks)
logger.info(f"Vector store updated with {len(chunks)} chunks")
except Exception as e:
logger.error(f"Error updating vector store: {str(e)}")
raise
def initialize_llm(self):
"""Initialize the language model and QA chain."""
try:
# Get Hugging Face token
hf_token = os.environ.get('HUGGINGFACE_TOKEN')
if not hf_token:
raise ValueError("Please set HUGGINGFACE_TOKEN environment variable")
# Login to Hugging Face
login(token=hf_token)
# Initialize model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=hf_token,
trust_remote_code=True
)
# Configure model loading based on device
model_config = {
'device_map': 'auto',
'trust_remote_code': True,
'token': hf_token
}
if self.device == "cuda":
model_config['torch_dtype'] = torch.float16
else:
model_config['low_cpu_mem_usage'] = True
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_config)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.1,
device_map="auto"
)
llm = HuggingFacePipeline(pipeline=pipe)
# Create prompt template
prompt_template = """
Context: {context}
Based on the context above, please provide a clear and concise answer to the following question.
If the information is not in the context, explicitly state so.
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
logger.info("LLM initialized successfully")
except Exception as e:
logger.error(f"Error initializing LLM: {str(e)}")
raise
def process_upload(self, files: List[gr.File]) -> str:
"""Process uploaded files and initialize/update the system."""
if not files:
return "Please select files to upload."
try:
current_files = len(os.listdir(self.upload_folder))
if current_files + len(files) > self.max_files:
return f"Maximum number of documents ({self.max_files}) exceeded"
processed_files = []
new_documents = []
for file in files:
documents = self.process_file(file)
new_documents.extend(documents)
processed_files.append(os.path.basename(file.name))
self.update_vector_store(new_documents)
self.documents.extend(new_documents)
if self.qa_chain is None:
self.initialize_llm()
return f"Successfully processed: {', '.join(processed_files)}"
except Exception as e:
return f"Error: {str(e)}"
def generate_response(self, question: str) -> Dict:
"""Generate response for a given question."""
if not self.qa_chain:
return {"error": "System not initialized. Please upload documents first."}
try:
result = self.qa_chain({"query": question})
response = {
'answer': result['result'],
'sources': []
}
for doc in result['source_documents']:
source = {
'title': doc.metadata.get('source', 'Unknown'),
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content
}
response['sources'].append(source)
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {"error": str(e)}
# Initialize system
rag_system = RAGSystem()
def process_query(message: str, history: List) -> List:
"""Process user query and return updated history."""
try:
if not rag_system.qa_chain:
return history + [(message, "Please upload documents first.")]
response = rag_system.generate_response(message)
if "error" in response:
return history + [(message, f"Error: {response['error']}")]
answer = response['answer']
sources = set([source['title'] for source in response['sources']])
if sources:
answer += "\n\nπŸ“š Sources:\n" + "\n".join([f"β€’ {source}" for source in sources])
return history + [(message, answer)]
except Exception as e:
return history + [(message, f"Error: {str(e)}")]
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 1rem;">
<h1 style="color: #2d333a;">πŸ€– Easy RAG</h1>
<p style="color: #4a5568;">A simple and powerful RAG system for your documents</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.HTML("""
<div style="padding: 1rem; border: 1px solid #e5e7eb; border-radius: 0.5rem; background-color: white;">
<h3 style="margin-top: 0;">πŸ“ Upload Documents</h3>
""")
file_output = gr.File(
file_count="multiple",
label="Select Files",
elem_id="file-upload"
)
gr.HTML("""
<div style="font-size: 0.8em; color: #666;">
<p>β€’ Maximum 5 files</p>
<p>β€’ 10MB per file</p>
<p>β€’ Supported: PDF, TXT, DOCX</p>
</div>
""")
system_output = gr.Textbox(
label="Status",
interactive=False
)
gr.HTML("</div>")
with gr.Column(scale=3):
chatbot = gr.Chatbot(
value=[],
label="Chat",
height=600,
show_copy_button=True
)
with gr.Row():
message = gr.Textbox(
placeholder="Ask a question about your documents...",
show_label=False,
container=False,
scale=8
)
clear = gr.Button("πŸ—‘οΈ", size="sm", scale=1)
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 1rem;
background-color: #f8f9fa; border-radius: 10px;">
<div style="margin-bottom: 1rem;">
<h3 style="color: #2d333a;">πŸ” About Easy RAG</h3>
<p style="color: #666; font-size: 0.9em;">
Powered by state-of-the-art AI technology:
</p>
<ul style="list-style: none; color: #666; font-size: 0.9em;">
<li>πŸ”Ή LLM: Llama-2-7b-chat-hf</li>
<li>πŸ”Ή Embeddings: multilingual-e5-large</li>
<li>πŸ”Ή Vector Store: FAISS</li>
</ul>
</div>
<div style="border-top: 1px solid #ddd; padding-top: 1rem;">
<p style="color: #666; font-size: 0.8em;">
Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>
</p>
</div>
</div>
""")
# Set up event handlers
file_output.upload(
rag_system.process_upload,
inputs=[file_output],
outputs=[system_output]
)
message.submit(
process_query,
inputs=[message, chatbot],
outputs=[chatbot]
)
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
# Log system information
logger.info("Starting Easy RAG system...")
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
else:
logger.info("Running on CPU mode with optimizations")
# Check for HUGGINGFACE_TOKEN
if not os.environ.get('HUGGINGFACE_TOKEN'):
logger.warning("HUGGINGFACE_TOKEN not found in environment variables")
logger.warning("Please set it before running the application")
print("Please set your HUGGINGFACE_TOKEN environment variable")
print("Example: export HUGGINGFACE_TOKEN=your_token_here")
exit(1)
# Get sharing preference from environment
share_enabled = os.environ.get('SHARE_APP', 'false').lower() == 'true'
if share_enabled:
logger.info("Public sharing is enabled - a public URL will be generated")
try:
# Launch the application
demo.launch(
server_name="0.0.0.0", # Listen on all network interfaces
server_port=7860, # Default Gradio port
share=share_enabled, # Generate public URL if enabled
show_error=True, # Show detailed error messages
quiet=True # Reduce console output noise
)
except KeyboardInterrupt:
logger.info("Shutting down server...")
except Exception as e:
logger.error(f"Error launching server: {str(e)}")
raise