File size: 6,379 Bytes
e61be93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from typing import Any, Dict, List, Optional, TypedDict, Annotated
import operator
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START, END
import os
import requests
import json
from dotenv import load_dotenv

load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY)

DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

class AgentState(TypedDict):
    question: str
    answer: str
    task_id: str
    log: Annotated[List[str], operator.add]

def assistant(state: AgentState) -> AgentState:
    messages = [
        SystemMessage(content="You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: <your answer here>. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."),
        HumanMessage(content=state["question"])
    ]
    response = llm.invoke(messages)
    return {"answer": response.content, "log": [f"Assistant response: {response.content}"]}

# Functions to interact with the API
def get_all_questions():
    """Fetch all questions from the API"""
    try:
        response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error fetching questions: {e}")
        return []

def get_random_question():
    """Fetch a random question from the API"""
    try:
        response = requests.get(f"{DEFAULT_API_URL}/random-question", timeout=15)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error fetching random question: {e}")
        return None

def get_file_for_task(task_id: str):
    """Download file associated with a task ID"""
    try:
        response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=30)
        response.raise_for_status()
        return response.content
    except requests.exceptions.RequestException as e:
        print(f"Error fetching file for task {task_id}: {e}")
        return None

def submit_answers(username: str, agent_code: str, answers: List[Dict]):
    """Submit answers to the API"""
    submission_data = {
        "username": username,
        "agent_code": agent_code,
        "answers": answers
    }
    try:
        response = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60)
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as e:
        print(f"Error submitting answers: {e}")
        return None

# Build the graph
graph = StateGraph(AgentState)
graph.add_node("assistant", assistant)
graph.add_edge(START, "assistant")
graph.add_edge("assistant", END)
app = graph.compile()

def run_agent(question: str, task_id: str):
    """Run the agent on a single question"""
    state = {"question": question, "task_id": task_id, "log": []}
    return app.invoke(state)

def run_agent_on_all_questions():
    """Run the agent on all questions from the API"""
    print("Fetching all questions...")
    questions = get_all_questions()
    
    if not questions:
        print("No questions found or error occurred")
        return
    
    print(f"Found {len(questions)} questions")
    results = []
    
    for i, question_data in enumerate(questions):
        task_id = question_data.get("task_id")
        question_text = question_data.get("question")
        
        if not task_id or not question_text:
            print(f"Skipping malformed question {i}")
            continue
        
        print(f"\nProcessing question {i+1}/{len(questions)}")
        print(f"Task ID: {task_id}")
        print(f"Question: {question_text[:100]}...")
        
        # Run the agent
        result = run_agent(question_text, task_id)
        
        results.append({
            "task_id": task_id,
            "question": question_text,
            "answer": result["answer"],
            "log": result["log"]
        })
        
        print(f"Answer: {result['answer']}")
    
    return results

def demo_single_question():
    """Demo with a single random question"""
    print("Fetching a random question...")
    question_data = get_random_question()
    
    if not question_data:
        print("Could not fetch random question")
        return
    
    task_id = question_data.get("task_id")
    question_text = question_data.get("question")
    
    print(f"Task ID: {task_id}")
    print(f"Question: {question_text}")
    
    # Run the agent
    result = run_agent(question_text, task_id)
    
    print(f"\nAnswer: {result['answer']}")
    print(f"Log: {result['log']}")
    
    return result

if __name__ == "__main__":
    # Option 1: Test with a single random question
    # print("=== Testing with Random Question ===")
    # demo_single_question()
    
    # print("\n" + "="*50 + "\n")
    
    # Option 2: Run on all questions (commented out for now)
    print("=== Running on All Questions ===")
    results = run_agent_on_all_questions()
    
    # Save results to file
    if results:
        with open('agent_results.json', 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to agent_results.json")
    
    # Option 3: Manual question for testing
    print("=== Manual Test ===")
    manual_question = "What is the capital of France?"
    manual_task_id = "test-123"
    manual_result = run_agent(manual_question, manual_task_id)
    print(f"Question: {manual_question}")
    print(f"Answer: {manual_result['answer']}")