import os from dotenv import load_dotenv # Load environment variables load_dotenv() # Set protobuf implementation to avoid C++ extension issues os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # Load keys from environment hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN") serper_api_key = os.getenv("SERPER_API_KEY") # ---- Imports ---- from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_community.vectorstores import Chroma from langchain_core.documents import Document from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.tools import tool from langchain.tools.retriever import create_retriever_tool from langchain.vectorstores import Chroma from langchain.embeddings import HuggingFaceEmbeddings from langchain.schema import Document import json import requests from typing import List, Dict, Any import re import math from datetime import datetime # ---- Enhanced Tools ---- @tool def multiply(a: float, b: float) -> float: """Multiply two numbers""" return a * b @tool def add(a: float, b: float) -> float: """Add two numbers""" return a + b @tool def subtract(a: float, b: float) -> float: """Subtract two numbers""" return a - b @tool def divide(a: float, b: float) -> float: """Divide two numbers""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Calculate modulus of two integers""" return a % b @tool def power(a: float, b: float) -> float: """Calculate a raised to the power of b""" return a ** b @tool def square_root(a: float) -> float: """Calculate square root of a number""" return math.sqrt(a) @tool def factorial(n: int) -> int: """Calculate factorial of a number""" if n < 0: raise ValueError("Factorial is not defined for negative numbers") if n == 0 or n == 1: return 1 result = 1 for i in range(2, n + 1): result *= i return result @tool def gcd(a: int, b: int) -> int: """Calculate greatest common divisor""" while b: a, b = b, a % b return a @tool def lcm(a: int, b: int) -> int: """Calculate least common multiple""" return abs(a * b) // gcd(a, b) @tool def percentage(part: float, whole: float) -> float: """Calculate percentage""" return (part / whole) * 100 @tool def compound_interest(principal: float, rate: float, time: float, n: int = 1) -> float: """Calculate compound interest""" return principal * (1 + rate/n) ** (n * time) @tool def wiki_search(query: str) -> str: """Search Wikipedia for information""" try: search_docs = WikipediaLoader(query=query, load_max_docs=3).load() if not search_docs: return "No Wikipedia results found." formatted = "\n\n---\n\n".join([ f'\n{doc.page_content[:2000]}\n' for doc in search_docs ]) return formatted except Exception as e: return f"Wikipedia search error: {str(e)}" @tool def web_search(query: str) -> str: """Search the web using Tavily""" try: search_docs = TavilySearchResults(max_results=3).invoke(query=query) if not search_docs: return "No web search results found." formatted = "\n\n---\n\n".join([ f'\n{doc.get("content", "")[:2000]}\n' for doc in search_docs ]) return formatted except Exception as e: return f"Web search error: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search ArXiv for academic papers""" try: search_docs = ArxivLoader(query=query, load_max_docs=2).load() if not search_docs: return "No ArXiv results found." formatted = "\n\n---\n\n".join([ f'\n{doc.page_content[:1500]}\n' for doc in search_docs ]) return formatted except Exception as e: return f"ArXiv search error: {str(e)}" @tool def serper_search(query: str) -> str: """Enhanced web search using Serper API""" if not serper_api_key: return "Serper API key not available" try: url = "https://google.serper.dev/search" payload = json.dumps({ "q": query, "num": 5 }) headers = { 'X-API-KEY': serper_api_key, 'Content-Type': 'application/json' } response = requests.request("POST", url, headers=headers, data=payload) results = response.json() if 'organic' not in results: return "No search results found" formatted = "\n\n---\n\n".join([ f'\n{result.get("snippet", "")}\n' for result in results['organic'][:3] ]) return formatted except Exception as e: return f"Serper search error: {str(e)}" # ---- Embedding & Vector Store Setup ---- def setup_vector_store(): try: embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # Check if metadata.jsonl exists and load it if os.path.exists('metadata.jsonl'): json_QA = [] with open('metadata.jsonl', 'r') as jsonl_file: for line in jsonl_file: if line.strip(): # Skip empty lines json_QA.append(json.loads(line)) if json_QA: documents = [ Document( page_content=f"Question: {sample.get('Question', '')}\n\nFinal answer: {sample.get('Final answer', '')}", metadata={"source": sample.get("task_id", "unknown")} ) for sample in json_QA if sample.get('Question') and sample.get('Final answer') ] if documents: vector_store = Chroma.from_documents( documents=documents, embedding=embeddings, persist_directory="./chroma_db", collection_name="my_collection" ) vector_store.persist() print(f"Vector store created with {len(documents)} documents") return vector_store # Create empty vector store if no data vector_store = Chroma( embedding_function=embeddings, persist_directory="./chroma_db", collection_name="my_collection" ) print("Empty vector store created") return vector_store except Exception as e: print(f"Vector store setup error: {e}") # Return a dummy vector store function return None vector_store = setup_vector_store() @tool def similar_question_search(query: str) -> str: """Search for similar questions in the knowledge base""" if not vector_store: return "Vector store not available" try: matched_docs = vector_store.similarity_search(query, 3) if not matched_docs: return "No similar questions found" formatted = "\n\n---\n\n".join([ f'\n{doc.page_content[:1000]}\n' for doc in matched_docs ]) return formatted except Exception as e: return f"Similar question search error: {str(e)}" # ---- Enhanced System Prompt ---- system_prompt = """ You are an expert assistant capable of solving complex questions using available tools. You have access to: 1. Mathematical tools: add, subtract, multiply, divide, modulus, power, square_root, factorial, gcd, lcm, percentage, compound_interest 2. Search tools: wiki_search, web_search, arxiv_search, serper_search, similar_question_search IMPORTANT INSTRUCTIONS: 1. Break down complex questions into smaller steps 2. Use tools systematically to gather information and perform calculations 3. For mathematical problems, show your work step by step 4. For factual questions, search for current and accurate information 5. Cross-reference information from multiple sources when possible 6. Be precise with numbers - avoid rounding unless necessary When providing your final answer, use this exact format: FINAL ANSWER: [YOUR ANSWER] Rules for the final answer: - Numbers: Use plain digits without commas, units, or symbols (unless specifically requested) - Strings: Use exact names without articles or abbreviations - Lists: Comma-separated values following the above rules - Be concise and accurate Think step by step and use the available tools to ensure accuracy. """ sys_msg = SystemMessage(content=system_prompt) # ---- Enhanced Tool List ---- tools = [ # Math tools multiply, add, subtract, divide, modulus, power, square_root, factorial, gcd, lcm, percentage, compound_interest, # Search tools wiki_search, web_search, arxiv_search, serper_search, similar_question_search ] # ---- Graph Definition ---- def build_graph(provider: str = "huggingface"): """Build the agent graph with improved HuggingFace model""" if provider == "huggingface": # Use a more capable model from HuggingFace endpoint = HuggingFaceEndpoint( repo_id="microsoft/DialoGPT-large", # You can also try "google/flan-t5-xl" or "bigscience/bloom-7b1" temperature=0.1, huggingfacehub_api_token=hf_token, model_kwargs={ "max_length": 1024, "return_full_text": False } ) llm = ChatHuggingFace(llm=endpoint) else: raise ValueError("Only 'huggingface' provider is supported in this version.") llm_with_tools = llm.bind_tools(tools) def assistant(state: MessagesState): """Enhanced assistant node with better error handling""" try: messages = state["messages"] response = llm_with_tools.invoke(messages) return {"messages": [response]} except Exception as e: print(f"Assistant error: {e}") # Fallback response fallback_msg = HumanMessage(content=f"I encountered an error: {str(e)}. Let me try a simpler approach.") return {"messages": [fallback_msg]} def retriever(state: MessagesState): """Enhanced retriever with better context injection""" messages = state["messages"] user_query = messages[-1].content if messages else "" # Try to find similar questions context_messages = [sys_msg] if vector_store: try: similar = vector_store.similarity_search(user_query, k=2) if similar: context_msg = HumanMessage( content=f"Here are similar questions for context:\n\n{similar[0].page_content}" ) context_messages.append(context_msg) except Exception as e: print(f"Retriever error: {e}") return {"messages": context_messages + messages} # Build the graph builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) # Define edges builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile()