Spaces:
Sleeping
Sleeping
""" | |
Enhanced LangGraph Agent with Multi-LLM Support and Proper Question Answering | |
Combines your original LangGraph structure with enhanced response handling | |
""" | |
import os | |
import time | |
import random | |
from dotenv import load_dotenv | |
from typing import List, Dict, Any, TypedDict, Annotated | |
import operator | |
from langgraph.graph import START, StateGraph, MessagesState, END | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_community.vectorstores import SupabaseVectorStore | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import tool | |
from langchain.tools.retriever import create_retriever_tool | |
from supabase.client import Client, create_client | |
load_dotenv() | |
# Enhanced system prompt for better question answering | |
ENHANCED_SYSTEM_PROMPT = """You are a helpful assistant tasked with answering questions using a set of tools. | |
CRITICAL INSTRUCTIONS: | |
1. Read the question carefully and understand what specific information is being asked | |
2. Use the appropriate tools to find the exact information requested | |
3. For factual questions, search for current and accurate information | |
4. For calculations, use the math tools provided | |
5. Always provide specific, direct answers - never repeat the question as your answer | |
6. If you cannot find the information, state "Information not available" | |
7. Format your final response as: FINAL ANSWER: [your specific answer] | |
ANSWER FORMAT RULES: | |
- For numbers: provide just the number without commas or units unless specified | |
- For names/strings: provide the exact name or term without articles | |
- For lists: provide comma-separated values | |
- Be concise and specific in your final answer | |
Remember: Your job is to ANSWER the question, not repeat it back.""" | |
# ---- Enhanced Tool Definitions ---- | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two numbers. | |
Args: | |
a: first int | |
b: second int | |
""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two numbers. | |
Args: | |
a: first int | |
b: second int | |
""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract two numbers. | |
Args: | |
a: first int | |
b: second int | |
""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide two numbers. | |
Args: | |
a: first int | |
b: second int | |
""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Get the modulus of two numbers. | |
Args: | |
a: first int | |
b: second int | |
""" | |
return a % b | |
def wiki_search(query: str) -> str: | |
"""Search Wikipedia for a query and return maximum 2 results. | |
Args: | |
query: The search query. | |
""" | |
try: | |
time.sleep(random.uniform(0.5, 1.0)) # Rate limiting | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
if not search_docs: | |
return "No Wikipedia results found" | |
formatted_search_docs = "\n\n---\n\n".join([ | |
f'<Document source="{doc.metadata.get("source", "Wikipedia")}" title="{doc.metadata.get("title", "")}">\n{doc.page_content[:1500]}\n</Document>' | |
for doc in search_docs | |
]) | |
return formatted_search_docs | |
except Exception as e: | |
return f"Wikipedia search failed: {e}" | |
def web_search(query: str) -> str: | |
"""Search Tavily for a query and return maximum 3 results. | |
Args: | |
query: The search query. | |
""" | |
try: | |
time.sleep(random.uniform(0.7, 1.2)) # Rate limiting | |
search_tool = TavilySearchResults(max_results=3) | |
search_docs = search_tool.invoke({"query": query}) | |
if not search_docs: | |
return "No web search results found" | |
formatted_search_docs = "\n\n---\n\n".join([ | |
f'<Document source="{doc.get("url", "")}">\n{doc.get("content", "")[:1200]}\n</Document>' | |
for doc in search_docs | |
]) | |
return formatted_search_docs | |
except Exception as e: | |
return f"Web search failed: {e}" | |
def arxiv_search(query: str) -> str: | |
"""Search Arxiv for a query and return maximum 3 results. | |
Args: | |
query: The search query. | |
""" | |
try: | |
time.sleep(random.uniform(0.5, 1.0)) # Rate limiting | |
search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
if not search_docs: | |
return "No ArXiv results found" | |
formatted_search_docs = "\n\n---\n\n".join([ | |
f'<Document source="{doc.metadata.get("source", "ArXiv")}" title="{doc.metadata.get("title", "")}">\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
]) | |
return formatted_search_docs | |
except Exception as e: | |
return f"ArXiv search failed: {e}" | |
# Initialize tools list | |
tools = [ | |
multiply, add, subtract, divide, modulus, | |
wiki_search, web_search, arxiv_search | |
] | |
# Enhanced State for better tracking | |
class EnhancedState(MessagesState): | |
"""Enhanced state with additional tracking""" | |
query: str = "" | |
tools_used: List[str] = [] | |
search_results: str = "" | |
def build_graph(provider: str = "groq"): | |
"""Build the enhanced graph with proper error handling and response formatting""" | |
# Initialize LLM based on provider | |
if provider == "google": | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
elif provider == "groq": | |
llm = ChatGroq(model="llama3-70b-8192", temperature=0) # Using more reliable model | |
elif provider == "huggingface": | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", | |
temperature=0, | |
), | |
) | |
else: | |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.") | |
# Bind tools to LLM | |
llm_with_tools = llm.bind_tools(tools) | |
# Initialize vector store if available | |
vector_store = None | |
try: | |
if os.getenv("SUPABASE_URL") and os.getenv("SUPABASE_SERVICE_KEY"): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
supabase: Client = create_client( | |
os.environ.get("SUPABASE_URL"), | |
os.environ.get("SUPABASE_SERVICE_KEY") | |
) | |
vector_store = SupabaseVectorStore( | |
client=supabase, | |
embedding=embeddings, | |
table_name="documents", | |
query_name="match_documents_langchain", | |
) | |
except Exception as e: | |
print(f"Vector store initialization failed: {e}") | |
def retriever(state: MessagesState): | |
"""Enhanced retriever node with fallback""" | |
messages = state["messages"] | |
query = messages[-1].content if messages else "" | |
# Try to get similar questions from vector store | |
similar_context = "" | |
if vector_store: | |
try: | |
similar_questions = vector_store.similarity_search(query, k=1) | |
if similar_questions: | |
similar_context = f"\n\nSimilar example for reference:\n{similar_questions[0].page_content}" | |
except Exception as e: | |
print(f"Vector search failed: {e}") | |
# Enhanced system message with context | |
enhanced_prompt = ENHANCED_SYSTEM_PROMPT + similar_context | |
sys_msg = SystemMessage(content=enhanced_prompt) | |
return {"messages": [sys_msg] + messages} | |
def assistant(state: MessagesState): | |
"""Enhanced assistant node with better response handling""" | |
try: | |
response = llm_with_tools.invoke(state["messages"]) | |
# Ensure response is properly formatted | |
if hasattr(response, 'content'): | |
content = response.content | |
# Check if this is just repeating the question | |
original_query = state["messages"][-1].content if state["messages"] else "" | |
if content.strip() == original_query.strip(): | |
# Force a better response | |
enhanced_messages = state["messages"] + [ | |
HumanMessage(content=f"Please provide a specific answer to this question, do not repeat the question: {original_query}") | |
] | |
response = llm_with_tools.invoke(enhanced_messages) | |
return {"messages": [response]} | |
except Exception as e: | |
error_response = AIMessage(content=f"Error processing request: {e}") | |
return {"messages": [error_response]} | |
def format_final_answer(state: MessagesState): | |
"""Format the final answer properly""" | |
messages = state["messages"] | |
if not messages: | |
return {"messages": [AIMessage(content="FINAL ANSWER: Information not available")]} | |
last_message = messages[-1] | |
if hasattr(last_message, 'content'): | |
content = last_message.content | |
# Ensure proper formatting | |
if "FINAL ANSWER:" not in content: | |
# Extract the key information and format it | |
if content.strip(): | |
formatted_content = f"FINAL ANSWER: {content.strip()}" | |
else: | |
formatted_content = "FINAL ANSWER: Information not available" | |
formatted_message = AIMessage(content=formatted_content) | |
return {"messages": messages[:-1] + [formatted_message]} | |
return {"messages": messages} | |
# Build the graph | |
builder = StateGraph(MessagesState) | |
# Add nodes | |
builder.add_node("retriever", retriever) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_node("formatter", format_final_answer) | |
# Add edges | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges( | |
"assistant", | |
tools_condition, | |
{ | |
"tools": "tools", | |
"__end__": "formatter" | |
} | |
) | |
builder.add_edge("tools", "assistant") | |
builder.add_edge("formatter", END) | |
# Compile graph with checkpointer | |
return builder.compile(checkpointer=MemorySaver()) | |
# Test function | |
def test_agent(): | |
"""Test the agent with sample questions""" | |
graph = build_graph(provider="groq") | |
test_questions = [ | |
"How many studio albums were published by Mercedes Sosa between 2000 and 2009?", | |
"What is 25 multiplied by 17?", | |
"Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2004?" | |
] | |
for question in test_questions: | |
print(f"\nQuestion: {question}") | |
print("-" * 60) | |
try: | |
messages = [HumanMessage(content=question)] | |
config = {"configurable": {"thread_id": f"test_{hash(question)}"}} | |
result = graph.invoke({"messages": messages}, config) | |
if result and "messages" in result: | |
final_message = result["messages"][-1] | |
if hasattr(final_message, 'content'): | |
print(f"Answer: {final_message.content}") | |
else: | |
print(f"Answer: {final_message}") | |
else: | |
print("Answer: No response generated") | |
except Exception as e: | |
print(f"Error: {e}") | |
print() | |
if __name__ == "__main__": | |
# Run tests | |
test_agent() | |