demo / backend /tasks /create_bench_config_file.py
tfrere's picture
improve get available model provider
d2805fc
raw
history blame
16.1 kB
"""
Task to create and save the configuration file
"""
import os
import pathlib
import uuid
import yaml
import shutil
import time
import threading
from typing import Optional, Dict, Any, List, Tuple
from loguru import logger
from huggingface_hub import HfApi
from tasks.get_available_model_provider import get_available_model_provider
from config.models_config import (
DEFAULT_BENCHMARK_MODEL,
BENCHMARK_MODEL_ROLES,
DEFAULT_BENCHMARK_TIMEOUT,
PREFERRED_PROVIDERS,
ALTERNATIVE_BENCHMARK_MODELS,
)
class CreateBenchConfigTask:
"""
Task to create and save a configuration file for YourbenchSimpleDemo
"""
def __init__(self, session_uid: Optional[str] = None, timeout: float = None):
"""
Initialize the task with a session ID
Args:
session_uid: Optional session ID, will be generated if None
timeout: Timeout in seconds for benchmark operations (if None, uses default)
"""
self.session_uid = session_uid or str(uuid.uuid4())
self.logs: List[str] = []
self.is_completed = False
self.is_running_flag = threading.Event()
self.thread = None
self.timeout = timeout if timeout is not None else DEFAULT_BENCHMARK_TIMEOUT
self._add_log("[INFO] Initializing configuration creation task")
def _add_log(self, message: str) -> None:
"""
Add a log message to the logs list
Args:
message: Log message to add
"""
if message not in self.logs: # Avoid duplicates
self.logs.append(message)
# Force a copy of the list to avoid reference issues
self.logs = self.logs.copy()
# Log to system logs
logger.info(f"[{self.session_uid}] {message}")
def get_logs(self) -> List[str]:
"""
Get all logs for this task
Returns:
List of log messages
"""
return self.logs.copy() # Retourner une copie pour éviter les problèmes de référence
def save_uploaded_file(self, file_path: str) -> str:
"""
Process the uploaded file that is already in the correct directory
Args:
file_path: Path to the uploaded file
Returns:
Path to the file (same as input)
"""
try:
# The file is already in the correct location: uploaded_files/{session_id}/uploaded_files/
# Just log that we're processing it and return the path
self._add_log(f"[INFO] Processing file: {os.path.basename(file_path)}")
return file_path
except Exception as e:
error_msg = f"Error processing file: {str(e)}"
self._add_log(f"[ERROR] {error_msg}")
raise RuntimeError(error_msg)
def get_model_provider(self, model_name: str) -> Optional[str]:
"""
Get the available provider for a model
Args:
model_name: Name of the model to check
Returns:
Available provider or None if none found
"""
self._add_log(f"[INFO] Finding available provider for {model_name}")
# Essayer de trouver un provider pour le modèle
provider = get_available_model_provider(model_name, verbose=True)
if provider:
self._add_log(f"[INFO] Found provider for {model_name}: {provider}")
return provider
# Si aucun provider n'est trouvé avec la configuration préférée
# Essayons de trouver n'importe quel provider disponible en ignorant la préférence
from huggingface_hub import model_info
from tasks.get_available_model_provider import test_provider
self._add_log(f"[WARNING] No preferred provider found for {model_name}, trying all available providers...")
try:
# Obtenir tous les providers possibles pour ce modèle
info = model_info(model_name, expand="inferenceProviderMapping")
if hasattr(info, "inference_provider_mapping"):
providers = list(info.inference_provider_mapping.keys())
# Exclure les providers préférés déjà testés
other_providers = [p for p in providers if p not in PREFERRED_PROVIDERS]
if other_providers:
self._add_log(f"[INFO] Testing additional providers: {', '.join(other_providers)}")
# Tester chaque provider
for provider in other_providers:
self._add_log(f"[INFO] Testing provider {provider}")
if test_provider(model_name, provider, verbose=True):
self._add_log(f"[INFO] Found alternative provider for {model_name}: {provider}")
return provider
except Exception as e:
self._add_log(f"[WARNING] Error while testing additional providers: {str(e)}")
self._add_log(f"[WARNING] No available provider found for {model_name}")
return None
def generate_base_config(self, hf_org: str, hf_dataset_name: str) -> Dict[str, Any]:
"""
Create the base configuration dictionary
Args:
hf_org: Hugging Face organization name
hf_dataset_name: Hugging Face dataset name
Returns:
Configuration dictionary
"""
self._add_log(f"[INFO] Generating base configuration for {hf_dataset_name}")
# Check if HF token is available
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise RuntimeError("HF_TOKEN environment variable is not defined")
# Get provider for the default model
provider = self.get_model_provider(DEFAULT_BENCHMARK_MODEL)
# Si aucun provider n'est trouvé pour le modèle par défaut, essayer les modèles alternatifs
selected_model = DEFAULT_BENCHMARK_MODEL
if not provider:
self._add_log(f"[WARNING] Primary model {DEFAULT_BENCHMARK_MODEL} not available. Trying alternatives...")
# Utiliser la liste des modèles alternatifs depuis la configuration
for alt_model in ALTERNATIVE_BENCHMARK_MODELS:
self._add_log(f"[INFO] Trying alternative model: {alt_model}")
alt_provider = self.get_model_provider(alt_model)
if alt_provider:
self._add_log(f"[INFO] Found working alternative model: {alt_model} with provider: {alt_provider}")
selected_model = alt_model
provider = alt_provider
break
# Si toujours pas de provider, lever une exception
if not provider:
error_msg = "No model with available provider found. Cannot proceed with benchmark."
self._add_log(f"[ERROR] {error_msg}")
raise RuntimeError(error_msg)
# Create model configuration
model_list = [{
"model_name": selected_model,
"provider": provider,
"api_key": "$HF_TOKEN",
"max_concurrent_requests": 32,
}]
# Mettre à jour les roles de modèle si un modèle alternatif est utilisé
model_roles = dict(BENCHMARK_MODEL_ROLES)
if selected_model != DEFAULT_BENCHMARK_MODEL:
for role in model_roles:
if role != "chunking": # Ne pas changer le modèle de chunking
model_roles[role] = [selected_model]
self._add_log(f"[INFO] Updated model roles to use {selected_model}")
# Add minimum delay of 2 seconds for provider_check stage
self._add_log("[INFO] Finalizing provider check...")
time.sleep(2)
# Mark provider check stage as completed
self._add_log("[SUCCESS] Stage completed: provider_check")
return {
"hf_configuration": {
"token": "$HF_TOKEN",
"hf_organization": "$HF_ORGANIZATION",
"private": True,
"hf_dataset_name": hf_dataset_name,
"concat_if_exist": False,
"timeout": self.timeout, # Add timeout to configuration
},
"model_list": model_list,
"model_roles": model_roles,
"pipeline": {
"ingestion": {
"source_documents_dir": f"uploaded_files/{self.session_uid}/uploaded_files/",
"output_dir": f"uploaded_files/{self.session_uid}/ingested",
"run": True,
"timeout": self.timeout, # Add timeout to ingestion
},
"upload_ingest_to_hub": {
"source_documents_dir": f"uploaded_files/{self.session_uid}/ingested",
"run": True,
"timeout": self.timeout, # Add timeout to upload
},
"summarization": {
"run": True,
"timeout": self.timeout, # Add timeout to summarization
},
"chunking": {
"run": True,
"timeout": self.timeout, # Add timeout to chunking
"chunking_configuration": {
"l_min_tokens": 64,
"l_max_tokens": 128,
"tau_threshold": 0.8,
"h_min": 2,
"h_max": 5,
"num_multihops_factor": 1,
},
},
"single_shot_question_generation": {
"run": True,
"timeout": self.timeout, # Add timeout to question generation
"additional_instructions": "Generate rich and creative questions to test a curious adult",
"chunk_sampling": {
"mode": "count",
"value": 5,
"random_seed": 123,
},
},
"multi_hop_question_generation": {
"run": False,
"timeout": self.timeout, # Add timeout to multi-hop question generation
},
"lighteval": {
"run": False,
"timeout": self.timeout, # Add timeout to lighteval
},
},
}
def save_yaml_file(self, config: Dict[str, Any], path: str) -> str:
"""
Save the given configuration dictionary to a YAML file
Args:
config: Configuration dictionary
path: Path to save the file
Returns:
Path to the saved file
"""
try:
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as file:
yaml.dump(config, file, default_flow_style=False, sort_keys=False)
self._add_log(f"[INFO] Configuration saved: {path}")
return path
except Exception as e:
error_msg = f"Error saving configuration: {str(e)}"
self._add_log(f"[ERROR] {error_msg}")
raise RuntimeError(error_msg)
def _run_task(self, file_path: str) -> str:
"""
Internal method to run the task in a separate thread
Args:
file_path: Path to the uploaded file
Returns:
Path to the configuration file
"""
try:
# Use the default yourbench organization
org_name = os.getenv("HF_ORGANIZATION")
# Check if HF token is available
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise RuntimeError("HF_TOKEN environment variable is not defined")
self._add_log(f"[INFO] Organization: {org_name}")
time.sleep(0.5) # Simulate delay
# Save the uploaded file
saved_file_path = self.save_uploaded_file(file_path)
time.sleep(1) # Simulate delay
# Path for the config file
config_dir = pathlib.Path(f"uploaded_files/{self.session_uid}")
config_path = config_dir / "config.yml"
# Generate dataset name based on session ID
dataset_name = f"yourbench_{self.session_uid}"
self._add_log(f"[INFO] Dataset name: {dataset_name}")
time.sleep(0.8) # Simulate delay
# Log the start of finding providers
self._add_log("[INFO] Finding available providers for models...")
# Generate and save the configuration
config = self.generate_base_config(org_name, dataset_name)
time.sleep(1.2) # Simulate delay
config_file_path = self.save_yaml_file(config, str(config_path))
self._add_log(f"[INFO] Configuration generated successfully: {config_file_path}")
# Simulate additional processing
time.sleep(1.5) # Simulate delay
self._add_log("[INFO] Starting ingestion")
time.sleep(2) # Simulate delay
self._add_log(f"[INFO] Processing file: {dataset_name}")
time.sleep(2) # Simulate delay
self._add_log("[SUCCESS] Stage completed: config_generation")
# Tâche terminée
self.mark_task_completed()
return str(config_path)
except Exception as e:
error_msg = f"Error generating configuration: {str(e)}"
self._add_log(f"[ERROR] {error_msg}")
self.mark_task_completed()
raise RuntimeError(error_msg)
def run(self, file_path: str, token: Optional[str] = None, timeout: Optional[float] = None) -> str:
"""
Run the task to create and save the configuration file asynchronously
Args:
file_path: Path to the uploaded file
token: Hugging Face token (not used, using HF_TOKEN from environment)
timeout: Timeout in seconds for benchmark operations (if None, uses default)
Returns:
Path to the configuration file
"""
# Update timeout if provided
if timeout is not None:
self.timeout = timeout
# Mark the task as running
self.is_running_flag.set()
# Run the task directly without threading
try:
config_path = self._run_task(file_path)
return config_path
except Exception as e:
error_msg = f"Error generating configuration: {str(e)}"
self._add_log(f"[ERROR] {error_msg}")
self.mark_task_completed()
raise RuntimeError(error_msg)
def is_running(self) -> bool:
"""
Check if the task is running
Returns:
True if running, False otherwise
"""
return self.is_running_flag.is_set() and not self.is_completed
def is_task_completed(self) -> bool:
"""
Check if the task is completed
Returns:
True if completed, False otherwise
"""
return self.is_completed
def mark_task_completed(self) -> None:
"""
Mark the task as completed
"""
self.is_completed = True
self.is_running_flag.clear()
self._add_log("[INFO] Configuration generation task completed")