Easy_RAG / app.py
CamiloVega's picture
Update app.py
630fea2 verified
import os
import logging
from typing import List, Dict
import torch
import gradio as gr
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain_community.document_loaders import (
PyPDFLoader,
Docx2txtLoader,
CSVLoader,
UnstructuredFileLoader
)
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import spaces
import tempfile
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Constants
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
SUPPORTED_FORMATS = [".pdf", ".docx", ".doc", ".csv", ".txt"]
class DocumentLoader:
"""Enhanced document loader supporting multiple file formats."""
@staticmethod
def load_file(file_path: str) -> List:
"""Load a single file based on its extension."""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == '.pdf':
loader = PyPDFLoader(file_path)
elif ext in ['.docx', '.doc']:
loader = Docx2txtLoader(file_path)
elif ext == '.csv':
loader = CSVLoader(file_path)
else: # fallback for txt and other text files
loader = UnstructuredFileLoader(file_path)
documents = loader.load()
# Add metadata
for doc in documents:
doc.metadata.update({
'title': os.path.basename(file_path),
'type': 'document',
'format': ext[1:],
'language': 'auto'
})
logger.info(f"Successfully loaded {file_path}")
return documents
except Exception as e:
logger.error(f"Error loading {file_path}: {str(e)}")
raise
class RAGSystem:
"""Enhanced RAG system with dynamic document loading."""
def __init__(self, model_name: str = MODEL_NAME):
self.model_name = model_name
self.embeddings = None
self.vector_store = None
self.qa_chain = None
self.tokenizer = None
self.model = None
self.is_initialized = False
self.processed_files = set() # Mantener registro de archivos procesados
def initialize_model(self):
"""Initialize the base model and tokenizer."""
try:
logger.info("Initializing language model...")
# Initialize embeddings
self.embeddings = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-large",
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
# Initialize model and tokenizer
# Get HuggingFace token
hf_token = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN')
if not hf_token:
raise ValueError("No Hugging Face token found. Please set HUGGINGFACE_TOKEN in your environment variables")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
token=hf_token, # Add token here
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
token=hf_token, # Add token here
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto"
)
# Create generation pipeline
pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.15,
device_map="auto"
)
self.llm = HuggingFacePipeline(pipeline=pipe)
self.is_initialized = True
logger.info("Model initialization completed")
except Exception as e:
logger.error(f"Error during model initialization: {str(e)}")
raise
def process_documents(self, files: List[tempfile._TemporaryFileWrapper]) -> None:
"""Process uploaded documents and update the vector store."""
try:
documents = []
new_files = []
# Procesar solo archivos nuevos
for file in files:
if file.name not in self.processed_files:
docs = DocumentLoader.load_file(file.name)
documents.extend(docs)
new_files.append(file.name)
self.processed_files.add(file.name)
if not new_files:
logger.info("No new documents to process")
return
if not documents:
raise ValueError("No documents were successfully loaded.")
# Process documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=200,
separators=["\n\n", "\n", ". ", " ", ""],
length_function=len
)
chunks = text_splitter.split_documents(documents)
# Create or update vector store
if self.vector_store is None:
self.vector_store = FAISS.from_documents(chunks, self.embeddings)
else:
self.vector_store.add_documents(chunks)
# Initialize QA chain
prompt_template = """
Context: {context}
Based on the provided context, please answer the following question clearly and concisely.
If the information is not in the context, please say so explicitly.
Question: {question}
"""
PROMPT = PromptTemplate(
template=prompt_template,
input_variables=["context", "question"]
)
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(
search_kwargs={"k": 6}
),
return_source_documents=True,
chain_type_kwargs={"prompt": PROMPT}
)
logger.info(f"Successfully processed {len(documents)} documents")
except Exception as e:
logger.error(f"Error processing documents: {str(e)}")
raise
def generate_response(self, question: str) -> Dict:
"""Generate response for a given question."""
if not self.is_initialized or self.qa_chain is None:
return {
'answer': "Please upload some documents first before asking questions.",
'sources': []
}
try:
result = self.qa_chain({"query": question})
response = {
'answer': result['result'],
'sources': []
}
for doc in result['source_documents']:
source = {
'title': doc.metadata.get('title', 'Unknown'),
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
'metadata': doc.metadata
}
response['sources'].append(source)
return response
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise
@spaces.GPU(duration=60)
def process_response(user_input: str, chat_history: List, files: List) -> tuple:
"""Process user input and generate response."""
try:
if not rag_system.is_initialized:
rag_system.initialize_model()
# Siempre procesar documentos si hay archivos nuevos
if files:
rag_system.process_documents(files)
response = rag_system.generate_response(user_input)
# Clean and format response
answer = response['answer']
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
# Format sources
sources = set([source['title'] for source in response['sources'][:3]])
if sources:
answer += "\n\nπŸ“š Sources consulted:\n" + "\n".join([f"β€’ {source}" for source in sources])
chat_history.append((user_input, answer))
return chat_history
except Exception as e:
logger.error(f"Error in process_response: {str(e)}")
error_message = f"Sorry, an error occurred: {str(e)}"
chat_history.append((user_input, error_message))
return chat_history
# Initialize RAG system
logger.info("Initializing RAG system...")
try:
rag_system = RAGSystem()
logger.info("RAG system created successfully")
except Exception as e:
logger.error(f"Failed to create RAG system: {str(e)}")
raise
# Create Gradio interface
try:
logger.info("Creating Gradio interface...")
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
<h1 style="color: #2d333a;">πŸ“š Easy RAG</h1>
<p style="color: #4a5568;">
Your AI Assistant for Document Analysis and Q&A
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
files = gr.Files(
label="Upload Your Documents",
file_types=SUPPORTED_FORMATS,
file_count="multiple"
)
gr.HTML("""
<div style="font-size: 0.9em; color: #666; margin-top: 0.5em;">
Supported formats: PDF, DOCX, CSV, TXT
</div>
""")
chatbot = gr.Chatbot(
show_label=False,
container=True,
height=500,
bubble_full_width=True,
show_copy_button=True,
scale=2
)
with gr.Row():
message = gr.Textbox(
placeholder="πŸ’­ Ask me anything about your documents...",
show_label=False,
container=False,
scale=8,
autofocus=True
)
clear = gr.Button("πŸ—‘οΈ Clear", size="sm", scale=1)
# Instructions
gr.HTML("""
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 10px; margin: 20px 0;">
<h3 style="color: #2d333a; margin-bottom: 10px;">πŸ” How to use:</h3>
<ol style="color: #666; margin-left: 20px;">
<li>Upload one or more documents (PDF, DOCX, CSV, or TXT)</li>
<li>Wait for the documents to be processed</li>
<li>Ask questions about your documents</li>
<li>View sources used in the responses</li>
</ol>
</div>
""")
# Footer with credits
gr.HTML("""
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
background-color: #f8f9fa; border-radius: 10px;">
<div style="margin-bottom: 15px;">
<h3 style="color: #2d333a;">⚑ About this assistant</h3>
<p style="color: #666; font-size: 14px;">
This application uses RAG (Retrieval Augmented Generation) technology combining:
</p>
<ul style="list-style: none; color: #666; font-size: 14px;">
<li>πŸ”Ή LLM Engine: Llama-2-7b-chat-hf</li>
<li>πŸ”Ή Embeddings: multilingual-e5-large</li>
<li>πŸ”Ή Vector Store: FAISS</li>
</ul>
</div>
<div style="border-top: 1px solid #ddd; padding-top: 15px;">
<p style="color: #666; font-size: 14px;">
Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/"
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>,
AI Professor and Solutions Consultant πŸ€–
</p>
</div>
</div>
""")
# Configure event handlers
def submit(user_input, chat_history, files):
return process_response(user_input, chat_history, files)
def clear_context():
# Limpiar el historial y reiniciar el sistema
rag_system.vector_store = None
rag_system.processed_files.clear()
return None
message.submit(submit, [message, chatbot, files], [chatbot])
clear.click(clear_context, None, chatbot)
logger.info("Gradio interface created successfully")
demo.launch()
except Exception as e:
logger.error(f"Error in Gradio interface creation: {str(e)}")
raise