junlin3 commited on
Commit
957b75b
·
1 Parent(s): 81917a3

添加实现代码

Browse files
Files changed (4) hide show
  1. .gitignore +5 -0
  2. agent.py +213 -0
  3. app.py +9 -4
  4. requirements.txt +162 -2
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /.DS_Store
2
+ /.idea
3
+ /.venv
4
+ /chroma_langchain_db
5
+ /.env
agent.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain.tools.retriever import create_retriever_tool
3
+ from langchain_community.document_loaders import WikipediaLoader
4
+ from langchain_community.tools.tavily_search import TavilySearchResults
5
+ from langchain_community.document_loaders import ArxivLoader
6
+
7
+ from langchain_core.messages import HumanMessage, SystemMessage
8
+
9
+ from langchain_ollama import ChatOllama
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_huggingface import HuggingFaceEmbeddings, ChatHuggingFace, HuggingFaceEndpoint
12
+ from langchain_chroma import Chroma
13
+ from langgraph.graph import START, StateGraph, MessagesState
14
+
15
+ from langgraph.prebuilt import ToolNode
16
+ from langgraph.prebuilt import tools_condition
17
+
18
+ import os
19
+ from dotenv import load_dotenv
20
+
21
+
22
+ load_dotenv()
23
+
24
+
25
+ @tool
26
+ def multiply(a: int, b: int) -> int:
27
+ """Multiply two numbers and return the result.
28
+
29
+ Args:
30
+ a (int): The first number.
31
+ b (int): The second number.
32
+ Returns:
33
+ int: The product of the two numbers.
34
+ """
35
+ return a * b
36
+
37
+ @tool
38
+ def add(a: int, b: int) -> int:
39
+ """Add two numbers and return the result.
40
+ Args:
41
+ a (int): The first number.
42
+ b (int): The second number.
43
+ Returns:
44
+ int: The sum of the two numbers.
45
+ """
46
+ return a + b
47
+
48
+ @tool
49
+ def subtract(a: int, b: int) -> int:
50
+ """Subtract two numbers and return the result.
51
+ Args:
52
+ a (int): The first number.
53
+ b (int): The second number.
54
+ Returns:
55
+ int: The difference between the two numbers.
56
+ """
57
+ return a - b
58
+
59
+ @tool
60
+ def divide(a: int, b: int) -> int:
61
+ """Divide two numbers and return the result.
62
+ Args:
63
+ a (int): The first number.
64
+ b (int): The second number.
65
+ Returns:
66
+ int: The quotient of the two numbers.
67
+ """
68
+ return a / b
69
+
70
+ @tool
71
+ def modulus(a: int, b: int) -> int:
72
+ """Calculate the modulus of two numbers and return the result.
73
+ Args:
74
+ a (int): The first number.
75
+ b (int): The second number.
76
+ Returns:
77
+ int: The modulus of the two numbers.
78
+ """
79
+ return a % b
80
+
81
+ @tool
82
+ def wiki_search(query: str) -> str:
83
+ """Search Wikipedia for a given query and return the top result.
84
+ Args:
85
+ query (str): The search query.
86
+ """
87
+ search_docs = WikipediaLoader(query, load_max_docs=2).load()
88
+ formatted_search_docs = '\n\n---\n\n'.join(
89
+ [
90
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>' for doc in search_docs
91
+ ]
92
+ )
93
+ return {'wiki_results': formatted_search_docs}
94
+
95
+ @tool
96
+ def web_search(query: str) -> str:
97
+ """Search Tavily for a query and return maximum 3 results
98
+
99
+ Args:
100
+ query (str): The search query.
101
+ """
102
+ search_docs = TavilySearchResults(max_results=3).invoke(query)
103
+ formatted_search_docs = '\n\n---\n\n'.join(
104
+ [
105
+ f'<Document source="{doc["url"]}" page="{doc.get("title", "")}">\n{doc.get('content', '')}\n</Document>' for doc in search_docs
106
+ ]
107
+ )
108
+ return {'web_results': formatted_search_docs}
109
+
110
+ @tool
111
+ def arvix_search(query: str) -> str:
112
+ """Search Arvix for a query and return maximum 3 results
113
+
114
+ Args:
115
+ query (str): The search query.
116
+ """
117
+ search_docs = ArxivLoader(query, load_max_docs=3).load()
118
+ formatted_search_docs = '\n\n---\n\n'.join(
119
+ [
120
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>' for doc in search_docs
121
+ ]
122
+ )
123
+ return {'arvix_results': formatted_search_docs}
124
+
125
+
126
+ system_prompt = """
127
+
128
+ """
129
+
130
+
131
+
132
+ # System message
133
+ sys_msg = SystemMessage(content=system_prompt)
134
+
135
+ # Retriever
136
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
137
+ vector_store = Chroma(
138
+ collection_name="demo_collection",
139
+ embedding_function=embeddings,
140
+ persist_directory="./chroma_langchain_db",
141
+ )
142
+ create_retriever_tool = create_retriever_tool(
143
+ retriever= vector_store.as_retriever(),
144
+ name='Question Search',
145
+ description='A tool to retrieve similar question from vector store.'
146
+ )
147
+
148
+
149
+ tools = [
150
+ multiply,
151
+ add,
152
+ subtract,
153
+ modulus,
154
+ wiki_search,
155
+ web_search,
156
+ arvix_search
157
+ ]
158
+
159
+ # build graph function
160
+ def build_graph(tag: str='google'):
161
+ """Build the graph"""
162
+
163
+ if tag == 'local':
164
+ llm = ChatOllama(model="qwen3")
165
+ elif tag == 'google':
166
+ # Google Gemini
167
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
168
+ else:
169
+ url='https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf'
170
+ token = os.getenv('HF_TOKEN')
171
+ end_point = HuggingFaceEndpoint(
172
+ endpoint_url=url,
173
+ temperature=0,
174
+ huggingfacehub_api_token=token)
175
+ llm = ChatHuggingFace(llm=end_point)
176
+ # bind tools to llm
177
+ llm_with_tools = llm.bind_tools(tools)
178
+
179
+ def assistant(state: MessagesState):
180
+ return {'messages': [llm_with_tools.invoke(state['messages'])]}
181
+
182
+ def retriever(state: MessagesState):
183
+ similar_question = vector_store.similarity_search(state['messages'][0].content)
184
+ example_msg = HumanMessage(
185
+ content=f''
186
+ )
187
+ return {'messages': [sys_msg] + state['messages'] + [example_msg]}
188
+
189
+ builder = StateGraph(MessagesState)
190
+ builder.add_node('retriever', retriever)
191
+ builder.add_node('assistant', assistant)
192
+ builder.add_node('tools', ToolNode(tools))
193
+ builder.add_edge(START, 'retriever')
194
+ builder.add_edge('retriever', 'assistant')
195
+ builder.add_conditional_edges(
196
+ 'assistant',
197
+ tools_condition
198
+ )
199
+ builder.add_edge('tools', 'assistant')
200
+ return builder.compile()
201
+
202
+
203
+ # test
204
+ if __name__ == "__main__":
205
+ question = 'When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?'
206
+ # build the graph
207
+ graph = build_graph('local')
208
+ # run the graph
209
+ messages = [HumanMessage(content=question)]
210
+ messages = graph.invoke({'messages': messages})
211
+ for m in messages['messages']:
212
+ m.pretty_print()
213
+
app.py CHANGED
@@ -4,6 +4,10 @@ import requests
4
  import inspect
5
  import pandas as pd
6
 
 
 
 
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -12,12 +16,13 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
15
- print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
4
  import inspect
5
  import pandas as pd
6
 
7
+ from langchain_core.messages import HumanMessage
8
+ from agent import build_graph
9
+
10
+
11
  # (Keep Constants as is)
12
  # --- Constants ---
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
16
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
17
  class BasicAgent:
18
  def __init__(self):
19
+ self.graph = build_graph()
20
  def __call__(self, question: str) -> str:
21
  print(f"Agent received question (first 50 chars): {question[:50]}...")
22
+ messages = [HumanMessage(content=question)]
23
+ result = self.graph.invoke({"messages": messages})
24
+ answer = result['messages'][-1].content
25
+ return answer
26
 
27
  def run_and_submit_all( profile: gr.OAuthProfile | None):
28
  """
requirements.txt CHANGED
@@ -1,2 +1,162 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.12
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ asgiref==3.8.1
8
+ attrs==25.3.0
9
+ backoff==2.2.1
10
+ bcrypt==4.3.0
11
+ beautifulsoup4==4.13.4
12
+ build==1.2.2.post1
13
+ cachetools==5.5.2
14
+ certifi==2025.4.26
15
+ charset-normalizer==3.4.2
16
+ chromadb==1.0.12
17
+ click==8.2.1
18
+ coloredlogs==15.0.1
19
+ dataclasses-json==0.6.7
20
+ distro==1.9.0
21
+ durationpy==0.10
22
+ fastapi==0.115.9
23
+ ffmpy==0.6.0
24
+ filelock==3.18.0
25
+ filetype==1.2.0
26
+ flatbuffers==25.2.10
27
+ frozenlist==1.7.0
28
+ fsspec==2025.5.1
29
+ google-ai-generativelanguage==0.6.18
30
+ google-api-core==2.25.1
31
+ google-auth==2.40.3
32
+ googleapis-common-protos==1.70.0
33
+ gradio==5.34.0
34
+ gradio_client==1.10.3
35
+ groovy==0.1.2
36
+ grpcio==1.73.0
37
+ grpcio-status==1.73.0
38
+ h11==0.16.0
39
+ hf-xet==1.1.3
40
+ httpcore==1.0.9
41
+ httptools==0.6.4
42
+ httpx==0.28.1
43
+ httpx-sse==0.4.0
44
+ huggingface-hub==0.33.0
45
+ humanfriendly==10.0
46
+ idna==3.10
47
+ importlib_metadata==8.7.0
48
+ importlib_resources==6.5.2
49
+ Jinja2==3.1.6
50
+ joblib==1.5.1
51
+ jsonpatch==1.33
52
+ jsonpointer==3.0.0
53
+ jsonschema==4.24.0
54
+ jsonschema-specifications==2025.4.1
55
+ kubernetes==33.1.0
56
+ langchain==0.3.25
57
+ langchain-chroma==0.2.4
58
+ langchain-community==0.3.25
59
+ langchain-core==0.3.65
60
+ langchain-google-genai==2.1.5
61
+ langchain-huggingface==0.3.0
62
+ langchain-ollama==0.3.3
63
+ langchain-text-splitters==0.3.8
64
+ langgraph==0.4.8
65
+ langgraph-checkpoint==2.0.26
66
+ langgraph-prebuilt==0.2.2
67
+ langgraph-sdk==0.1.70
68
+ langsmith==0.3.45
69
+ markdown-it-py==3.0.0
70
+ MarkupSafe==3.0.2
71
+ marshmallow==3.26.1
72
+ mdurl==0.1.2
73
+ mmh3==5.1.0
74
+ mpmath==1.3.0
75
+ multidict==6.4.4
76
+ mypy_extensions==1.1.0
77
+ networkx==3.5
78
+ numpy==2.3.0
79
+ oauthlib==3.2.2
80
+ ollama==0.5.1
81
+ onnxruntime==1.22.0
82
+ opentelemetry-api==1.34.1
83
+ opentelemetry-exporter-otlp-proto-common==1.34.1
84
+ opentelemetry-exporter-otlp-proto-grpc==1.34.1
85
+ opentelemetry-instrumentation==0.55b1
86
+ opentelemetry-instrumentation-asgi==0.55b1
87
+ opentelemetry-instrumentation-fastapi==0.55b1
88
+ opentelemetry-proto==1.34.1
89
+ opentelemetry-sdk==1.34.1
90
+ opentelemetry-semantic-conventions==0.55b1
91
+ opentelemetry-util-http==0.55b1
92
+ orjson==3.10.18
93
+ ormsgpack==1.10.0
94
+ overrides==7.7.0
95
+ packaging==24.2
96
+ pandas==2.3.0
97
+ pillow==11.2.1
98
+ posthog==4.8.0
99
+ propcache==0.3.2
100
+ proto-plus==1.26.1
101
+ protobuf==6.31.1
102
+ pyasn1==0.6.1
103
+ pyasn1_modules==0.4.2
104
+ pydantic==2.11.5
105
+ pydantic-settings==2.9.1
106
+ pydantic_core==2.33.2
107
+ pydub==0.25.1
108
+ Pygments==2.19.1
109
+ PyPika==0.48.9
110
+ pyproject_hooks==1.2.0
111
+ python-dateutil==2.9.0.post0
112
+ python-dotenv==1.1.0
113
+ python-multipart==0.0.20
114
+ pytz==2025.2
115
+ PyYAML==6.0.2
116
+ referencing==0.36.2
117
+ regex==2024.11.6
118
+ requests==2.32.4
119
+ requests-oauthlib==2.0.0
120
+ requests-toolbelt==1.0.0
121
+ rich==14.0.0
122
+ rpds-py==0.25.1
123
+ rsa==4.9.1
124
+ ruff==0.11.13
125
+ safehttpx==0.1.6
126
+ safetensors==0.5.3
127
+ scikit-learn==1.7.0
128
+ scipy==1.15.3
129
+ semantic-version==2.10.0
130
+ sentence-transformers==4.1.0
131
+ setuptools==80.9.0
132
+ shellingham==1.5.4
133
+ six==1.17.0
134
+ sniffio==1.3.1
135
+ soupsieve==2.7
136
+ SQLAlchemy==2.0.41
137
+ starlette==0.45.3
138
+ sympy==1.14.0
139
+ tenacity==9.1.2
140
+ threadpoolctl==3.6.0
141
+ tokenizers==0.21.1
142
+ tomlkit==0.13.3
143
+ torch==2.7.1
144
+ tqdm==4.67.1
145
+ transformers==4.52.4
146
+ typer==0.16.0
147
+ typing-inspect==0.9.0
148
+ typing-inspection==0.4.1
149
+ typing_extensions==4.14.0
150
+ tzdata==2025.2
151
+ urllib3==2.4.0
152
+ uvicorn==0.34.3
153
+ uvloop==0.21.0
154
+ watchfiles==1.0.5
155
+ websocket-client==1.8.0
156
+ websockets==15.0.1
157
+ wikipedia==1.4.0
158
+ wrapt==1.17.2
159
+ xxhash==3.5.0
160
+ yarl==1.20.1
161
+ zipp==3.23.0
162
+ zstandard==0.23.0