Spaces:
Runtime error
Runtime error
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.core.agent import ReActAgent | |
from llama_index.core.tools import FunctionTool | |
from transformers import AutoTokenizer, pipeline | |
import os | |
import gradio as gr | |
import requests | |
import pandas as pd | |
import traceback | |
import torch | |
import re | |
import gc | |
from typing import List, Dict | |
from datetime import datetime | |
# Import real tool dependencies | |
try: | |
from duckduckgo_search import DDGS | |
except ImportError: | |
print("Warning: duckduckgo_search not installed. Web search will be limited.") | |
DDGS = None | |
try: | |
from sympy import sympify | |
from sympy.core.sympify import SympifyError | |
except ImportError: | |
print("Warning: sympy not installed. Math calculator will be limited.") | |
sympify = None | |
SympifyError = Exception | |
# --- Constants --- | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
MEMORY_LIMIT_GB = 16 # Your system's memory limit | |
# --- Advanced Agent Definition --- | |
class SmartAgent: | |
def __init__(self): | |
print(f"Initializing Local LLM Agent (Memory Limit: {MEMORY_LIMIT_GB}GB)...") | |
self.model_loaded = False | |
# Model options sorted by capability (name, approx size in GB, quantization) | |
model_options = [ | |
("google/flan-t5-large", 3, "8-bit"), # Best balance for 16GB | |
("google/flan-t5-base", 1, "8-bit"), # Smaller fallback | |
("facebook/opt-1.3b", 2.5, "8-bit") # Alternative option | |
] | |
# Try loading models until success | |
for model_name, size_gb, quantization in model_options: | |
if size_gb <= MEMORY_LIMIT_GB and self._try_load_model(model_name, quantization): | |
self.model_loaded = True | |
break | |
if not self.model_loaded: | |
raise RuntimeError("Failed to load any suitable model within memory constraints") | |
# Initialize tools with enhanced implementations | |
self.tools = [ | |
FunctionTool.from_defaults( | |
fn=self.smart_web_search, | |
name="web_search", | |
description="Searches the web for current information. Use for questions about recent events, people, or facts not in the model's training data." | |
), | |
FunctionTool.from_defaults( | |
fn=self.robust_math_calculator, | |
name="math_calculator", | |
description="Solves mathematical expressions and equations. Use for calculations, arithmetic, algebra, or numerical problems." | |
) | |
] | |
# Initialize ReAct agent with memory optimization | |
try: | |
self.agent = ReActAgent.from_tools( | |
tools=self.tools, | |
llm=self.llm, | |
verbose=True, | |
max_iterations=4, | |
react_context="""Think step by step. Use tools when needed: | |
- For current/recent information: web_search | |
- For calculations: math_calculator | |
- Be concise but accurate""" | |
) | |
print("ReAct Agent initialized successfully") | |
except Exception as e: | |
print(f"ReAct Agent init failed: {e}") | |
self.agent = None | |
def _try_load_model(self, model_name: str, quantization: str) -> bool: | |
"""Attempt to load model with memory constraints""" | |
try: | |
print(f"Loading {model_name} with {quantization} quantization...") | |
model_kwargs = { | |
"torch_dtype": torch.float16, | |
"low_cpu_mem_usage": True, | |
} | |
if quantization == "8-bit": | |
model_kwargs["load_in_8bit"] = True | |
elif quantization == "4-bit": | |
model_kwargs["load_in_4bit"] = True | |
self.llm = HuggingFaceLLM( | |
model_name=model_name, | |
tokenizer_name=model_name, | |
context_window=2048, | |
max_new_tokens=256, | |
generate_kwargs={ | |
"temperature": 0.4, | |
"do_sample": True, | |
"top_p": 0.9, | |
"repetition_penalty": 1.1 | |
}, | |
device_map="auto" if torch.cuda.is_available() else "cpu", | |
model_kwargs=model_kwargs | |
) | |
# Test the model | |
test_response = self.llm.complete("Test response:") | |
if not test_response: | |
raise ValueError("Model failed test response") | |
print(f"Successfully loaded {model_name}") | |
return True | |
except Exception as e: | |
print(f"Failed to load {model_name}: {str(e)}") | |
self.cleanup_memory() | |
return False | |
def smart_web_search(self, query: str) -> str: | |
"""Enhanced web search with focused results""" | |
print(f"Searching: {query[:60]}...") | |
if not DDGS: | |
return "Web search unavailable (duckduckgo_search not installed)" | |
try: | |
with DDGS() as ddgs: | |
# Get focused results with longer snippets | |
results = list(ddgs.text(query, max_results=3)) | |
if not results: | |
return "No results found" | |
# Process results for key information | |
processed = [] | |
for i, res in enumerate(results, 1): | |
title = res.get('title', 'No title') | |
body = res.get('body', 'No description') | |
url = res.get('href', '') | |
# Extract most relevant part for the query | |
key_info = self._extract_relevant_info(query, body) | |
processed.append( | |
f"π Result {i}:\n" | |
f"Title: {title}\n" | |
f"Info: {key_info[:250]}\n" | |
f"Source: {url}\n" | |
) | |
return "\n".join(processed) | |
except Exception as e: | |
return f"Search error: {str(e)}" | |
def _extract_relevant_info(self, query: str, text: str) -> str: | |
"""Extract the most relevant portion of text for the query""" | |
query_lower = query.lower() | |
text_lower = text.lower() | |
# Handle different question types | |
if any(w in query_lower for w in ['who is', 'biography', 'born']): | |
# Look for birth/death info | |
match = re.search(r"(born [^.]+? in [^.]+?\.)", text, re.I) | |
return match.group(1) if match else text[:250] | |
elif any(w in query_lower for w in ['died', 'death']): | |
match = re.search(r"(died [^.]+?\.)", text, re.I) | |
return match.group(1) if match else text[:250] | |
elif any(w in query_lower for w in ['award', 'prize', 'won']): | |
match = re.search(r"(awarded [^.]+? in [^.]+?\.)", text, re.I) | |
return match.group(1) if match else text[:250] | |
# Default: return first 250 chars with important sentences | |
sentences = re.split(r'(?<=[.!?]) +', text) | |
important = [s for s in sentences if any(w in s.lower() for w in query.lower().split())] | |
return " ".join(important[:3]) if important else text[:250] | |
def robust_math_calculator(self, expression: str) -> str: | |
"""Improved math calculator with better parsing""" | |
print(f"Calculating: {expression}") | |
# Clean and preprocess the expression | |
expr = expression.strip("'\"") | |
# Replace words with operators | |
replacements = { | |
'plus': '+', 'minus': '-', 'times': '*', 'divided by': '/', | |
'^': '**', 'percent': '/100', 'modulo': '%' | |
} | |
for word, op in replacements.items(): | |
expr = expr.replace(word, op) | |
# Extract math expression from text | |
math_match = re.search(r"([-+]?\d*\.?\d+[+\-*/%^()\s]+\d+\.?\d*)", expr) | |
if math_match: | |
expr = math_match.group(1) | |
# Safety check | |
allowed_chars = set("0123456789+-*/().%^ ") | |
if not all(c in allowed_chars for c in expr.replace(" ", "")): | |
return "Error: Invalid characters in expression" | |
try: | |
# Try direct evaluation first | |
result = eval(expr) | |
return f"Result: {result}" | |
except: | |
# Fallback to sympy if available | |
if sympify: | |
try: | |
result = sympify(expr).evalf() | |
return f"Result: {result}" | |
except SympifyError as e: | |
return f"Math error: {str(e)}" | |
return "Error: Could not evaluate the expression" | |
def __call__(self, question: str) -> str: | |
"""Main interface for answering questions""" | |
print(f"\nQuestion: {question[:100]}...") | |
try: | |
# Step 1: Classify question type | |
q_type = self._classify_question(question) | |
# Step 2: Use appropriate strategy | |
if q_type == "fact": | |
return self._answer_fact_question(question) | |
elif q_type == "math": | |
return self._answer_math_question(question) | |
else: | |
return self._answer_general_question(question) | |
except Exception as e: | |
print(f"Error processing question: {str(e)}") | |
return self._fallback_response(question) | |
def _classify_question(self, question: str) -> str: | |
"""Determine the type of question""" | |
q_lower = question.lower() | |
# Math questions | |
math_keywords = ['calculate', 'compute', 'sum', 'total', 'average', | |
'percentage', 'equation', 'solve', 'math', 'number', | |
'+', '-', '*', '/', '='] | |
if any(kw in q_lower for kw in math_keywords): | |
return "math" | |
# Fact-based questions | |
fact_keywords = ['current', 'latest', 'recent', 'today', 'news', | |
'who is', 'what is', 'when did', 'where is', | |
'competition', 'winner', 'recipient', 'nationality', | |
'country', 'malko', 'century', 'award', 'born', 'died'] | |
if any(kw in q_lower for kw in fact_keywords): | |
return "fact" | |
return "general" | |
def _answer_fact_question(self, question: str) -> str: | |
"""Handle fact-based questions with web search""" | |
# Extract key entities for focused search | |
entities = re.findall(r"([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)", question) | |
search_query = " ".join(entities[:3]) or question[:50] | |
# Get search results | |
search_results = self.smart_web_search(search_query) | |
# Process with LLM if available | |
if self.model_loaded: | |
prompt = f"""Question: {question} | |
Search Results: | |
{search_results} | |
Based ONLY on these results, provide a concise answer. | |
If the answer isn't there, say so.""" | |
try: | |
response = self.llm.complete(prompt) | |
return str(response).strip() | |
except: | |
return f"Search results for '{search_query}':\n{search_results}" | |
return f"Search results for '{search_query}':\n{search_results}" | |
def _answer_math_question(self, question: str) -> str: | |
"""Handle math questions with calculator""" | |
# Try to extract math expression | |
math_expr = re.search(r"([\d\s+\-*/().^]+)", question) | |
if math_expr: | |
return self.robust_math_calculator(math_expr.group(1)) | |
# If no clear expression, use agent reasoning | |
if self.agent: | |
try: | |
response = self.agent.query(question) | |
return str(response) | |
except: | |
return self._fallback_response(question) | |
return self._fallback_response(question) | |
def _answer_general_question(self, question: str) -> str: | |
"""Handle general knowledge questions""" | |
if self.agent: | |
try: | |
response = self.agent.query(question) | |
return str(response) | |
except: | |
return self._fallback_response(question) | |
# Fallback to simple LLM response | |
try: | |
response = self.llm.complete(question) | |
return str(response) | |
except: | |
return self._fallback_response(question) | |
def _fallback_response(self, question: str) -> str: | |
"""Final fallback when all else fails""" | |
return f"I couldn't generate a complete answer for: {question[:150]}... Please try rephrasing or ask about something more specific." | |
def cleanup_memory(self): | |
"""Clean up memory resources""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# --- Submission Logic --- | |
def run_and_submit_all(profile: gr.OAuthProfile | None): | |
"""Handle the full evaluation process""" | |
space_id = os.getenv("SPACE_ID") | |
if profile: | |
username = f"{profile.username}" | |
print(f"User logged in: {username}") | |
else: | |
print("User not logged in.") | |
return "Please Login to Hugging Face with the button.", None | |
api_url = DEFAULT_API_URL | |
questions_url = f"{api_url}/questions" | |
submit_url = f"{api_url}/submit" | |
# Initialize agent with memory management | |
try: | |
agent = SmartAgent() | |
except Exception as e: | |
print(f"Agent initialization failed: {e}") | |
return f"Error initializing agent: {e}", None | |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
print(f"Agent code URL: {agent_code}") | |
# Fetch Questions | |
print(f"Fetching questions from: {questions_url}") | |
try: | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions_data = response.json() | |
if not questions_data: | |
return "No questions received from server.", None | |
print(f"Fetched {len(questions_data)} questions.") | |
except Exception as e: | |
return f"Error fetching questions: {e}", None | |
# Process Questions | |
results_log = [] | |
answers_payload = [] | |
for i, item in enumerate(questions_data, 1): | |
task_id = item.get("task_id") | |
question = item.get("question") | |
if not task_id or not question: | |
continue | |
print(f"Processing question {i}/{len(questions_data)} (ID: {task_id})") | |
try: | |
answer = agent(question) | |
answers_payload.append({ | |
"task_id": task_id, | |
"submitted_answer": answer[:2000] # Limit answer length | |
}) | |
results_log.append({ | |
"Task ID": task_id, | |
"Question": question[:100] + "..." if len(question) > 100 else question, | |
"Answer": answer[:200] + "..." if len(answer) > 200 else answer | |
}) | |
# Clean memory every 5 questions | |
if i % 5 == 0: | |
agent.cleanup_memory() | |
except Exception as e: | |
print(f"Error on question {task_id}: {e}") | |
answers_payload.append({ | |
"task_id": task_id, | |
"submitted_answer": f"Error processing question: {str(e)}" | |
}) | |
results_log.append({ | |
"Task ID": task_id, | |
"Question": question[:100] + "..." if len(question) > 100 else question, | |
"Answer": f"Error: {str(e)}" | |
}) | |
# Submit Answers | |
submission_data = { | |
"username": username.strip(), | |
"agent_code": agent_code, | |
"answers": answers_payload | |
} | |
print(f"Submitting {len(answers_payload)} answers...") | |
try: | |
response = requests.post(submit_url, json=submission_data, timeout=60) | |
response.raise_for_status() | |
result = response.json() | |
status = ( | |
f"β Submission Successful!\n\n" | |
f"User: {result.get('username')}\n" | |
f"Score: {result.get('score', 'N/A')}% " | |
f"({result.get('correct_count', '?')}/{result.get('total_attempted', '?')})\n" | |
f"Message: {result.get('message', '')}" | |
) | |
return status, pd.DataFrame(results_log) | |
except Exception as e: | |
error_msg = f"β Submission Failed: {str(e)}" | |
print(error_msg) | |
return error_msg, pd.DataFrame(results_log) | |
# --- Gradio UI --- | |
with gr.Blocks(title="Local LLM Agent Evaluation") as demo: | |
gr.Markdown(""" | |
# οΏ½ Local LLM Agent Evaluation | |
**Run your local agent against the course evaluation questions** | |
""") | |
with gr.Row(): | |
gr.LoginButton() | |
run_btn = gr.Button( | |
"π Run Evaluation & Submit Answers", | |
variant="primary" | |
) | |
status_out = gr.Textbox( | |
label="π Status", | |
interactive=False | |
) | |
results_table = gr.DataFrame( | |
label="π Results", | |
interactive=False, | |
wrap=True | |
) | |
run_btn.click( | |
fn=run_and_submit_all, | |
outputs=[status_out, results_table] | |
) | |
if __name__ == "__main__": | |
print("\n" + "="*60) | |
print(f"π Starting Agent Evaluation - {datetime.now().strftime('%Y-%m-%d %H:%M')}") | |
print(f"Memory Limit: {MEMORY_LIMIT_GB}GB") | |
print("="*60) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |