Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Fix
Browse files
    	
        agent.py
    CHANGED
    
    | @@ -2,13 +2,10 @@ import os | |
| 2 | 
             
            import json
         | 
| 3 | 
             
            from dotenv import load_dotenv
         | 
| 4 |  | 
| 5 | 
            -
            # ---- Environment & Setup ----
         | 
| 6 | 
             
            load_dotenv()
         | 
| 7 | 
             
            os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
         | 
| 8 | 
            -
             | 
| 9 | 
             
            hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
         | 
| 10 |  | 
| 11 | 
            -
            # ---- Imports ----
         | 
| 12 | 
             
            from langgraph.graph import START, StateGraph, MessagesState
         | 
| 13 | 
             
            from langgraph.prebuilt import tools_condition, ToolNode
         | 
| 14 | 
             
            from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
         | 
| @@ -20,7 +17,7 @@ from langchain_core.messages import SystemMessage, HumanMessage | |
| 20 | 
             
            from langchain_core.tools import tool
         | 
| 21 | 
             
            from langchain.schema import Document
         | 
| 22 |  | 
| 23 | 
            -
            # ----  | 
| 24 |  | 
| 25 | 
             
            @tool
         | 
| 26 | 
             
            def multiply(a: int, b: int) -> int:
         | 
| @@ -51,87 +48,58 @@ def modulus(a: int, b: int) -> int: | |
| 51 |  | 
| 52 | 
             
            @tool
         | 
| 53 | 
             
            def wiki_search(query: str) -> str:
         | 
| 54 | 
            -
                """Search Wikipedia for the  | 
| 55 | 
             
                search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
         | 
| 56 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 57 | 
            -
                    [
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                        for doc in search_docs
         | 
| 60 | 
            -
                    ]
         | 
| 61 | 
             
                )
         | 
| 62 | 
             
                return {"wiki_results": formatted}
         | 
| 63 |  | 
| 64 | 
             
            @tool
         | 
| 65 | 
             
            def web_search(query: str) -> str:
         | 
| 66 | 
            -
                """Search the web using Tavily  | 
| 67 | 
             
                search_docs = TavilySearchResults(max_results=3).invoke(query=query)
         | 
| 68 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 69 | 
            -
                    [
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                        for doc in search_docs
         | 
| 72 | 
            -
                    ]
         | 
| 73 | 
             
                )
         | 
| 74 | 
             
                return {"web_results": formatted}
         | 
| 75 |  | 
| 76 | 
             
            @tool
         | 
| 77 | 
             
            def arvix_search(query: str) -> str:
         | 
| 78 | 
            -
                """Search Arxiv for  | 
| 79 | 
             
                search_docs = ArxivLoader(query=query, load_max_docs=3).load()
         | 
| 80 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 81 | 
            -
                    [
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                        for doc in search_docs
         | 
| 84 | 
            -
                    ]
         | 
| 85 | 
             
                )
         | 
| 86 | 
             
                return {"arvix_results": formatted}
         | 
| 87 |  | 
| 88 | 
            -
             | 
| 89 | 
            -
            def similar_question_search(query: str) -> str:
         | 
| 90 | 
            -
                """Searches for questions similar to the input query using a vector database."""
         | 
| 91 | 
            -
                matched_docs = vector_store.similarity_search(query, 3)
         | 
| 92 | 
            -
                formatted = "\n\n---\n\n".join(
         | 
| 93 | 
            -
                    [
         | 
| 94 | 
            -
                        f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
         | 
| 95 | 
            -
                        for doc in matched_docs
         | 
| 96 | 
            -
                    ]
         | 
| 97 | 
            -
                )
         | 
| 98 | 
            -
                return {"similar_questions": formatted}
         | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
            # ---- Embedding & Vector Store ----
         | 
| 102 | 
            -
             | 
| 103 | 
             
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
         | 
| 104 | 
            -
             | 
| 105 | 
            -
            json_QA = []
         | 
| 106 | 
            -
            with open('metadata.jsonl', 'r') as jsonl_file:
         | 
| 107 | 
            -
                for line in jsonl_file:
         | 
| 108 | 
            -
                    json_QA.append(json.loads(line))
         | 
| 109 | 
            -
             | 
| 110 | 
             
            documents = [
         | 
| 111 | 
             
                Document(
         | 
| 112 | 
             
                    page_content=f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}",
         | 
| 113 | 
             
                    metadata={"source": sample["task_id"]}
         | 
| 114 | 
            -
                )
         | 
| 115 | 
            -
                for sample in json_QA
         | 
| 116 | 
             
            ]
         | 
| 117 | 
            -
             | 
| 118 | 
             
            vector_store = Chroma.from_documents(
         | 
| 119 | 
             
                documents=documents,
         | 
| 120 | 
             
                embedding=embeddings,
         | 
| 121 | 
             
                persist_directory="./chroma_db",
         | 
| 122 | 
             
                collection_name="my_collection"
         | 
| 123 | 
             
            )
         | 
| 124 | 
            -
            vector_store.persist()
         | 
| 125 | 
             
            print("Documents inserted:", vector_store._collection.count())
         | 
| 126 |  | 
| 127 | 
             
            @tool
         | 
| 128 | 
             
            def similar_question_search(query: str) -> str:
         | 
|  | |
| 129 | 
             
                matched_docs = vector_store.similarity_search(query, 3)
         | 
| 130 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 131 | 
            -
                    [
         | 
| 132 | 
            -
             | 
| 133 | 
            -
                        for doc in matched_docs
         | 
| 134 | 
            -
                    ]
         | 
| 135 | 
             
                )
         | 
| 136 | 
             
                return {"similar_questions": formatted}
         | 
| 137 |  | 
| @@ -143,17 +111,14 @@ Now, I will ask you a question. Report your thoughts, and finish your answer wit | |
| 143 | 
             
            FINAL ANSWER: [YOUR FINAL ANSWER]. 
         | 
| 144 | 
             
            YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings...
         | 
| 145 | 
             
            """
         | 
| 146 | 
            -
             | 
| 147 | 
             
            sys_msg = SystemMessage(content=system_prompt)
         | 
| 148 |  | 
| 149 | 
            -
            # ---- Tool List ----
         | 
| 150 | 
            -
             | 
| 151 | 
             
            tools = [
         | 
| 152 | 
             
                multiply, add, subtract, divide, modulus,
         | 
| 153 | 
             
                wiki_search, web_search, arvix_search, similar_question_search
         | 
| 154 | 
             
            ]
         | 
| 155 |  | 
| 156 | 
            -
            # ---- Graph  | 
| 157 |  | 
| 158 | 
             
            def build_graph(provider: str = "huggingface"):
         | 
| 159 | 
             
                if provider == "huggingface":
         | 
|  | |
| 2 | 
             
            import json
         | 
| 3 | 
             
            from dotenv import load_dotenv
         | 
| 4 |  | 
|  | |
| 5 | 
             
            load_dotenv()
         | 
| 6 | 
             
            os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
         | 
|  | |
| 7 | 
             
            hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
         | 
| 8 |  | 
|  | |
| 9 | 
             
            from langgraph.graph import START, StateGraph, MessagesState
         | 
| 10 | 
             
            from langgraph.prebuilt import tools_condition, ToolNode
         | 
| 11 | 
             
            from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
         | 
|  | |
| 17 | 
             
            from langchain_core.tools import tool
         | 
| 18 | 
             
            from langchain.schema import Document
         | 
| 19 |  | 
| 20 | 
            +
            # ---- Tool Definitions (with docstrings) ----
         | 
| 21 |  | 
| 22 | 
             
            @tool
         | 
| 23 | 
             
            def multiply(a: int, b: int) -> int:
         | 
|  | |
| 48 |  | 
| 49 | 
             
            @tool
         | 
| 50 | 
             
            def wiki_search(query: str) -> str:
         | 
| 51 | 
            +
                """Search Wikipedia for the query and return text of up to 2 documents."""
         | 
| 52 | 
             
                search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
         | 
| 53 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 54 | 
            +
                    f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
         | 
| 55 | 
            +
                    for doc in search_docs
         | 
|  | |
|  | |
| 56 | 
             
                )
         | 
| 57 | 
             
                return {"wiki_results": formatted}
         | 
| 58 |  | 
| 59 | 
             
            @tool
         | 
| 60 | 
             
            def web_search(query: str) -> str:
         | 
| 61 | 
            +
                """Search the web for the query using Tavily and return up to 3 results."""
         | 
| 62 | 
             
                search_docs = TavilySearchResults(max_results=3).invoke(query=query)
         | 
| 63 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 64 | 
            +
                    f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
         | 
| 65 | 
            +
                    for doc in search_docs
         | 
|  | |
|  | |
| 66 | 
             
                )
         | 
| 67 | 
             
                return {"web_results": formatted}
         | 
| 68 |  | 
| 69 | 
             
            @tool
         | 
| 70 | 
             
            def arvix_search(query: str) -> str:
         | 
| 71 | 
            +
                """Search Arxiv for the query and return content from up to 3 papers."""
         | 
| 72 | 
             
                search_docs = ArxivLoader(query=query, load_max_docs=3).load()
         | 
| 73 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 74 | 
            +
                    f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
         | 
| 75 | 
            +
                    for doc in search_docs
         | 
|  | |
|  | |
| 76 | 
             
                )
         | 
| 77 | 
             
                return {"arvix_results": formatted}
         | 
| 78 |  | 
| 79 | 
            +
            # Build vector store once
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 80 | 
             
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
         | 
| 81 | 
            +
            json_QA = [json.loads(line) for line in open("metadata.jsonl", "r")]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 82 | 
             
            documents = [
         | 
| 83 | 
             
                Document(
         | 
| 84 | 
             
                    page_content=f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}",
         | 
| 85 | 
             
                    metadata={"source": sample["task_id"]}
         | 
| 86 | 
            +
                ) for sample in json_QA
         | 
|  | |
| 87 | 
             
            ]
         | 
|  | |
| 88 | 
             
            vector_store = Chroma.from_documents(
         | 
| 89 | 
             
                documents=documents,
         | 
| 90 | 
             
                embedding=embeddings,
         | 
| 91 | 
             
                persist_directory="./chroma_db",
         | 
| 92 | 
             
                collection_name="my_collection"
         | 
| 93 | 
             
            )
         | 
|  | |
| 94 | 
             
            print("Documents inserted:", vector_store._collection.count())
         | 
| 95 |  | 
| 96 | 
             
            @tool
         | 
| 97 | 
             
            def similar_question_search(query: str) -> str:
         | 
| 98 | 
            +
                """Search for questions similar to the input query using the vector store."""
         | 
| 99 | 
             
                matched_docs = vector_store.similarity_search(query, 3)
         | 
| 100 | 
             
                formatted = "\n\n---\n\n".join(
         | 
| 101 | 
            +
                    f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
         | 
| 102 | 
            +
                    for doc in matched_docs
         | 
|  | |
|  | |
| 103 | 
             
                )
         | 
| 104 | 
             
                return {"similar_questions": formatted}
         | 
| 105 |  | 
|  | |
| 111 | 
             
            FINAL ANSWER: [YOUR FINAL ANSWER]. 
         | 
| 112 | 
             
            YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings...
         | 
| 113 | 
             
            """
         | 
|  | |
| 114 | 
             
            sys_msg = SystemMessage(content=system_prompt)
         | 
| 115 |  | 
|  | |
|  | |
| 116 | 
             
            tools = [
         | 
| 117 | 
             
                multiply, add, subtract, divide, modulus,
         | 
| 118 | 
             
                wiki_search, web_search, arvix_search, similar_question_search
         | 
| 119 | 
             
            ]
         | 
| 120 |  | 
| 121 | 
            +
            # ---- Graph Builder ----
         | 
| 122 |  | 
| 123 | 
             
            def build_graph(provider: str = "huggingface"):
         | 
| 124 | 
             
                if provider == "huggingface":
         | 
 
			
