demo / backend /routes /benchmark.py
tfrere's picture
update error message and avoid double benchmark generation
7f7e436
raw
history blame
7.72 kB
from fastapi import APIRouter, HTTPException
from typing import Dict, Any
import os
import time
from tasks.create_bench_config_file import CreateBenchConfigTask
from tasks.create_bench import CreateBenchTask
router = APIRouter(tags=["benchmark"])
# Store active tasks by session_id (importé dans main.py)
active_tasks = {}
# Référence aux session_files (sera fournie par main.py)
# Cette déclaration sera écrasée par l'affectation dans __init__.py
session_files = {}
@router.post("/generate-benchmark")
async def generate_benchmark(data: Dict[str, Any]):
"""
Generate a benchmark configuration and run the ingestion process
Args:
data: Dictionary containing session_id
Returns:
Dictionary with logs and status
"""
session_id = data.get("session_id")
# Debug to check session_files and received session_id
print(f"DEBUG: Session ID received: {session_id}")
print(f"DEBUG: Available session files: {list(router.session_files.keys())}")
if not session_id or session_id not in router.session_files:
return {"error": "Invalid or missing session ID"}
# Vérifier si un benchmark est déjà en cours ou complété pour cette session
if session_id in active_tasks:
task = active_tasks[session_id]
# Si le benchmark est déjà terminé, retourner les logs existants
if task.is_task_completed():
return {
"status": "already_completed",
"logs": task.get_logs(),
"is_completed": True
}
# Si le benchmark est en cours d'exécution, retourner les logs actuels
else:
return {
"status": "already_running",
"logs": task.get_logs(),
"is_completed": False
}
file_path = router.session_files[session_id]
all_logs = []
try:
# Initialiser la tâche qui gérera tout le processus
task = UnifiedBenchmarkTask(session_uid=session_id)
# Stockage pour récupération ultérieure des logs
active_tasks[session_id] = task
# Démarrer le processus de benchmark
task.run(file_path)
# Récupérer les logs initiaux
all_logs = task.get_logs()
return {
"status": "running",
"logs": all_logs
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"logs": all_logs
}
@router.get("/benchmark-progress/{session_id}")
async def get_benchmark_progress(session_id: str):
"""
Get the logs and status for a running benchmark task
Args:
session_id: Session ID for the task
Returns:
Dictionary with logs and completion status
"""
if session_id not in active_tasks:
raise HTTPException(status_code=404, detail="Benchmark task not found")
task = active_tasks[session_id]
logs = task.get_logs()
is_completed = task.is_task_completed()
return {
"logs": logs,
"is_completed": is_completed
}
# Créer une classe qui unifie le processus de benchmark
class UnifiedBenchmarkTask:
"""
Task that handles the entire benchmark process from configuration to completion
"""
def __init__(self, session_uid: str):
"""
Initialize the unified benchmark task
Args:
session_uid: Session ID for this task
"""
self.session_uid = session_uid
self.logs = []
self.is_completed = False
self.config_task = None
self.bench_task = None
self._add_log("[INFO] Initializing benchmark task")
def _add_log(self, message: str):
"""
Add a log message
Args:
message: Log message to add
"""
if message not in self.logs: # Avoid duplicates
self.logs.append(message)
# Force a copy to avoid reference problems
self.logs = self.logs.copy()
print(f"[{self.session_uid}] {message}")
def get_logs(self):
"""
Get all logs
Returns:
List of log messages
"""
return self.logs.copy()
def is_task_completed(self):
"""
Check if the task is completed
Returns:
True if completed, False otherwise
"""
return self.is_completed
def run(self, file_path: str):
"""
Run the benchmark process
Args:
file_path: Path to the uploaded file
"""
# Start in a separate thread to avoid blocking
import threading
thread = threading.Thread(target=self._run_process, args=(file_path,))
thread.daemon = True
thread.start()
def _run_process(self, file_path: str):
"""
Internal method to run the process
Args:
file_path: Path to the uploaded file
"""
try:
# Step 1: Configuration
self._add_log("[INFO] Starting configuration process")
self.config_task = CreateBenchConfigTask(session_uid=self.session_uid)
# Execute the configuration task
try:
config_path = self.config_task.run(file_path=file_path)
# Get configuration logs
config_logs = self.config_task.get_logs()
for log in config_logs:
self._add_log(log)
# Mark configuration step as completed
if "[SUCCESS] Stage completed: config_generation" not in self.logs:
self._add_log("[SUCCESS] Stage completed: configuration")
# Step 2: Benchmark
self._add_log("[INFO] Starting benchmark process")
self.bench_task = CreateBenchTask(session_uid=self.session_uid, config_path=config_path)
# Run the benchmark task
self.bench_task.run()
# Wait for the benchmark task to complete
while not self.bench_task.is_task_completed():
# Get new logs and add them
bench_logs = self.bench_task.get_logs()
for log in bench_logs:
self._add_log(log)
time.sleep(1)
# Get final logs
final_logs = self.bench_task.get_logs()
for log in final_logs:
self._add_log(log)
# Mark as completed
self.is_completed = True
self._add_log("[SUCCESS] Benchmark process completed successfully")
except Exception as config_error:
error_msg = str(config_error)
# Log detailed error
self._add_log(f"[ERROR] Configuration failed: {error_msg}")
# Check if it's a provider error and provide a more user-friendly message
if "Required models not available" in error_msg:
self._add_log("[ERROR] Some required models are not available at the moment. Please try again later.")
# Mark as completed with error
self.is_completed = True
except Exception as e:
self._add_log(f"[ERROR] Benchmark process failed: {str(e)}")
self.is_completed = True