Spaces:
Running
Running
""" | |
RAG (Retrieval Augmented Generation) service. | |
This module provides the RAG implementation with tool creation and agent management. | |
""" | |
import traceback | |
from typing import List, Dict, Any, Optional, Tuple | |
from langchain.tools import tool | |
from langchain.agents import AgentExecutor, create_tool_calling_agent | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.memory import ConversationBufferMemory | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
from configs.config import Config | |
from utils import ( | |
retrieve_similar_chunks, | |
filter_relevant_chunks, | |
prepare_context_from_chunks | |
) | |
from services.llm_service import create_tavily_search_tool | |
def create_vector_search_tool( | |
faiss_index: faiss.IndexHNSWFlat, | |
document_chunks_with_metadata: List[Dict[str, Any]], | |
embedding_model: SentenceTransformer, | |
k: int = None, | |
max_chunk_length: int = None | |
): | |
""" | |
Create a vector search tool for document retrieval. | |
Args: | |
faiss_index: FAISS index for similarity search | |
document_chunks_with_metadata: List of document chunks | |
embedding_model: SentenceTransformer model | |
k: Number of chunks to retrieve | |
max_chunk_length: Maximum chunk length | |
Returns: | |
LangChain tool for vector search | |
""" | |
if k is None: | |
k = Config.DEFAULT_K_CHUNKS // 3 # Use fewer chunks for tool | |
if max_chunk_length is None: | |
max_chunk_length = Config.DEFAULT_CHUNK_SIZE | |
def vector_database_search(query: str) -> str: | |
"""Search the uploaded PDF document for information related to the query. | |
Args: | |
query: The search query string to find relevant information in the document. | |
Returns: | |
A string containing relevant information found in the document. | |
""" | |
# Handle very short or empty queries | |
if not query or len(query.strip()) < 3: | |
return "Please provide a more specific search query with at least 3 characters." | |
try: | |
# Retrieve similar chunks using the provided session-specific components | |
similar_chunks_data = retrieve_similar_chunks( | |
query, | |
faiss_index, | |
document_chunks_with_metadata, | |
embedding_model, | |
k=k, | |
max_chunk_length=max_chunk_length | |
) | |
# Format the response | |
if not similar_chunks_data: | |
return "No relevant information found in the document for that query. Please try rephrasing your question or using different keywords." | |
# Filter out chunks with very high distance (low similarity) | |
filtered_chunks = filter_relevant_chunks(similar_chunks_data) | |
if not filtered_chunks: | |
return "No sufficiently relevant information found in the document for that query. Please try rephrasing your question or using different keywords." | |
context = "\n\n---\n\n".join([chunk_text for chunk_text, _, _ in filtered_chunks]) | |
return f"The following information was found in the document regarding '{query}':\n{context}" | |
except Exception as e: | |
print(f"Error in vector search tool: {e}") | |
return f"Error searching the document: {str(e)}" | |
return vector_database_search | |
class RAGService: | |
"""Service for RAG operations.""" | |
def __init__(self): | |
"""Initialize RAG service.""" | |
self.tavily_tool = create_tavily_search_tool() | |
def create_agent_tools( | |
self, | |
faiss_index: faiss.IndexHNSWFlat, | |
document_chunks: List[Dict[str, Any]], | |
embedding_model: SentenceTransformer, | |
use_web_search: bool = False | |
) -> List: | |
""" | |
Create tools for the RAG agent. | |
Args: | |
faiss_index: FAISS index | |
document_chunks: Document chunks | |
embedding_model: Embedding model | |
use_web_search: Whether to include web search tool | |
Returns: | |
List of tools for the agent | |
""" | |
tools = [] | |
# Add vector search tool | |
vector_tool = create_vector_search_tool( | |
faiss_index=faiss_index, | |
document_chunks_with_metadata=document_chunks, | |
embedding_model=embedding_model, | |
max_chunk_length=Config.DEFAULT_CHUNK_SIZE, | |
k=10 | |
) | |
tools.append(vector_tool) | |
# Add web search tool if requested and available | |
if use_web_search and self.tavily_tool: | |
tools.append(self.tavily_tool) | |
return tools | |
def create_agent_prompt(self, has_document_search: bool, has_web_search: bool) -> ChatPromptTemplate: | |
""" | |
Create prompt template for the agent. | |
Args: | |
has_document_search: Whether document search is available | |
has_web_search: Whether web search is available | |
Returns: | |
ChatPromptTemplate for the agent | |
""" | |
# Build tool instructions dynamically | |
tool_instructions = "" | |
if has_document_search: | |
tool_instructions += "Use vector_database_search to find information in the uploaded document. " | |
if has_web_search: | |
tool_instructions += "Use tavily_search_results_json for web searches when document search is insufficient. " | |
if not tool_instructions: | |
tool_instructions = "Answer based on the provided context only. " | |
return ChatPromptTemplate.from_messages([ | |
("system", f"""You are a helpful AI assistant that answers questions about documents. | |
Context: {{context}} | |
Tools available: {tool_instructions} | |
Instructions: | |
- Use the provided context first | |
- If context is insufficient, use available tools to search for more information | |
- Provide clear, helpful answers | |
- If you cannot find an answer, say so clearly"""), | |
("human", "{input}"), | |
MessagesPlaceholder(variable_name="chat_history"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
]) | |
def execute_agent( | |
self, | |
llm, | |
tools: List, | |
query: str, | |
context: str, | |
memory: ConversationBufferMemory | |
) -> Dict[str, Any]: | |
""" | |
Execute the RAG agent with given tools and context. | |
Args: | |
llm: Language model | |
tools: List of tools | |
query: User query | |
context: Context string | |
memory: Conversation memory | |
Returns: | |
Agent response | |
""" | |
try: | |
# Validate tools | |
for tool in tools: | |
if not hasattr(tool, 'name') or not hasattr(tool, 'description'): | |
raise ValueError(f"Tool {tool} is missing required attributes") | |
# Create prompt | |
has_document_search = any(t.name == "vector_database_search" for t in tools) | |
has_web_search = any(t.name == "tavily_search_results_json" for t in tools) | |
prompt = self.create_agent_prompt(has_document_search, has_web_search) | |
# Create agent | |
agent = create_tool_calling_agent(llm, tools, prompt) | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
memory=memory, | |
verbose=Config.AGENT_VERBOSE, | |
handle_parsing_errors=True, | |
max_iterations=Config.AGENT_MAX_ITERATIONS, | |
return_intermediate_steps=False, | |
early_stopping_method="generate" | |
) | |
# Execute agent | |
agent_input = { | |
"input": query, | |
"context": context, | |
} | |
response_payload = agent_executor.invoke(agent_input) | |
# Validate response | |
agent_output = response_payload.get("output", "") if response_payload else "" | |
if not agent_output or len(agent_output.strip()) < 10: | |
raise ValueError("Insufficient response from agent") | |
# Check for incomplete responses | |
problematic_prefixes = [ | |
"Based on the Document,", | |
"According to a web search,", | |
"Based on the available information,", | |
"I need to", | |
"Let me" | |
] | |
stripped_output = agent_output.strip() | |
if any(stripped_output == prefix.strip() or | |
stripped_output == prefix.strip() + "." | |
for prefix in problematic_prefixes): | |
raise ValueError("Agent returned incomplete response") | |
return response_payload | |
except Exception as e: | |
raise | |
def fallback_response( | |
self, | |
llm, | |
tools: List, | |
query: str, | |
context: str, | |
use_tavily: bool = False | |
) -> Dict[str, Any]: | |
""" | |
Generate fallback response using direct tool usage or LLM. | |
Args: | |
llm: Language model | |
tools: List of available tools | |
query: User query | |
context: Context string | |
use_tavily: Whether to use web search | |
Returns: | |
Fallback response | |
""" | |
try: | |
tool_results = [] | |
# Try vector search first if available | |
vector_tool = next((t for t in tools if t.name == "vector_database_search"), None) | |
if vector_tool: | |
try: | |
search_result = vector_tool.run(query) | |
if search_result and "No relevant information" not in search_result: | |
tool_results.append(f"Document Search: {search_result}") | |
except Exception as tool_error: | |
pass | |
# Try web search if needed and available | |
if use_tavily: | |
web_tool = next((t for t in tools if t.name == "tavily_search_results_json"), None) | |
if web_tool: | |
try: | |
web_result = web_tool.run(query) | |
if web_result: | |
tool_results.append(f"Web Search: {web_result}") | |
except Exception as tool_error: | |
pass | |
# Combine tool results with context | |
enhanced_context = context | |
if tool_results: | |
enhanced_context += "\n\nAdditional Information:\n" + "\n\n".join(tool_results) | |
# Use direct LLM call with enhanced context | |
direct_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful assistant. Use the provided context and information to answer the user's question clearly and completely."), | |
("human", "Context and Information: {context}\n\nQuestion: {input}") | |
]) | |
formatted_prompt = direct_prompt.format_prompt( | |
context=enhanced_context, | |
input=query | |
).to_messages() | |
response = llm.invoke(formatted_prompt) | |
direct_output = response.content if hasattr(response, 'content') else str(response) | |
return {"output": direct_output} | |
except Exception as manual_error: | |
# Final fallback - simple LLM response | |
fallback_prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a helpful assistant that answers questions about documents. | |
Use the provided context to answer the user's question. | |
If the context contains relevant information, start your answer with "Based on the document, ..." | |
If the context is insufficient, clearly state what you don't know."""), | |
("human", "Context: {context}\n\nQuestion: {input}") | |
]) | |
formatted_fallback = fallback_prompt.format_prompt( | |
context=context, | |
input=query | |
).to_messages() | |
response = llm.invoke(formatted_fallback) | |
fallback_output = response.content if hasattr(response, 'content') else str(response) | |
return {"output": fallback_output} | |
def generate_response( | |
self, | |
llm, | |
query: str, | |
context_chunks: List[Tuple], | |
faiss_index: faiss.IndexHNSWFlat, | |
document_chunks: List[Dict[str, Any]], | |
embedding_model: SentenceTransformer, | |
memory: ConversationBufferMemory, | |
use_tavily: bool = False | |
) -> Dict[str, Any]: | |
""" | |
Generate RAG response using agent or fallback methods. | |
Args: | |
llm: Language model | |
query: User query | |
context_chunks: Initial context chunks | |
faiss_index: FAISS index | |
document_chunks: Document chunks | |
embedding_model: Embedding model | |
memory: Conversation memory | |
use_tavily: Whether to use web search | |
Returns: | |
Generated response | |
""" | |
# Validate inputs | |
if not query or not query.strip(): | |
return {"output": "Please provide a valid question."} | |
# Create tools | |
tools = self.create_agent_tools( | |
faiss_index, document_chunks, embedding_model, use_tavily | |
) | |
if not tools: | |
fallback_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful assistant that answers questions about documents. Use the provided context to answer the user's question."), | |
("human", "Context: {context}\n\nQuestion: {input}") | |
]) | |
try: | |
formatted_prompt = fallback_prompt.format_prompt( | |
context="No context available", | |
input=query | |
).to_messages() | |
response = llm.invoke(formatted_prompt) | |
return {"output": response.content if hasattr(response, 'content') else str(response)} | |
except Exception as e: | |
return {"output": "I'm sorry, I encountered an error processing your request."} | |
# Prepare context | |
context = prepare_context_from_chunks(context_chunks) | |
# Try agent execution | |
if not tools: | |
# Handle case where no tools are available | |
fallback_prompt = ChatPromptTemplate.from_messages([ | |
("system", "You are a helpful assistant that answers questions about documents. Use the provided context to answer the user's question."), | |
("human", "Context: {context}\n\nQuestion: {input}") | |
]) | |
formatted_prompt = fallback_prompt.format_prompt( | |
context=context, | |
input=query | |
).to_messages() | |
response = llm.invoke(formatted_prompt) | |
return {"output": response.content if hasattr(response, 'content') else str(response)} | |
try: | |
return self.execute_agent(llm, tools, query, context, memory) | |
except Exception as e: | |
error_msg = str(e) | |
# Try fallback approach | |
try: | |
return self.fallback_response(llm, tools, query, context, use_tavily) | |
except Exception as fallback_error: | |
return {"output": "I'm sorry, I encountered an error processing your request. Please try again."} | |
# Global RAG service instance | |
rag_service = RAGService() | |