File size: 12,095 Bytes
2f87e44
f85ab70
10e9b7d
ba6c035
10e9b7d
eccf8e4
3c4371f
830c198
ba6c035
 
 
2f87e44
cc70c39
ba6c035
 
 
 
 
2f87e44
ba6c035
2f87e44
85d8289
e80aab9
3db6293
2f87e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba6c035
2ac3a83
 
ba6c035
 
 
 
 
 
2ac3a83
 
ba6c035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f87e44
ba6c035
 
 
2f87e44
 
 
ba6c035
 
 
 
 
 
 
 
2f87e44
ba6c035
 
2f87e44
ba6c035
2f87e44
ba6c035
 
2f87e44
ba6c035
 
 
 
 
 
 
 
 
 
 
 
 
 
cc70c39
2f87e44
3c0542a
cc70c39
7cfcba6
cc70c39
 
2f87e44
ba6c035
cc70c39
2f87e44
 
ba6c035
 
 
 
2ac3a83
ba6c035
 
2f87e44
 
2ac3a83
ba6c035
 
 
 
 
 
 
 
2f87e44
 
ba6c035
2f87e44
ba6c035
2f87e44
 
ba6c035
 
 
2f87e44
 
ba6c035
2f87e44
ba6c035
2f87e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac3a83
ba6c035
 
 
2f87e44
 
ba6c035
 
 
 
 
2f87e44
 
 
 
 
 
ba6c035
 
2f87e44
4021bf3
b795696
2f87e44
ba6c035
2f87e44
 
2ac3a83
91b0aca
2ac3a83
830c198
cc70c39
2ac3a83
cc70c39
91b0aca
 
cc70c39
91b0aca
cc70c39
31243f4
ba6c035
31243f4
f85ab70
cc70c39
91b0aca
eccf8e4
91b0aca
7d65c66
31243f4
f85ab70
85d8289
cc70c39
 
91b0aca
cc70c39
 
7d65c66
 
91b0aca
31243f4
 
91b0aca
 
 
31243f4
91b0aca
 
b795696
 
 
 
91b0aca
b795696
 
31243f4
91b0aca
 
b795696
 
 
 
91b0aca
b795696
 
cc70c39
31243f4
91b0aca
cc70c39
91b0aca
e80aab9
91b0aca
 
 
 
 
 
 
 
e80aab9
 
2f87e44
91b0aca
7d65c66
91b0aca
e80aab9
 
2f87e44
e80aab9
ba6c035
7e4a06b
2f87e44
 
 
ba6c035
e80aab9
 
830c198
f85ab70
830c198
ba6c035
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# app.py (Refactored for Improved Performance)

import os
import re
import gradio as gr
import requests
import pandas as pd
import logging
import numexpr
from typing import TypedDict, Annotated

# --- Langchain & HF Imports (Modern and Correct) ---
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.errors import GraphRecursionError
from langchain_community.document_loaders.youtube import YoutubeLoader
from transformers import pipeline as hf_pipeline  # Renamed to avoid conflict

# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

### --- REFACTOR 1: A much stricter and more detailed System Prompt --- ###
SYSTEM_PROMPT = """You are GAIA, a powerful expert assistant. You are designed to answer questions accurately and efficiently by using a set of available tools.

**Your STRICT Process:**

1.  **Analyze the User's Question:** Carefully determine the user's intent and what information is needed.

2.  **Tool Selection and Execution:**
    * **Is a tool necessary?**
        * For questions about public information, facts, current events, statistics, people, companies, etc., you **MUST** use the `web_search` tool. Do not rely on your internal knowledge.
        * If the question includes a URL pointing to an image (`.png`, `.jpg`, etc.), you **MUST** use the `image_analyzer` tool.
        * If the question includes a YouTube URL, you **MUST** use the `youtube_transcript_reader` tool.
        * If the question requires a calculation, you **MUST** use the `math_calculator` tool.
        * If the question is a simple logic puzzle, riddle, or language task you can solve directly, you do not need a tool.
    * **Tool Call Format:** To use a tool, you **MUST** respond with **only** the tool call on a single line. Do not add any other text or explanation.
        * Example: `web_search("How many albums did Mercedes Sosa release after 2000?")`

3.  **Analyze Tool Output:**
    * Review the information returned by the tool.
    * If you have enough information to answer the user's question, proceed to the final step.
    * If you need more information, you may use another tool.

4.  **Final Answer:**
    * Once you have a definitive answer, you **MUST** format it as follows, and nothing else:
        `FINAL ANSWER: [Your concise and accurate answer]`
"""

# --- Tool Definitions (Unchanged) ---
image_to_text_pipeline = None


@tool
def web_search(query: str) -> str:
    """Searches the web using DuckDuckGo for up-to-date information."""
    logging.info(f"--- Calling Web Search Tool with query: {query} ---")
    search = DuckDuckGoSearchRun()
    return search.run(query)


@tool
def math_calculator(expression: str) -> str:
    """Calculates the result of a mathematical expression."""
    logging.info(f"--- Calling Math Calculator Tool with expression: {expression} ---")
    try:
        if not re.match(r"^[0-9\.\+\-\*\/\(\)\s]+$", expression):
            return "Error: Invalid characters in expression."
        result = numexpr.evaluate(expression).item()
        return str(result)
    except Exception as e:
        return f"Error: {e}"


@tool
def image_analyzer(image_url: str) -> str:
    """Analyzes an image from a URL and returns a text description."""
    global image_to_text_pipeline
    logging.info(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
    try:
        if image_to_text_pipeline is None:
            logging.info("--- Initializing Image Analyzer pipeline... ---")
            image_to_text_pipeline = hf_pipeline(
                "image-to-text", model="Salesforce/blip-image-captioning-base"
            )
        description = image_to_text_pipeline(image_url)[0].get(
            "generated_text", "Error"
        )
        return description
    except Exception as e:
        return f"Error analyzing image: {e}"


@tool
def youtube_transcript_reader(youtube_url: str) -> str:
    """Reads the transcript of a YouTube video from its URL."""
    logging.info(f"--- Calling YouTube Transcript Reader with URL: {youtube_url} ---")
    try:
        loader = YoutubeLoader.from_youtube_url(youtube_url, add_video_info=False)
        return " ".join([doc.page_content for doc in loader.load()])[:4000]
    except Exception as e:
        return f"Error reading YouTube transcript: {e}"


# --- Agent State & Graph (Unchanged) ---
class AgentState(TypedDict):
    question: str
    messages: Annotated[list, lambda x, y: x + y]


class GaiaAgent:
    def __init__(self):
        logging.info("Initializing GaiaAgent...")
        self.tools = [
            web_search,
            math_calculator,
            image_analyzer,
            youtube_transcript_reader,
        ]

        # IMPORTANT: Make sure you have accepted the terms of use for this model on the Hugging Face Hub!
        logging.info("Initializing LLM...")
        llm = HuggingFaceEndpoint(
            repo_id="HuggingFaceH4/zephyr-7b-beta",
            temperature=0.1,
            max_new_tokens=1024,
            huggingface_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
        )

        prompt = PromptTemplate.from_template(
            SYSTEM_PROMPT + "\n{messages}\n\nQuestion: {question}"
        )
        self.agent = prompt | llm | StrOutputParser()
        self.graph = self._create_graph()
        logging.info("GaiaAgent initialized successfully.")

    def _call_agent(self, state: AgentState):
        logging.info("--- Calling Agent ---")
        response = self.agent.invoke(state)
        return {"messages": [response]}

    def _call_tools(self, state: AgentState):
        logging.info("--- Calling Tools ---")
        raw_tool_call = state["messages"][-1]
        tool_call_match = re.search(r"(\w+)\s*\((.*?)\)", raw_tool_call, re.DOTALL)
        if not tool_call_match:
            logging.warning("No valid tool call found in agent response.")
            return {
                "messages": [
                    "No valid tool call found. Please try again or provide a FINAL ANSWER."
                ]
            }

        tool_name = tool_call_match.group(1).strip()
        tool_input_str = tool_call_match.group(2).strip().strip("'\"")

        tool_to_call = next((t for t in self.tools if t.name == tool_name), None)
        if tool_to_call:
            try:
                result = tool_to_call.run(tool_input_str)
                return {"messages": [str(result)]}
            except Exception as e:
                return {"messages": [f"Error executing tool {tool_name}: {e}"]}
        else:
            return {
                "messages": [
                    f"Tool '{tool_name}' not found. Available tools: web_search, math_calculator, image_analyzer, youtube_transcript_reader."
                ]
            }

    def _decide_action(self, state: AgentState):
        return "tools" if "FINAL ANSWER:" not in state["messages"][-1] else END

    def _create_graph(self):
        graph = StateGraph(AgentState)
        graph.add_node("agent", self._call_agent)
        graph.add_node("tools", self._call_tools)
        graph.add_conditional_edges(
            "agent", self._decide_action, {"tools": "tools", END: END}
        )
        graph.add_edge("tools", "agent")
        graph.set_entry_point("agent")
        return graph.compile()

    def __call__(self, question: str) -> str:
        logging.info(f"Agent received question: {question[:100]}...")
        try:
            initial_state = {"question": question, "messages": []}
            ### --- REFACTOR 3: Gracefully handle recursion errors --- ###
            final_state = self.graph.invoke(initial_state, {"recursion_limit": 15})
            final_response = final_state["messages"][-1]
            match = re.search(
                r"FINAL ANSWER:\s*(.*)", final_response, re.IGNORECASE | re.DOTALL
            )
            return (
                match.group(1).strip() if match else "Could not determine final answer."
            )
        except GraphRecursionError:
            logging.error("Agent got stuck in a loop.")
            return "Agent Error: Stuck in a loop."
        except Exception as e:
            logging.error(f"Error during agent invocation: {e}", exc_info=True)
            return f"Error: {e}"


# --- Main Application Logic (Unchanged) ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
    # This function is restored to run all questions.
    # ... (The rest of this function and the Gradio UI code is the same as the last working version) ...
    if not profile:
        return "Please Login to Hugging Face.", None
    username = profile.username
    logging.info(f"User logged in: {username}")

    space_id = os.getenv("SPACE_ID")
    if not space_id:
        space_id = "leofltt/HF_Agents_Final_Assignment"  # Your fallback
        logging.warning(f"SPACE_ID not found, using fallback: {space_id}")

    agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"

    try:
        agent = GaiaAgent()
    except Exception as e:
        return f"Fatal error initializing agent: {e}", None

    logging.info("Fetching questions...")
    try:
        response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=20)
        response.raise_for_status()
        questions_data = response.json()
    except Exception as e:
        return f"Error fetching questions: {e}", None

    logging.info(
        f"FULL EVALUATION MODE: Processing all {len(questions_data)} questions..."
    )

    results_log = []
    answers_payload = []
    for i, item in enumerate(questions_data):
        task_id = item.get("task_id")
        question_text = item.get("question")
        logging.info(
            f"--- Processing question {i+1}/{len(questions_data)} (Task ID: {task_id}) ---"
        )
        try:
            answer = agent(question_text)
            answers_payload.append({"task_id": task_id, "submitted_answer": answer})
            results_log.append(
                {
                    "Task ID": task_id,
                    "Question": question_text,
                    "Submitted Answer": answer,
                }
            )
        except Exception as e:
            error_message = f"AGENT ERROR on task {task_id}: {e}"
            logging.error(error_message, exc_info=True)
            results_log.append(
                {
                    "Task ID": task_id,
                    "Question": question_text,
                    "Submitted Answer": error_message,
                }
            )

    if not answers_payload:
        return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)

    logging.info(f"Submitting {len(answers_payload)} answers...")
    try:
        submission_data = {
            "username": username,
            "agent_code": agent_code,
            "answers": answers_payload,
        }
        response = requests.post(
            f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60
        )
        response.raise_for_status()
        result_data = response.json()
        status = f"Submission Successful!\nScore: {result_data.get('score', 'N/A')}%"
        return status, pd.DataFrame(results_log)
    except Exception as e:
        return f"Submission Failed: {e}", pd.DataFrame(results_log)


# --- Gradio Interface (Unchanged) ---
with gr.Blocks() as demo:
    gr.Markdown("# GAIA Agent Evaluation Runner")
    gr.LoginButton()
    run_button = gr.Button("Run Full Evaluation & Submit All Answers")
    status_output = gr.Textbox(label="Run Status / Result", lines=4)
    results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
    run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])

if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )
    demo.launch()