Create basic_agent.py

#208
by ksdeexith - opened
Files changed (1) hide show
  1. basic_agent.py +176 -0
basic_agent.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, TypedDict, Annotated
2
+ import operator
3
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.tools import tool
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
+ from langgraph.graph import StateGraph, START, END
8
+ import os
9
+ import requests
10
+ import json
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv()
14
+
15
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
16
+ llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)
17
+
18
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
19
+
20
+ class AgentState(TypedDict):
21
+ question: str
22
+ answer: str
23
+ task_id: str
24
+ log: Annotated[List[str], operator.add]
25
+
26
+ def assistant(state: AgentState) -> AgentState:
27
+ messages = [
28
+ SystemMessage(content="You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: <your answer here>. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."),
29
+ HumanMessage(content=state["question"])
30
+ ]
31
+ response = llm.invoke(messages)
32
+ return {"answer": response.content, "log": [f"Assistant response: {response.content}"]}
33
+
34
+ # Functions to interact with the API
35
+ def get_all_questions():
36
+ """Fetch all questions from the API"""
37
+ try:
38
+ response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
39
+ response.raise_for_status()
40
+ return response.json()
41
+ except requests.exceptions.RequestException as e:
42
+ print(f"Error fetching questions: {e}")
43
+ return []
44
+
45
+ def get_random_question():
46
+ """Fetch a random question from the API"""
47
+ try:
48
+ response = requests.get(f"{DEFAULT_API_URL}/random-question", timeout=15)
49
+ response.raise_for_status()
50
+ return response.json()
51
+ except requests.exceptions.RequestException as e:
52
+ print(f"Error fetching random question: {e}")
53
+ return None
54
+
55
+ def get_file_for_task(task_id: str):
56
+ """Download file associated with a task ID"""
57
+ try:
58
+ response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=30)
59
+ response.raise_for_status()
60
+ return response.content
61
+ except requests.exceptions.RequestException as e:
62
+ print(f"Error fetching file for task {task_id}: {e}")
63
+ return None
64
+
65
+ def submit_answers(username: str, agent_code: str, answers: List[Dict]):
66
+ """Submit answers to the API"""
67
+ submission_data = {
68
+ "username": username,
69
+ "agent_code": agent_code,
70
+ "answers": answers
71
+ }
72
+ try:
73
+ response = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60)
74
+ response.raise_for_status()
75
+ return response.json()
76
+ except requests.exceptions.RequestException as e:
77
+ print(f"Error submitting answers: {e}")
78
+ return None
79
+
80
+ # Build the graph
81
+ graph = StateGraph(AgentState)
82
+ graph.add_node("assistant", assistant)
83
+ graph.add_edge(START, "assistant")
84
+ graph.add_edge("assistant", END)
85
+ app = graph.compile()
86
+
87
+ def run_agent(question: str, task_id: str):
88
+ """Run the agent on a single question"""
89
+ state = {"question": question, "task_id": task_id, "log": []}
90
+ return app.invoke(state)
91
+
92
+ def run_agent_on_all_questions():
93
+ """Run the agent on all questions from the API"""
94
+ print("Fetching all questions...")
95
+ questions = get_all_questions()
96
+
97
+ if not questions:
98
+ print("No questions found or error occurred")
99
+ return
100
+
101
+ print(f"Found {len(questions)} questions")
102
+ results = []
103
+
104
+ for i, question_data in enumerate(questions):
105
+ task_id = question_data.get("task_id")
106
+ question_text = question_data.get("question")
107
+
108
+ if not task_id or not question_text:
109
+ print(f"Skipping malformed question {i}")
110
+ continue
111
+
112
+ print(f"\nProcessing question {i+1}/{len(questions)}")
113
+ print(f"Task ID: {task_id}")
114
+ print(f"Question: {question_text[:100]}...")
115
+
116
+ # Run the agent
117
+ result = run_agent(question_text, task_id)
118
+
119
+ results.append({
120
+ "task_id": task_id,
121
+ "question": question_text,
122
+ "answer": result["answer"],
123
+ "log": result["log"]
124
+ })
125
+
126
+ print(f"Answer: {result['answer']}")
127
+
128
+ return results
129
+
130
+ def demo_single_question():
131
+ """Demo with a single random question"""
132
+ print("Fetching a random question...")
133
+ question_data = get_random_question()
134
+
135
+ if not question_data:
136
+ print("Could not fetch random question")
137
+ return
138
+
139
+ task_id = question_data.get("task_id")
140
+ question_text = question_data.get("question")
141
+
142
+ print(f"Task ID: {task_id}")
143
+ print(f"Question: {question_text}")
144
+
145
+ # Run the agent
146
+ result = run_agent(question_text, task_id)
147
+
148
+ print(f"\nAnswer: {result['answer']}")
149
+ print(f"Log: {result['log']}")
150
+
151
+ return result
152
+
153
+ if __name__ == "__main__":
154
+ # Option 1: Test with a single random question
155
+ # print("=== Testing with Random Question ===")
156
+ # demo_single_question()
157
+
158
+ # print("\n" + "="*50 + "\n")
159
+
160
+ # Option 2: Run on all questions (commented out for now)
161
+ print("=== Running on All Questions ===")
162
+ results = run_agent_on_all_questions()
163
+
164
+ # Save results to file
165
+ if results:
166
+ with open('agent_results.json', 'w') as f:
167
+ json.dump(results, f, indent=2)
168
+ print(f"\nResults saved to agent_results.json")
169
+
170
+ # Option 3: Manual question for testing
171
+ print("=== Manual Test ===")
172
+ manual_question = "What is the capital of France?"
173
+ manual_task_id = "test-123"
174
+ manual_result = run_agent(manual_question, manual_task_id)
175
+ print(f"Question: {manual_question}")
176
+ print(f"Answer: {manual_result['answer']}")