Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from typing import List, Dict | |
import torch | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
UPLOAD_FOLDER = "uploaded_docs" | |
class DocumentManager: | |
"""Class to manage document uploads and processing.""" | |
def __init__(self): | |
self.upload_folder = UPLOAD_FOLDER | |
os.makedirs(self.upload_folder, exist_ok=True) | |
self.max_files = 5 | |
self.max_file_size = 10 * 1024 * 1024 # 10 MB | |
self.supported_formats = ['.pdf', '.txt', '.docx'] | |
self.documents = [] | |
def validate_file(self, file): | |
if os.path.getsize(file.name) > self.max_file_size: | |
raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit") | |
ext = os.path.splitext(file.name)[1].lower() | |
if ext not in self.supported_formats: | |
raise ValueError(f"Unsupported file format. Supported formats: {', '.join(self.supported_formats)}") | |
def load_document(self, file_path: str) -> List: | |
ext = os.path.splitext(file_path)[1].lower() | |
try: | |
if ext == '.pdf': | |
loader = PyPDFLoader(file_path) | |
elif ext == '.txt': | |
loader = TextLoader(file_path) | |
elif ext == '.docx': | |
loader = Docx2txtLoader(file_path) | |
else: | |
raise ValueError(f"Unsupported file format: {ext}") | |
documents = loader.load() | |
for doc in documents: | |
doc.metadata.update({ | |
'source': os.path.basename(file_path), | |
'type': 'uploaded' | |
}) | |
return documents | |
except Exception as e: | |
logger.error(f"Error loading {file_path}: {str(e)}") | |
raise | |
def process_upload(self, files: List) -> str: | |
if len(os.listdir(self.upload_folder)) + len(files) > self.max_files: | |
raise ValueError(f"Maximum number of documents ({self.max_files}) exceeded") | |
processed_files = [] | |
for file in files: | |
try: | |
self.validate_file(file) | |
save_path = os.path.join(self.upload_folder, file.name) | |
file.save(save_path) | |
docs = self.load_document(save_path) | |
self.documents.extend(docs) | |
processed_files.append(file.name) | |
except Exception as e: | |
logger.error(f"Error processing {file.name}: {str(e)}") | |
return f"Error processing {file.name}: {str(e)}" | |
return f"Successfully processed files: {', '.join(processed_files)}" | |
class RAGSystem: | |
"""Main RAG system class.""" | |
def __init__(self, model_name: str = MODEL_NAME): | |
self.model_name = model_name | |
self.document_manager = DocumentManager() | |
self.embeddings = None | |
self.vector_store = None | |
self.qa_chain = None | |
self.is_initialized = False | |
def initialize_system(self, documents: List = None): | |
"""Initialize RAG system with provided documents.""" | |
try: | |
if not documents: | |
raise ValueError("No documents provided for initialization") | |
# Initialize text splitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ". ", " ", ""] | |
) | |
# Process documents | |
chunks = text_splitter.split_documents(documents) | |
# Initialize embeddings | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="intfloat/multilingual-e5-large", | |
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'} | |
) | |
# Create vector store | |
self.vector_store = FAISS.from_documents(chunks, self.embeddings) | |
# Initialize LLM pipeline | |
tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.1, | |
device_map="auto" | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
prompt_template = """ | |
Context: {context} | |
Based on the context above, please provide a clear and concise answer to the following question. | |
If the information is not in the context, explicitly state so. | |
Question: {question} | |
""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, | |
input_variables=["context", "question"] | |
) | |
# Set up QA chain | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT} | |
) | |
self.is_initialized = True | |
return "System initialized successfully" | |
except Exception as e: | |
logger.error(f"Error during system initialization: {str(e)}") | |
return f"Error: {str(e)}" | |
def generate_response(self, question: str) -> Dict: | |
"""Generate response for a given question.""" | |
if not self.is_initialized: | |
return {"error": "System not initialized. Please upload documents first."} | |
try: | |
result = self.qa_chain({"query": question}) | |
response = { | |
'answer': result['result'], | |
'sources': [] | |
} | |
for doc in result['source_documents']: | |
source = { | |
'title': doc.metadata.get('source', 'Unknown'), | |
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content | |
} | |
response['sources'].append(source) | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
return {"error": str(e)} | |
# Initialize RAG system | |
rag_system = RAGSystem() | |
def process_file_upload(files): | |
"""Handle file uploads and system initialization.""" | |
try: | |
upload_result = rag_system.document_manager.process_upload(files) | |
if "Error" in upload_result: | |
return upload_result | |
init_result = rag_system.initialize_system(rag_system.document_manager.documents) | |
return f"{upload_result}\n{init_result}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def process_query(message, history): | |
"""Process user query and generate response.""" | |
try: | |
if not rag_system.is_initialized: | |
return history + [(message, "Please upload documents first.")] | |
response = rag_system.generate_response(message) | |
if "error" in response: | |
return history + [(message, f"Error: {response['error']}")] | |
answer = response['answer'] | |
sources = set([source['title'] for source in response['sources']]) | |
if sources: | |
answer += "\n\nπ Sources:\n" + "\n".join([f"β’ {source}" for source in sources]) | |
return history + [(message, answer)] | |
except Exception as e: | |
return history + [(message, f"Error: {str(e)}")] | |
# Create Gradio interface | |
demo = gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") | |
with demo: | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
<h1 style="color: #2d333a;">π€ Easy RAG</h1> | |
<p style="color: #4a5568;"> | |
A simple and powerful RAG system for your documents | |
</p> | |
</div> | |
""") | |
with gr.Row(): | |
file_output = gr.File( | |
file_count="multiple", | |
label="Upload Documents (PDF, TXT, DOCX - Max 5 files, 10MB each)" | |
) | |
upload_button = gr.Button("Upload and Initialize") | |
system_output = gr.Textbox(label="System Status") | |
chatbot = gr.Chatbot( | |
show_label=False, | |
container=True, | |
height=400, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
message = gr.Textbox( | |
placeholder="Ask a question about your documents...", | |
show_label=False, | |
container=False, | |
scale=8 | |
) | |
clear = gr.Button("ποΈ Clear", size="sm", scale=1) | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 15px;"> | |
<h3 style="color: #2d333a;">π About Easy RAG</h3> | |
<p style="color: #666; font-size: 14px;"> | |
A powerful RAG system that lets you query your documents using: | |
</p> | |
<ul style="list-style: none; color: #666; font-size: 14px;"> | |
<li>πΉ LLM: Llama-2-7b-chat-hf</li> | |
<li>πΉ Embeddings: multilingual-e5-large</li> | |
<li>πΉ Vector Store: FAISS</li> | |
</ul> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/" | |
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a> | |
</p> | |
</div> | |
</div> | |
""") | |
# Set up event handlers | |
upload_button.click( | |
process_file_upload, | |
inputs=[file_output], | |
outputs=[system_output] | |
) | |
message.submit( | |
process_query, | |
inputs=[message, chatbot], | |
outputs=[chatbot] | |
) | |
clear.click(lambda: None, None, chatbot) | |
demo.launch() |