KoRiF commited on
Commit
8b26094
·
1 Parent(s): 5b8c340

implement advanced LangGraph processing

Browse files
Files changed (1) hide show
  1. workflow.py +200 -9
workflow.py CHANGED
@@ -1,12 +1,25 @@
1
- from typing import TypedDict, Annotated, Callable, Optional, Any
 
2
  from langgraph.graph import StateGraph, END
 
 
 
3
  from answering import gen_question_answer
4
 
 
 
 
 
 
5
 
6
  class AgentState(TypedDict):
 
7
  question: str
8
  answer: Annotated[str, lambda x, y: y] # Overwrite with new value
9
  formatted_answer: Annotated[str, lambda x, y: y]
 
 
 
10
 
11
  class GAIAAnsweringWorkflow:
12
  def __init__(
@@ -17,14 +30,40 @@ class GAIAAnsweringWorkflow:
17
  """
18
  Initialize the GAIA agent workflow
19
 
20
- Args:
21
  qa_function: Core question answering function (gen_question_answer)
22
  formatter: Answer formatting function (default: GAIA boxed format)
23
  """
24
  self.qa_function = gen_question_answer #qa_function or self.default_qa_function
25
  self.formatter = formatter or self.default_formatter
26
  self.workflow = self.build_workflow()
27
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @staticmethod
29
  def default_qa_function(question: str) -> str:
30
  """Placeholder QA function (override with your CodeAgent)"""
@@ -35,22 +74,138 @@ class GAIAAnsweringWorkflow:
35
  """Default GAIA formatting"""
36
  return answer #f"\\boxed{{{answer}}}"
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def build_workflow(self) -> Any:
39
  """Construct and compile the LangGraph workflow"""
40
  # Create graph
41
  workflow = StateGraph(AgentState)
42
 
43
  # Add nodes
 
 
 
44
  workflow.add_node("generate_answer", self.generate_answer_node)
45
  workflow.add_node("format_output", self.format_output_node)
46
 
47
  # Define edges
48
- workflow.set_entry_point("generate_answer")
 
 
 
 
 
 
 
 
 
49
  workflow.add_edge("generate_answer", "format_output")
50
  workflow.add_edge("format_output", END)
51
 
52
  return workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
54
  def generate_answer_node(self, state: AgentState) -> dict:
55
  """Node that executes the question answering tool"""
56
  try:
@@ -62,13 +217,45 @@ class GAIAAnsweringWorkflow:
62
 
63
  def format_output_node(self, state: AgentState) -> dict:
64
  """Node that formats the answer for GAIA benchmark"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  try:
66
- formatted = self.formatter(state["answer"])
67
- return {"formatted_answer": formatted}
68
  except Exception as e:
69
  return {"formatted_answer": f"\\boxed{{\\text{{Formatting error: {str(e)}}}}}"}
70
 
71
- def __call__(self, question: str) -> str:
 
72
  """
73
  Execute the agent workflow for a given question
74
 
@@ -80,9 +267,13 @@ class GAIAAnsweringWorkflow:
80
  """
81
  # Initialize state
82
  initial_state = {
 
83
  "question": question,
84
  "answer": "",
85
- "formatted_answer": ""
 
 
 
86
  }
87
 
88
  # Execute workflow
@@ -102,7 +293,7 @@ if __name__ == "__main__":
102
  # Create agent instance
103
  agent = GAIAAnsweringWorkflow(
104
  qa_function=gen_question_answer,
105
- formatter=lambda ans: f"ANSWER: \\boxed{{{ans}}}" # Custom formatting
106
  )
107
 
108
  # Test cases
 
1
+ from typing import List, Dict, TypedDict, Annotated, Callable, Optional, Any
2
+ from huggingface_hub.inference._generated.types import question_answering
3
  from langgraph.graph import StateGraph, END
4
+ from langchain_core.messages import HumanMessage
5
+ #from langchain_community.chat_models import ChatHuggingFace
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
7
  from answering import gen_question_answer
8
 
9
+ import requests
10
+ import os
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+
14
 
15
  class AgentState(TypedDict):
16
+ context: Dict[str, str]
17
  question: str
18
  answer: Annotated[str, lambda x, y: y] # Overwrite with new value
19
  formatted_answer: Annotated[str, lambda x, y: y]
20
+ format_requirement: str
21
+ uris: List[str]
22
+ reasoning: List[str]
23
 
24
  class GAIAAnsweringWorkflow:
25
  def __init__(
 
30
  """
31
  Initialize the GAIA agent workflow
32
 
33
+ Args:
34
  qa_function: Core question answering function (gen_question_answer)
35
  formatter: Answer formatting function (default: GAIA boxed format)
36
  """
37
  self.qa_function = gen_question_answer #qa_function or self.default_qa_function
38
  self.formatter = formatter or self.default_formatter
39
  self.workflow = self.build_workflow()
40
+ # Initialize model with HF Inference API
41
+ llm_endpoint = HuggingFaceEndpoint(
42
+ model="deepseek-ai/DeepSeek-R1",#endpoint_url="https://api-inference.huggingface.co/models/cortexso/deepseek-r1:7b",
43
+ huggingfacehub_api_token=os.getenv("HF_TOKEN"),
44
+ task="text-generation",
45
+ #max_tokens=1024
46
+ )
47
+ self.reasoning_llm = ChatHuggingFace(llm=llm_endpoint)
48
+
49
+ self.llm = ChatHuggingFace(
50
+ llm=HuggingFaceEndpoint(
51
+ model="mistralai/Mistral-7B-Instruct-v0.3",
52
+ huggingfacehub_api_token=os.getenv("HF_TOKEN"),
53
+ task="text-generation",
54
+ )
55
+
56
+ )
57
+
58
+
59
+
60
+ def ask_llm(self, question: str, do_reasoning=False)->str:
61
+ prompt = question
62
+ messages = [HumanMessage(content=prompt)]
63
+ response = self.llm.invoke(messages)
64
+ answer = str(response.content)
65
+ return answer
66
+
67
  @staticmethod
68
  def default_qa_function(question: str) -> str:
69
  """Placeholder QA function (override with your CodeAgent)"""
 
74
  """Default GAIA formatting"""
75
  return answer #f"\\boxed{{{answer}}}"
76
 
77
+ def check_context_independent(self, state: AgentState)->bool:
78
+ if ctx := state.get("context"):
79
+ if ctx.get("filename"):
80
+ return False
81
+ prompt = f"""
82
+
83
+ I have a CodeAgent based on the text-to-text model that can use Internet search and parse the information found.
84
+ If this approach is enough to successfully cope with the task, then we will call such a task an "easy question"
85
+
86
+ AS AN ERUDITE PERSON YOU must analyze how difficult it will be to solve the next question
87
+ <<{state["question"]}>>
88
+
89
+ If you think that the question is easy, then return an empty string. Important! You should NOT add any symbols to the output in this case!
90
+ If the question concerns the use of additional resources such as complex analysis of downloaded files or resources on the Internet, then return an action plan
91
+
92
+ """
93
+ reply = self.ask_llm(prompt, True)
94
+ prompt = f""" The reasonings from other LLM is provided: <<{reply}>>
95
+ You have to Summarize:
96
+ output either empty string ('') for easy question
97
+ or extract action plan for non-easy question
98
+ """
99
+ reply = self.ask_llm(prompt, False)
100
+ if reply:
101
+ state["reasoning"].append(reply)
102
+ return False
103
+ return True
104
+
105
+
106
  def build_workflow(self) -> Any:
107
  """Construct and compile the LangGraph workflow"""
108
  # Create graph
109
  workflow = StateGraph(AgentState)
110
 
111
  # Add nodes
112
+ workflow.add_node("preparations", self.preparations_node)
113
+ workflow.add_node("triage", self.triage_node)
114
+ workflow.add_node("deep_processing", self.deep_processing_node)
115
  workflow.add_node("generate_answer", self.generate_answer_node)
116
  workflow.add_node("format_output", self.format_output_node)
117
 
118
  # Define edges
119
+ workflow.set_entry_point("preparations")
120
+ workflow.add_edge("preparations", "triage")
121
+ workflow.add_conditional_edges("triage"
122
+ , self.check_context_independent
123
+ , {
124
+ True: "generate_answer",
125
+ False: "deep_processing"
126
+ })
127
+ workflow.add_edge("deep_processing", "format_output")
128
+
129
  workflow.add_edge("generate_answer", "format_output")
130
  workflow.add_edge("format_output", END)
131
 
132
  return workflow.compile()
133
+
134
+ def extract_noted_urls_with_llm(self, question: str) -> List[str]:
135
+ """Use LLM to extract URLs specifically noted in the question"""
136
+ prompt = f"""
137
+ Analyze the following question and extract ONLY URLs that are explicitly noted or referenced.
138
+ Return each URL on a separate line. If no URLs are noted, return an empty string.
139
+
140
+ QUESTION: {question}
141
+
142
+ Respond ONLY with the URLs, one per line, with no additional text or formatting.
143
+ """
144
+
145
+ try:
146
+ # Use your LLM to generate the response
147
+ response = self.ask_llm(prompt)
148
+
149
+ # Parse the response to extract URLs
150
+ urls = []
151
+ for line in response.split('\n'):
152
+ line = line.strip()
153
+ #if line.startswith(('http://', 'https://', 'www.')):
154
+ urls.append(line)
155
+
156
+ return urls
157
+ except Exception as e:
158
+ print(f"LLM-based URL extraction failed: {str(e)}")
159
+ return []
160
+
161
+ def download_file(self, task_id: str, file_name: str) -> str:
162
+ """Download file from API and return local path"""
163
+ try:
164
+ #os.makedirs("files", exist_ok=True)
165
+ file_path = f"./{file_name}"#files/{file_id}"
166
+ api_base_url: str = "https://agents-course-unit4-scoring.hf.space"
167
+ api_endpoint = f"{api_base_url}/files/{task_id}"
168
+ response = requests.get(api_endpoint)
169
+ response.raise_for_status()
170
+
171
+ with open(file_path, "wb") as f:
172
+ f.write(response.content)
173
+
174
+ print(f"File saved: {file_path}")
175
+ return file_path
176
+ except Exception as e:
177
+ print(f"File download failed: {str(e)}")
178
+ return ""
179
+
180
+ def preparations_node(self, state: AgentState) -> dict:
181
+ if not state["context"]:
182
+ return {}
183
+ """Node to prepare resources"""
184
+ context = state["context"]
185
+ question = state["question"]
186
+ uris = state["uris"]
187
+
188
+ # 1. Handle file_id in context
189
+ if file_name:= context.get("file_name"):
190
+ file_path = self.download_file(context["task_id"], file_name)
191
+ if file_path:
192
+ uris.append(file_path)
193
+
194
+ # 2. Extract URLs from question
195
+ found_urls = self.extract_noted_urls_with_llm(question)
196
+ if found_urls:
197
+ uris.extend(found_urls)
198
+ print(f"Added {len(found_urls)} URL(s) from question")
199
+
200
+ return {"uris": uris}
201
+
202
+
203
+ def triage_node(self, state: AgentState) -> dict:
204
+ return {}
205
 
206
+ def deep_processing_node(self, state: AgentState) -> dict:
207
+ return {}
208
+
209
  def generate_answer_node(self, state: AgentState) -> dict:
210
  """Node that executes the question answering tool"""
211
  try:
 
217
 
218
  def format_output_node(self, state: AgentState) -> dict:
219
  """Node that formats the answer for GAIA benchmark"""
220
+ prompt = f"""
221
+
222
+ As a very smart person, you should formulate what should be the output format of the answer to the question:
223
+ <<{state["question"]}>>
224
+
225
+ You must formulate it very briefly and clearly!
226
+ The common requirements is:
227
+ <<
228
+ OUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
229
+
230
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise, and don't include additional text.
231
+ If the answer is a number, represent it with digits.
232
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
233
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
234
+ >>
235
+ But you have to figure out how the answer should looks like in the given case and reformulate requirement according to the specified question
236
+
237
+ """
238
+ format_requirement = self.ask_llm(prompt, True)
239
+
240
+ prompt = f"""
241
+ Your attentiveness and responsibility are very much needed! We are solving a strict test that is automatically checked, so we must formulate the answer in strict accordance with the task and the required format! Even one incorrect symbol in the answer can fail the task! Pull yourself together!
242
+
243
+ You will be required to produce an output of the answer, but formatted in accordance with the task
244
+
245
+ Received answer: <<{state['answer']}>>
246
+
247
+ Format requirements: <<{format_requirement}>>
248
+
249
+ Do NOT include << >> in your answer! Don't use full answer formulations! If you are asked about a number it MUST be just a number, nothing more! Each time it should be a clear answer (checked automatically)
250
+ """
251
  try:
252
+ formatted = self.ask_llm(prompt)
253
+ return {"formatted_answer": formatted.strip()}
254
  except Exception as e:
255
  return {"formatted_answer": f"\\boxed{{\\text{{Formatting error: {str(e)}}}}}"}
256
 
257
+
258
+ def __call__(self, question: str, context: Dict|None=None) -> str:
259
  """
260
  Execute the agent workflow for a given question
261
 
 
267
  """
268
  # Initialize state
269
  initial_state = {
270
+ "context": context,
271
  "question": question,
272
  "answer": "",
273
+ "formatted_answer": "",
274
+ "format_requirement": "",
275
+ "uris": [],
276
+ "reasoning": []
277
  }
278
 
279
  # Execute workflow
 
293
  # Create agent instance
294
  agent = GAIAAnsweringWorkflow(
295
  qa_function=gen_question_answer,
296
+ formatter=lambda ans: ans #f"ANSWER: \\boxed{{{ans}}}" # Custom formatting
297
  )
298
 
299
  # Test cases