Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import logging | |
from typing import List, Dict | |
import torch | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain_community.document_loaders import ( | |
PyPDFLoader, | |
Docx2txtLoader, | |
CSVLoader, | |
UnstructuredFileLoader | |
) | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import spaces | |
import tempfile | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
SUPPORTED_FORMATS = [".pdf", ".docx", ".doc", ".csv", ".txt"] | |
class DocumentLoader: | |
"""Enhanced document loader supporting multiple file formats.""" | |
def load_file(file_path: str) -> List: | |
"""Load a single file based on its extension.""" | |
ext = os.path.splitext(file_path)[1].lower() | |
try: | |
if ext == '.pdf': | |
loader = PyPDFLoader(file_path) | |
elif ext in ['.docx', '.doc']: | |
loader = Docx2txtLoader(file_path) | |
elif ext == '.csv': | |
loader = CSVLoader(file_path) | |
else: # fallback for txt and other text files | |
loader = UnstructuredFileLoader(file_path) | |
documents = loader.load() | |
# Add metadata | |
for doc in documents: | |
doc.metadata.update({ | |
'title': os.path.basename(file_path), | |
'type': 'document', | |
'format': ext[1:], | |
'language': 'auto' | |
}) | |
logger.info(f"Successfully loaded {file_path}") | |
return documents | |
except Exception as e: | |
logger.error(f"Error loading {file_path}: {str(e)}") | |
raise | |
class RAGSystem: | |
"""Enhanced RAG system with dynamic document loading.""" | |
def __init__(self, model_name: str = MODEL_NAME): | |
self.model_name = model_name | |
self.embeddings = None | |
self.vector_store = None | |
self.qa_chain = None | |
self.tokenizer = None | |
self.model = None | |
self.is_initialized = False | |
self.processed_files = set() # Mantener registro de archivos procesados | |
def initialize_model(self): | |
"""Initialize the base model and tokenizer.""" | |
try: | |
logger.info("Initializing language model...") | |
# Initialize embeddings | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="intfloat/multilingual-e5-large", | |
model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}, | |
encode_kwargs={'normalize_embeddings': True} | |
) | |
# Initialize model and tokenizer | |
# Get HuggingFace token | |
hf_token = os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN') | |
if not hf_token: | |
raise ValueError("No Hugging Face token found. Please set HUGGINGFACE_TOKEN in your environment variables") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
token=hf_token, # Add token here | |
trust_remote_code=True | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
token=hf_token, # Add token here | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
# Create generation pipeline | |
pipe = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_new_tokens=512, | |
temperature=0.1, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
device_map="auto" | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
self.is_initialized = True | |
logger.info("Model initialization completed") | |
except Exception as e: | |
logger.error(f"Error during model initialization: {str(e)}") | |
raise | |
def process_documents(self, files: List[tempfile._TemporaryFileWrapper]) -> None: | |
"""Process uploaded documents and update the vector store.""" | |
try: | |
documents = [] | |
new_files = [] | |
# Procesar solo archivos nuevos | |
for file in files: | |
if file.name not in self.processed_files: | |
docs = DocumentLoader.load_file(file.name) | |
documents.extend(docs) | |
new_files.append(file.name) | |
self.processed_files.add(file.name) | |
if not new_files: | |
logger.info("No new documents to process") | |
return | |
if not documents: | |
raise ValueError("No documents were successfully loaded.") | |
# Process documents | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=800, | |
chunk_overlap=200, | |
separators=["\n\n", "\n", ". ", " ", ""], | |
length_function=len | |
) | |
chunks = text_splitter.split_documents(documents) | |
# Create or update vector store | |
if self.vector_store is None: | |
self.vector_store = FAISS.from_documents(chunks, self.embeddings) | |
else: | |
self.vector_store.add_documents(chunks) | |
# Initialize QA chain | |
prompt_template = """ | |
Context: {context} | |
Based on the provided context, please answer the following question clearly and concisely. | |
If the information is not in the context, please say so explicitly. | |
Question: {question} | |
""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, | |
input_variables=["context", "question"] | |
) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.vector_store.as_retriever( | |
search_kwargs={"k": 6} | |
), | |
return_source_documents=True, | |
chain_type_kwargs={"prompt": PROMPT} | |
) | |
logger.info(f"Successfully processed {len(documents)} documents") | |
except Exception as e: | |
logger.error(f"Error processing documents: {str(e)}") | |
raise | |
def generate_response(self, question: str) -> Dict: | |
"""Generate response for a given question.""" | |
if not self.is_initialized or self.qa_chain is None: | |
return { | |
'answer': "Please upload some documents first before asking questions.", | |
'sources': [] | |
} | |
try: | |
result = self.qa_chain({"query": question}) | |
response = { | |
'answer': result['result'], | |
'sources': [] | |
} | |
for doc in result['source_documents']: | |
source = { | |
'title': doc.metadata.get('title', 'Unknown'), | |
'content': doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content, | |
'metadata': doc.metadata | |
} | |
response['sources'].append(source) | |
return response | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise | |
def process_response(user_input: str, chat_history: List, files: List) -> tuple: | |
"""Process user input and generate response.""" | |
try: | |
if not rag_system.is_initialized: | |
rag_system.initialize_model() | |
# Siempre procesar documentos si hay archivos nuevos | |
if files: | |
rag_system.process_documents(files) | |
response = rag_system.generate_response(user_input) | |
# Clean and format response | |
answer = response['answer'] | |
if "Answer:" in answer: | |
answer = answer.split("Answer:")[-1].strip() | |
# Format sources | |
sources = set([source['title'] for source in response['sources'][:3]]) | |
if sources: | |
answer += "\n\nπ Sources consulted:\n" + "\n".join([f"β’ {source}" for source in sources]) | |
chat_history.append((user_input, answer)) | |
return chat_history | |
except Exception as e: | |
logger.error(f"Error in process_response: {str(e)}") | |
error_message = f"Sorry, an error occurred: {str(e)}" | |
chat_history.append((user_input, error_message)) | |
return chat_history | |
# Initialize RAG system | |
logger.info("Initializing RAG system...") | |
try: | |
rag_system = RAGSystem() | |
logger.info("RAG system created successfully") | |
except Exception as e: | |
logger.error(f"Failed to create RAG system: {str(e)}") | |
raise | |
# Create Gradio interface | |
try: | |
logger.info("Creating Gradio interface...") | |
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
<h1 style="color: #2d333a;">π Easy RAG</h1> | |
<p style="color: #4a5568;"> | |
Your AI Assistant for Document Analysis and Q&A | |
</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
files = gr.Files( | |
label="Upload Your Documents", | |
file_types=SUPPORTED_FORMATS, | |
file_count="multiple" | |
) | |
gr.HTML(""" | |
<div style="font-size: 0.9em; color: #666; margin-top: 0.5em;"> | |
Supported formats: PDF, DOCX, CSV, TXT | |
</div> | |
""") | |
chatbot = gr.Chatbot( | |
show_label=False, | |
container=True, | |
height=500, | |
bubble_full_width=True, | |
show_copy_button=True, | |
scale=2 | |
) | |
with gr.Row(): | |
message = gr.Textbox( | |
placeholder="π Ask me anything about your documents...", | |
show_label=False, | |
container=False, | |
scale=8, | |
autofocus=True | |
) | |
clear = gr.Button("ποΈ Clear", size="sm", scale=1) | |
# Instructions | |
gr.HTML(""" | |
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 10px; margin: 20px 0;"> | |
<h3 style="color: #2d333a; margin-bottom: 10px;">π How to use:</h3> | |
<ol style="color: #666; margin-left: 20px;"> | |
<li>Upload one or more documents (PDF, DOCX, CSV, or TXT)</li> | |
<li>Wait for the documents to be processed</li> | |
<li>Ask questions about your documents</li> | |
<li>View sources used in the responses</li> | |
</ol> | |
</div> | |
""") | |
# Footer with credits | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 15px;"> | |
<h3 style="color: #2d333a;">β‘ About this assistant</h3> | |
<p style="color: #666; font-size: 14px;"> | |
This application uses RAG (Retrieval Augmented Generation) technology combining: | |
</p> | |
<ul style="list-style: none; color: #666; font-size: 14px;"> | |
<li>πΉ LLM Engine: Llama-2-7b-chat-hf</li> | |
<li>πΉ Embeddings: multilingual-e5-large</li> | |
<li>πΉ Vector Store: FAISS</li> | |
</ul> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
Created by <a href="https://www.linkedin.com/in/camilo-vega-169084b1/" | |
target="_blank" style="color: #2196F3; text-decoration: none;">Camilo Vega</a>, | |
AI Professor and Solutions Consultant π€ | |
</p> | |
</div> | |
</div> | |
""") | |
# Configure event handlers | |
def submit(user_input, chat_history, files): | |
return process_response(user_input, chat_history, files) | |
def clear_context(): | |
# Limpiar el historial y reiniciar el sistema | |
rag_system.vector_store = None | |
rag_system.processed_files.clear() | |
return None | |
message.submit(submit, [message, chatbot, files], [chatbot]) | |
clear.click(clear_context, None, chatbot) | |
logger.info("Gradio interface created successfully") | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Error in Gradio interface creation: {str(e)}") | |
raise |