Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import logging | |
from typing import List, Dict | |
import torch | |
import gradio as gr | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
UPLOAD_FOLDER = "uploaded_docs" | |
EMBEDDING_MODEL = "intfloat/multilingual-e5-large" | |
class RAGSystem: | |
"""Main RAG system class.""" | |
def __init__(self): | |
# Initialize device | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Initialize folders | |
self.upload_folder = UPLOAD_FOLDER | |
if os.path.exists(self.upload_folder): | |
shutil.rmtree(self.upload_folder) | |
os.makedirs(self.upload_folder, exist_ok=True) | |
# Set limits | |
self.max_files = 5 | |
self.max_file_size = 10 * 1024 * 1024 # 10 MB | |
self.supported_formats = ['.pdf', '.txt', '.docx'] | |
# Initialize components | |
self.embeddings = None | |
self.vector_store = None | |
self.qa_chain = None | |
self.documents = [] | |
# Initialize embeddings | |
self.initialize_embeddings() | |
def initialize_embeddings(self): | |
"""Initialize embedding model.""" | |
try: | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL, | |
model_kwargs={'device': self.device}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
logger.info(f"Embeddings initialized successfully on {self.device}") | |
except Exception as e: | |
logger.error(f"Error initializing embeddings: {str(e)}") | |
raise | |
def validate_file(self, file_path: str, file_size: int) -> bool: | |
"""Validate uploaded file.""" | |
if file_size > self.max_file_size: | |
raise ValueError(f"File size exceeds {self.max_file_size // 1024 // 1024}MB limit") | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext not in self.supported_formats: | |
raise ValueError(f"Unsupported format. Supported: {', '.join(self.supported_formats)}") | |
return True | |
def process_file(self, file: gr.File) -> List: | |
"""Process a single file and return documents.""" | |
try: | |
file_path = file.name | |
file_size = os.path.getsize(file_path) | |
self.validate_file(file_path, file_size) | |
# Copy file to upload directory | |
filename = os.path.basename(file_path) | |
save_path = os.path.join(self.upload_folder, filename) | |
shutil.copy2(file_path, save_path) | |
# Load documents based on file type | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext == '.pdf': | |
loader = PyPDFLoader(save_path) | |
elif ext == '.txt': | |
loader = TextLoader(save_path) | |
else: # .docx | |
loader = Docx2txtLoader(save_path) | |
documents = loader.load() | |
for doc in documents: | |
doc.metadata.update({ | |
'source': filename, | |
'type': 'uploaded' | |
}) | |
return documents | |
except Exception as e: | |
logger.error(f"Error processing {file_path}: {str(e)}") | |
raise | |
def update_vector_store(self, new_documents: List): | |
"""Update vector store with new documents.""" | |
try: | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ". ", " ", ""] | |
) | |
chunks = text_splitter.split_documents(new_documents) | |
if self.vector_store is None: | |
self.vector_store = FAISS.from_documents(chunks, self.embeddings) | |
else: | |
self.vector_store.add_documents(chunks) | |
logger.info(f"Vector store updated with {len(chunks)} chunks") | |
except Exception as e: | |
logger.error(f"Error updating vector store: {str(e)}") | |
raise | |
def initialize_llm(self): | |
"""Initialize the language model and QA chain.""" | |
try: | |
# Get Hugging Face token | |
hf_token = os.environ.get('HUGGINGFACE_TOKEN') | |
if not hf_token: | |
raise ValueError("Please set HUGGINGFACE_TOKEN environment variable") | |
# Login to Hugging Face | |
login(token=hf_token) | |
# Initialize model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=hf_token, | |
trust_remote_code=True | |
) | |
# Configure model loading based on device | |
model_config = { | |
'device_map': 'auto', | |
'trust_remote_code': True, | |
'token': hf_token | |
} | |
if self.device == "cuda": | |
model_config['torch_dtype'] = torch.float16 | |
else: | |
model_config['low_cpu_mem_usage'] = True | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, **model_config) | |
# Create pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.1, | |
device_map="auto" | |
) | |
llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
prompt_template = """ | |
Context: {context} | |
Based on the context above, please provide a clear and concise answer to the following question. | |
If the information is not in the context, explicitly state so. | |
Question: {question} | |
""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, | |
input_variables=["context", "question"] | |
) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT} | |
) | |
logger.info("LLM initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing LLM: {str(e)}") | |
raise | |
def process_upload(self, files: List[gr.File]) -> str: | |
"""Process uploaded files and initialize/update the system.""" | |
if not files: | |
return "Please select files to upload." | |
try: | |
current_files = len(os.listdir(self.upload_folder)) | |
if current_files + len(files) > self.max_files: | |
return f"Maximum number of documents ({self.max_files}) exceeded" | |
processed_files = [] | |
new_documents = [] | |
for file in files: | |
documents = self.process_file(file) | |
new_documents.extend(documents) | |
processed_files.append(os.path.basename(file.name)) | |
self.update_vector_store(new_documents) | |
self.documents.extend(new_documents) | |
if self.qa_chain is None: | |
self.initialize_llm() | |
return f"Successfully processed: {', '.join(processed_files)}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
def generate_response(self, question: str) -> Dict: | |
"""Generate response for a given question.""" | |
if not self.qa_chain: | |
return {"error": "System not initialized. Please upload documents first."} | |
try: | |
result = self.qa_chain({"query": question}) | |
response = { | |
'answer': result['result'], | |
'sources': [] | |
} | |
for doc in result['source_documents']: | |
source = { | |
'title': doc.metadata.get('source', 'Unknown'), | |
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content | |
} | |
response['sources'].append(source) | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
return {"error": str(e)} | |
# Initialize system | |
rag_system = RAGSystem() | |
def process_query(message: str, history: List) -> List: | |
"""Process user query and return updated history.""" | |
try: | |
if not rag_system.qa_chain: | |
return history + [(message, "Please upload documents first.")] | |
response = rag_system.generate_response(message) | |
if "error" in response: | |
return history + [(message, f"Error: {response['error']}")] | |
answer = response['answer'] | |
sources = set([source['title'] for source in response['sources']]) | |
if sources: | |
answer += "\n\nπ Sources:\n" + "\n".join([f"β’ {source}" for source in sources]) | |
return history + [(message, answer)] | |
except Exception as e: | |
return history + [(message, f"Error: {str(e)}")] | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 1rem;"> | |
<h1 style="color: #2d333a;">π€ Easy RAG</h1> | |
<p style="color: #4a5568;">A simple and powerful RAG system for your documents</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.HTML(""" | |
<div style="padding: 1rem; border: 1px solid #e5e7eb; border-radius: 0.5rem; background-color: white;"> | |
<h3 style="margin-top: 0;">π Upload Documents</h3> | |
""") | |
file_output = gr.File( | |
file_count="multiple", | |
label="Select Files", | |
elem_id="file-upload" | |
) | |
gr.HTML(""" | |
<div style="font-size: 0.8em; color: #666;"> | |
<p>β’ Maximum 5 files</p> | |
<p>β’ 10MB per file</p> | |
<p>β’ Supported: PDF, TXT, DOCX</p> | |
</div> | |
""") | |
system_output = gr.Textbox( | |
label="Status", | |
interactive=False | |
) | |
gr.HTML("</div>") | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
value=[], | |
label="Chat", | |
height=600, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
message = gr.Textbox( | |
placeholder="Ask a question about your documents...", | |
show_label=False, | |
container=False, | |
scale=8 | |
) | |
clear = gr.Button("ποΈ", size="sm", scale=1) | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 1rem; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 1rem;"> | |
<h3 style="color: #2d333a;">π About Easy RAG</h3> | |
<p style="color: #666; font-size: 0.9em;"> | |
Powered by state-of-the-art AI technology: | |
</p> | |
<ul style="list-style: none; color: #666; font-size: 0.9em;"> | |
<li>πΉ LLM: Llama-2-7b-chat-hf</li> | |
<li>πΉ Embeddings: multilingual-e5-large</li> | |
<li>πΉ Vector Store: FAISS</li> | |
</ul> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 1rem;"> | |
<p style="color: #666; font-size: 0.8em;"> | |
Based on original work by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/" | |
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a> | |
</p> | |
</div> | |
</div> | |
""") | |
# Set up event handlers | |
file_output.upload( | |
rag_system.process_upload, | |
inputs=[file_output], | |
outputs=[system_output] | |
) | |
message.submit( | |
process_query, | |
inputs=[message, chatbot], | |
outputs=[chatbot] | |
) | |
clear.click(lambda: None, None, chatbot) | |
if __name__ == "__main__": | |
# Log system information | |
logger.info("Starting Easy RAG system...") | |
logger.info(f"PyTorch version: {torch.__version__}") | |
logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
logger.info(f"CUDA device: {torch.cuda.get_device_name(0)}") | |
else: | |
logger.info("Running on CPU mode with optimizations") | |
# Check for HUGGINGFACE_TOKEN | |
if not os.environ.get('HUGGINGFACE_TOKEN'): | |
logger.warning("HUGGINGFACE_TOKEN not found in environment variables") | |
logger.warning("Please set it before running the application") | |
print("Please set your HUGGINGFACE_TOKEN environment variable") | |
print("Example: export HUGGINGFACE_TOKEN=your_token_here") | |
exit(1) | |
# Get sharing preference from environment | |
share_enabled = os.environ.get('SHARE_APP', 'false').lower() == 'true' | |
if share_enabled: | |
logger.info("Public sharing is enabled - a public URL will be generated") | |
try: | |
# Launch the application | |
demo.launch( | |
server_name="0.0.0.0", # Listen on all network interfaces | |
server_port=7860, # Default Gradio port | |
share=share_enabled, # Generate public URL if enabled | |
show_error=True, # Show detailed error messages | |
quiet=True # Reduce console output noise | |
) | |
except KeyboardInterrupt: | |
logger.info("Shutting down server...") | |
except Exception as e: | |
logger.error(f"Error launching server: {str(e)}") | |
raise |