Spaces:
Build error
Build error
# app.py (Refactored for Improved Performance) | |
import os | |
import re | |
import gradio as gr | |
import requests | |
import pandas as pd | |
import logging | |
import numexpr | |
from typing import TypedDict, Annotated | |
# --- Langchain & HF Imports (Modern and Correct) --- | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.tools import tool | |
from langgraph.graph import StateGraph, END | |
from langgraph.errors import GraphRecursionError | |
from langchain_community.document_loaders.youtube import YoutubeLoader | |
from transformers import pipeline as hf_pipeline # Renamed to avoid conflict | |
# --- Constants --- | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
### --- REFACTOR 1: A much stricter and more detailed System Prompt --- ### | |
SYSTEM_PROMPT = """You are GAIA, a powerful expert assistant. You are designed to answer questions accurately and efficiently by using a set of available tools. | |
**Your STRICT Process:** | |
1. **Analyze the User's Question:** Carefully determine the user's intent and what information is needed. | |
2. **Tool Selection and Execution:** | |
* **Is a tool necessary?** | |
* For questions about public information, facts, current events, statistics, people, companies, etc., you **MUST** use the `web_search` tool. Do not rely on your internal knowledge. | |
* If the question includes a URL pointing to an image (`.png`, `.jpg`, etc.), you **MUST** use the `image_analyzer` tool. | |
* If the question includes a YouTube URL, you **MUST** use the `youtube_transcript_reader` tool. | |
* If the question requires a calculation, you **MUST** use the `math_calculator` tool. | |
* If the question is a simple logic puzzle, riddle, or language task you can solve directly, you do not need a tool. | |
* **Tool Call Format:** To use a tool, you **MUST** respond with **only** the tool call on a single line. Do not add any other text or explanation. | |
* Example: `web_search("How many albums did Mercedes Sosa release after 2000?")` | |
3. **Analyze Tool Output:** | |
* Review the information returned by the tool. | |
* If you have enough information to answer the user's question, proceed to the final step. | |
* If you need more information, you may use another tool. | |
4. **Final Answer:** | |
* Once you have a definitive answer, you **MUST** format it as follows, and nothing else: | |
`FINAL ANSWER: [Your concise and accurate answer]` | |
""" | |
# --- Tool Definitions (Unchanged) --- | |
image_to_text_pipeline = None | |
def web_search(query: str) -> str: | |
"""Searches the web using DuckDuckGo for up-to-date information.""" | |
logging.info(f"--- Calling Web Search Tool with query: {query} ---") | |
search = DuckDuckGoSearchRun() | |
return search.run(query) | |
def math_calculator(expression: str) -> str: | |
"""Calculates the result of a mathematical expression.""" | |
logging.info(f"--- Calling Math Calculator Tool with expression: {expression} ---") | |
try: | |
if not re.match(r"^[0-9\.\+\-\*\/\(\)\s]+$", expression): | |
return "Error: Invalid characters in expression." | |
result = numexpr.evaluate(expression).item() | |
return str(result) | |
except Exception as e: | |
return f"Error: {e}" | |
def image_analyzer(image_url: str) -> str: | |
"""Analyzes an image from a URL and returns a text description.""" | |
global image_to_text_pipeline | |
logging.info(f"--- Calling Image Analyzer Tool with URL: {image_url} ---") | |
try: | |
if image_to_text_pipeline is None: | |
logging.info("--- Initializing Image Analyzer pipeline... ---") | |
image_to_text_pipeline = hf_pipeline( | |
"image-to-text", model="Salesforce/blip-image-captioning-base" | |
) | |
description = image_to_text_pipeline(image_url)[0].get( | |
"generated_text", "Error" | |
) | |
return description | |
except Exception as e: | |
return f"Error analyzing image: {e}" | |
def youtube_transcript_reader(youtube_url: str) -> str: | |
"""Reads the transcript of a YouTube video from its URL.""" | |
logging.info(f"--- Calling YouTube Transcript Reader with URL: {youtube_url} ---") | |
try: | |
loader = YoutubeLoader.from_youtube_url(youtube_url, add_video_info=False) | |
return " ".join([doc.page_content for doc in loader.load()])[:4000] | |
except Exception as e: | |
return f"Error reading YouTube transcript: {e}" | |
# --- Agent State & Graph (Unchanged) --- | |
class AgentState(TypedDict): | |
question: str | |
messages: Annotated[list, lambda x, y: x + y] | |
class GaiaAgent: | |
def __init__(self): | |
logging.info("Initializing GaiaAgent...") | |
self.tools = [ | |
web_search, | |
math_calculator, | |
image_analyzer, | |
youtube_transcript_reader, | |
] | |
# IMPORTANT: Make sure you have accepted the terms of use for this model on the Hugging Face Hub! | |
logging.info("Initializing LLM...") | |
llm = HuggingFaceEndpoint( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
temperature=0.1, | |
max_new_tokens=1024, | |
huggingface_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
) | |
prompt = PromptTemplate.from_template( | |
SYSTEM_PROMPT + "\n{messages}\n\nQuestion: {question}" | |
) | |
self.agent = prompt | llm | StrOutputParser() | |
self.graph = self._create_graph() | |
logging.info("GaiaAgent initialized successfully.") | |
def _call_agent(self, state: AgentState): | |
logging.info("--- Calling Agent ---") | |
response = self.agent.invoke(state) | |
return {"messages": [response]} | |
def _call_tools(self, state: AgentState): | |
logging.info("--- Calling Tools ---") | |
raw_tool_call = state["messages"][-1] | |
tool_call_match = re.search(r"(\w+)\s*\((.*?)\)", raw_tool_call, re.DOTALL) | |
if not tool_call_match: | |
logging.warning("No valid tool call found in agent response.") | |
return { | |
"messages": [ | |
"No valid tool call found. Please try again or provide a FINAL ANSWER." | |
] | |
} | |
tool_name = tool_call_match.group(1).strip() | |
tool_input_str = tool_call_match.group(2).strip().strip("'\"") | |
tool_to_call = next((t for t in self.tools if t.name == tool_name), None) | |
if tool_to_call: | |
try: | |
result = tool_to_call.run(tool_input_str) | |
return {"messages": [str(result)]} | |
except Exception as e: | |
return {"messages": [f"Error executing tool {tool_name}: {e}"]} | |
else: | |
return { | |
"messages": [ | |
f"Tool '{tool_name}' not found. Available tools: web_search, math_calculator, image_analyzer, youtube_transcript_reader." | |
] | |
} | |
def _decide_action(self, state: AgentState): | |
return "tools" if "FINAL ANSWER:" not in state["messages"][-1] else END | |
def _create_graph(self): | |
graph = StateGraph(AgentState) | |
graph.add_node("agent", self._call_agent) | |
graph.add_node("tools", self._call_tools) | |
graph.add_conditional_edges( | |
"agent", self._decide_action, {"tools": "tools", END: END} | |
) | |
graph.add_edge("tools", "agent") | |
graph.set_entry_point("agent") | |
return graph.compile() | |
def __call__(self, question: str) -> str: | |
logging.info(f"Agent received question: {question[:100]}...") | |
try: | |
initial_state = {"question": question, "messages": []} | |
### --- REFACTOR 3: Gracefully handle recursion errors --- ### | |
final_state = self.graph.invoke(initial_state, {"recursion_limit": 15}) | |
final_response = final_state["messages"][-1] | |
match = re.search( | |
r"FINAL ANSWER:\s*(.*)", final_response, re.IGNORECASE | re.DOTALL | |
) | |
return ( | |
match.group(1).strip() if match else "Could not determine final answer." | |
) | |
except GraphRecursionError: | |
logging.error("Agent got stuck in a loop.") | |
return "Agent Error: Stuck in a loop." | |
except Exception as e: | |
logging.error(f"Error during agent invocation: {e}", exc_info=True) | |
return f"Error: {e}" | |
# --- Main Application Logic (Unchanged) --- | |
def run_and_submit_all(profile: gr.OAuthProfile | None): | |
# This function is restored to run all questions. | |
# ... (The rest of this function and the Gradio UI code is the same as the last working version) ... | |
if not profile: | |
return "Please Login to Hugging Face.", None | |
username = profile.username | |
logging.info(f"User logged in: {username}") | |
space_id = os.getenv("SPACE_ID") | |
if not space_id: | |
space_id = "leofltt/HF_Agents_Final_Assignment" # Your fallback | |
logging.warning(f"SPACE_ID not found, using fallback: {space_id}") | |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
try: | |
agent = GaiaAgent() | |
except Exception as e: | |
return f"Fatal error initializing agent: {e}", None | |
logging.info("Fetching questions...") | |
try: | |
response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=20) | |
response.raise_for_status() | |
questions_data = response.json() | |
except Exception as e: | |
return f"Error fetching questions: {e}", None | |
logging.info( | |
f"FULL EVALUATION MODE: Processing all {len(questions_data)} questions..." | |
) | |
results_log = [] | |
answers_payload = [] | |
for i, item in enumerate(questions_data): | |
task_id = item.get("task_id") | |
question_text = item.get("question") | |
logging.info( | |
f"--- Processing question {i+1}/{len(questions_data)} (Task ID: {task_id}) ---" | |
) | |
try: | |
answer = agent(question_text) | |
answers_payload.append({"task_id": task_id, "submitted_answer": answer}) | |
results_log.append( | |
{ | |
"Task ID": task_id, | |
"Question": question_text, | |
"Submitted Answer": answer, | |
} | |
) | |
except Exception as e: | |
error_message = f"AGENT ERROR on task {task_id}: {e}" | |
logging.error(error_message, exc_info=True) | |
results_log.append( | |
{ | |
"Task ID": task_id, | |
"Question": question_text, | |
"Submitted Answer": error_message, | |
} | |
) | |
if not answers_payload: | |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
logging.info(f"Submitting {len(answers_payload)} answers...") | |
try: | |
submission_data = { | |
"username": username, | |
"agent_code": agent_code, | |
"answers": answers_payload, | |
} | |
response = requests.post( | |
f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60 | |
) | |
response.raise_for_status() | |
result_data = response.json() | |
status = f"Submission Successful!\nScore: {result_data.get('score', 'N/A')}%" | |
return status, pd.DataFrame(results_log) | |
except Exception as e: | |
return f"Submission Failed: {e}", pd.DataFrame(results_log) | |
# --- Gradio Interface (Unchanged) --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# GAIA Agent Evaluation Runner") | |
gr.LoginButton() | |
run_button = gr.Button("Run Full Evaluation & Submit All Answers") | |
status_output = gr.Textbox(label="Run Status / Result", lines=4) | |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True) | |
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table]) | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
demo.launch() | |