leofltt's picture
bac k to zepr `._.
7cfcba6
# app.py (Refactored for Improved Performance)
import os
import re
import gradio as gr
import requests
import pandas as pd
import logging
import numexpr
from typing import TypedDict, Annotated
# --- Langchain & HF Imports (Modern and Correct) ---
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.errors import GraphRecursionError
from langchain_community.document_loaders.youtube import YoutubeLoader
from transformers import pipeline as hf_pipeline # Renamed to avoid conflict
# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
### --- REFACTOR 1: A much stricter and more detailed System Prompt --- ###
SYSTEM_PROMPT = """You are GAIA, a powerful expert assistant. You are designed to answer questions accurately and efficiently by using a set of available tools.
**Your STRICT Process:**
1. **Analyze the User's Question:** Carefully determine the user's intent and what information is needed.
2. **Tool Selection and Execution:**
* **Is a tool necessary?**
* For questions about public information, facts, current events, statistics, people, companies, etc., you **MUST** use the `web_search` tool. Do not rely on your internal knowledge.
* If the question includes a URL pointing to an image (`.png`, `.jpg`, etc.), you **MUST** use the `image_analyzer` tool.
* If the question includes a YouTube URL, you **MUST** use the `youtube_transcript_reader` tool.
* If the question requires a calculation, you **MUST** use the `math_calculator` tool.
* If the question is a simple logic puzzle, riddle, or language task you can solve directly, you do not need a tool.
* **Tool Call Format:** To use a tool, you **MUST** respond with **only** the tool call on a single line. Do not add any other text or explanation.
* Example: `web_search("How many albums did Mercedes Sosa release after 2000?")`
3. **Analyze Tool Output:**
* Review the information returned by the tool.
* If you have enough information to answer the user's question, proceed to the final step.
* If you need more information, you may use another tool.
4. **Final Answer:**
* Once you have a definitive answer, you **MUST** format it as follows, and nothing else:
`FINAL ANSWER: [Your concise and accurate answer]`
"""
# --- Tool Definitions (Unchanged) ---
image_to_text_pipeline = None
@tool
def web_search(query: str) -> str:
"""Searches the web using DuckDuckGo for up-to-date information."""
logging.info(f"--- Calling Web Search Tool with query: {query} ---")
search = DuckDuckGoSearchRun()
return search.run(query)
@tool
def math_calculator(expression: str) -> str:
"""Calculates the result of a mathematical expression."""
logging.info(f"--- Calling Math Calculator Tool with expression: {expression} ---")
try:
if not re.match(r"^[0-9\.\+\-\*\/\(\)\s]+$", expression):
return "Error: Invalid characters in expression."
result = numexpr.evaluate(expression).item()
return str(result)
except Exception as e:
return f"Error: {e}"
@tool
def image_analyzer(image_url: str) -> str:
"""Analyzes an image from a URL and returns a text description."""
global image_to_text_pipeline
logging.info(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
try:
if image_to_text_pipeline is None:
logging.info("--- Initializing Image Analyzer pipeline... ---")
image_to_text_pipeline = hf_pipeline(
"image-to-text", model="Salesforce/blip-image-captioning-base"
)
description = image_to_text_pipeline(image_url)[0].get(
"generated_text", "Error"
)
return description
except Exception as e:
return f"Error analyzing image: {e}"
@tool
def youtube_transcript_reader(youtube_url: str) -> str:
"""Reads the transcript of a YouTube video from its URL."""
logging.info(f"--- Calling YouTube Transcript Reader with URL: {youtube_url} ---")
try:
loader = YoutubeLoader.from_youtube_url(youtube_url, add_video_info=False)
return " ".join([doc.page_content for doc in loader.load()])[:4000]
except Exception as e:
return f"Error reading YouTube transcript: {e}"
# --- Agent State & Graph (Unchanged) ---
class AgentState(TypedDict):
question: str
messages: Annotated[list, lambda x, y: x + y]
class GaiaAgent:
def __init__(self):
logging.info("Initializing GaiaAgent...")
self.tools = [
web_search,
math_calculator,
image_analyzer,
youtube_transcript_reader,
]
# IMPORTANT: Make sure you have accepted the terms of use for this model on the Hugging Face Hub!
logging.info("Initializing LLM...")
llm = HuggingFaceEndpoint(
repo_id="HuggingFaceH4/zephyr-7b-beta",
temperature=0.1,
max_new_tokens=1024,
huggingface_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)
prompt = PromptTemplate.from_template(
SYSTEM_PROMPT + "\n{messages}\n\nQuestion: {question}"
)
self.agent = prompt | llm | StrOutputParser()
self.graph = self._create_graph()
logging.info("GaiaAgent initialized successfully.")
def _call_agent(self, state: AgentState):
logging.info("--- Calling Agent ---")
response = self.agent.invoke(state)
return {"messages": [response]}
def _call_tools(self, state: AgentState):
logging.info("--- Calling Tools ---")
raw_tool_call = state["messages"][-1]
tool_call_match = re.search(r"(\w+)\s*\((.*?)\)", raw_tool_call, re.DOTALL)
if not tool_call_match:
logging.warning("No valid tool call found in agent response.")
return {
"messages": [
"No valid tool call found. Please try again or provide a FINAL ANSWER."
]
}
tool_name = tool_call_match.group(1).strip()
tool_input_str = tool_call_match.group(2).strip().strip("'\"")
tool_to_call = next((t for t in self.tools if t.name == tool_name), None)
if tool_to_call:
try:
result = tool_to_call.run(tool_input_str)
return {"messages": [str(result)]}
except Exception as e:
return {"messages": [f"Error executing tool {tool_name}: {e}"]}
else:
return {
"messages": [
f"Tool '{tool_name}' not found. Available tools: web_search, math_calculator, image_analyzer, youtube_transcript_reader."
]
}
def _decide_action(self, state: AgentState):
return "tools" if "FINAL ANSWER:" not in state["messages"][-1] else END
def _create_graph(self):
graph = StateGraph(AgentState)
graph.add_node("agent", self._call_agent)
graph.add_node("tools", self._call_tools)
graph.add_conditional_edges(
"agent", self._decide_action, {"tools": "tools", END: END}
)
graph.add_edge("tools", "agent")
graph.set_entry_point("agent")
return graph.compile()
def __call__(self, question: str) -> str:
logging.info(f"Agent received question: {question[:100]}...")
try:
initial_state = {"question": question, "messages": []}
### --- REFACTOR 3: Gracefully handle recursion errors --- ###
final_state = self.graph.invoke(initial_state, {"recursion_limit": 15})
final_response = final_state["messages"][-1]
match = re.search(
r"FINAL ANSWER:\s*(.*)", final_response, re.IGNORECASE | re.DOTALL
)
return (
match.group(1).strip() if match else "Could not determine final answer."
)
except GraphRecursionError:
logging.error("Agent got stuck in a loop.")
return "Agent Error: Stuck in a loop."
except Exception as e:
logging.error(f"Error during agent invocation: {e}", exc_info=True)
return f"Error: {e}"
# --- Main Application Logic (Unchanged) ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
# This function is restored to run all questions.
# ... (The rest of this function and the Gradio UI code is the same as the last working version) ...
if not profile:
return "Please Login to Hugging Face.", None
username = profile.username
logging.info(f"User logged in: {username}")
space_id = os.getenv("SPACE_ID")
if not space_id:
space_id = "leofltt/HF_Agents_Final_Assignment" # Your fallback
logging.warning(f"SPACE_ID not found, using fallback: {space_id}")
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
try:
agent = GaiaAgent()
except Exception as e:
return f"Fatal error initializing agent: {e}", None
logging.info("Fetching questions...")
try:
response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=20)
response.raise_for_status()
questions_data = response.json()
except Exception as e:
return f"Error fetching questions: {e}", None
logging.info(
f"FULL EVALUATION MODE: Processing all {len(questions_data)} questions..."
)
results_log = []
answers_payload = []
for i, item in enumerate(questions_data):
task_id = item.get("task_id")
question_text = item.get("question")
logging.info(
f"--- Processing question {i+1}/{len(questions_data)} (Task ID: {task_id}) ---"
)
try:
answer = agent(question_text)
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
results_log.append(
{
"Task ID": task_id,
"Question": question_text,
"Submitted Answer": answer,
}
)
except Exception as e:
error_message = f"AGENT ERROR on task {task_id}: {e}"
logging.error(error_message, exc_info=True)
results_log.append(
{
"Task ID": task_id,
"Question": question_text,
"Submitted Answer": error_message,
}
)
if not answers_payload:
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
logging.info(f"Submitting {len(answers_payload)} answers...")
try:
submission_data = {
"username": username,
"agent_code": agent_code,
"answers": answers_payload,
}
response = requests.post(
f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60
)
response.raise_for_status()
result_data = response.json()
status = f"Submission Successful!\nScore: {result_data.get('score', 'N/A')}%"
return status, pd.DataFrame(results_log)
except Exception as e:
return f"Submission Failed: {e}", pd.DataFrame(results_log)
# --- Gradio Interface (Unchanged) ---
with gr.Blocks() as demo:
gr.Markdown("# GAIA Agent Evaluation Runner")
gr.LoginButton()
run_button = gr.Button("Run Full Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Result", lines=4)
results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
demo.launch()