import os from dotenv import load_dotenv # Load environment variables load_dotenv() # Set protobuf implementation to avoid C++ extension issues os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # Load keys from environment hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN") serper_api_key = os.getenv("SERPER_API_KEY") # ---- Imports ---- from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition, ToolNode from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain_community.vectorstores import Chroma from langchain_core.documents import Document from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.tools import tool from langchain.tools.retriever import create_retriever_tool from langchain.embeddings import HuggingFaceEmbeddings import json # ---- Tools ---- @tool def multiply(a: int, b: int) -> int: """Multiply two numbers together.""" return a * b @tool def add(a: int, b: int) -> int: """Add two numbers together.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract the second number from the first.""" return a - b @tool def divide(a: int, b: int) -> float: """Divide the first number by the second. Returns float or error if dividing by zero.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Returns the remainder after division of the first number by the second.""" return a % b @tool def wiki_search(query: str) -> str: """Search Wikipedia for information. Useful for factual questions about people, places, events, etc.""" try: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() formatted = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) return {"wiki_results": formatted} except Exception as e: return f"Wikipedia search failed: {str(e)}" @tool def web_search(query: str) -> str: """Search the web for current information. Useful when you need recent or non-Wikipedia information.""" try: search = TavilySearchResults(max_results=3) search_docs = search.invoke(query) formatted = "\n\n---\n\n".join( [ f'\n{doc["content"]}\n' for doc in search_docs ] ) return {"web_results": formatted} except Exception as e: return f"Web search failed: {str(e)}" @tool def arxiv_search(query: str) -> str: """Search academic papers on ArXiv. Useful for technical or scientific questions.""" try: search_docs = ArxivLoader(query=query, load_max_docs=2).load() formatted = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ] ) return {"arxiv_results": formatted} except Exception as e: return f"ArXiv search failed: {str(e)}" # ---- Embedding & Vector Store Setup ---- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # Load QA pairs json_QA = [] try: with open('metadata.jsonl', 'r') as jsonl_file: for line in jsonl_file: json_QA.append(json.loads(line)) except Exception as e: print(f"Error loading metadata.jsonl: {e}") json_QA = [] documents = [ Document( page_content=f"Question: {sample['Question']}\n\nAnswer: {sample['Final answer']}", metadata={"source": sample["task_id"], "question": sample["Question"], "answer": sample["Final answer"]} ) for sample in json_QA ] try: vector_store = Chroma.from_documents( documents=documents, embedding=embeddings, persist_directory="./chroma_db", collection_name="qa_collection" ) vector_store.persist() print(f"Documents inserted: {len(documents)}") except Exception as e: print(f"Error creating vector store: {e}") raise @tool def similar_question_search(query: str) -> str: """Search for similar questions that have been answered before. Always check here first before using other tools.""" try: matched_docs = vector_store.similarity_search(query, k=3) formatted = "\n\n---\n\n".join( [ f'\n\n' for doc in matched_docs ] ) return {"similar_questions": formatted} except Exception as e: return f"Similar question search failed: {str(e)}" # ---- System Prompt ---- system_prompt = """ You are an expert question-answering assistant. Follow these steps for each question: 1. FIRST check for similar questions using the similar_question_search tool 2. If a similar question exists with a clear answer, use that answer 3. If not, determine which tools might help answer the question 4. Use the tools systematically to gather information 5. Combine information from multiple sources if needed 6. Format your final answer precisely as: FINAL ANSWER: [your answer here] Rules for answers: - Numbers: plain digits only (no commas, units, or symbols) - Strings: minimal words, no articles, full names - Lists: comma-separated with no extra formatting - Be concise but accurate """ sys_msg = SystemMessage(content=system_prompt) # ---- Tool List ---- tools = [ similar_question_search, # Check this first multiply, add, subtract, divide, modulus, # Math tools wiki_search, web_search, arxiv_search # Information tools ] # ---- Graph Definition ---- def build_graph(): try: # Using a powerful HuggingFace model llm = ChatHuggingFace( llm=HuggingFaceEndpoint( repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0, max_new_tokens=512, huggingfacehub_api_token=hf_token ) ) llm_with_tools = llm.bind_tools(tools) def assistant(state: MessagesState): return {"messages": [llm_with_tools.invoke(state["messages"])]} def retriever(state: MessagesState): try: # First try to find similar questions similar = vector_store.similarity_search(state["messages"][-1].content, k=2) if similar: example_msg = HumanMessage( content=f"Here are similar questions and their answers:\n\n" + "\n\n".join([f"Q: {doc.metadata['question']}\nA: {doc.metadata['answer']}" for doc in similar]) ) return {"messages": [sys_msg] + state["messages"] + [example_msg]} return {"messages": [sys_msg] + state["messages"]} except Exception as e: print(f"Retriever error: {e}") return {"messages": [sys_msg] + state["messages"]} builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile() except Exception as e: print(f"Error building graph: {e}") raise