""" Models Registry for Hugging Face Spaces Optimized for remote inference without local model loading. """ import yaml import os from typing import List, Dict, Any, Optional from dataclasses import dataclass import sys from huggingface_hub import InferenceClient # Add src to path for imports sys.path.append('src') from utils.config_loader import config_loader @dataclass class ModelConfig: """Configuration for a model.""" name: str provider: str model_id: str params: Dict[str, Any] description: str class ModelsRegistry: """Registry for managing models from YAML configuration.""" def __init__(self, config_path: str = "config/models.yaml"): self.config_path = config_path self.models = self._load_models() def _load_models(self) -> List[ModelConfig]: """Load models from YAML configuration file.""" if not os.path.exists(self.config_path): raise FileNotFoundError(f"Models config file not found: {self.config_path}") with open(self.config_path, 'r') as f: config = yaml.safe_load(f) models = [] for model_data in config.get('models', []): model = ModelConfig( name=model_data['name'], provider=model_data['provider'], model_id=model_data['model_id'], params=model_data.get('params', {}), description=model_data.get('description', '') ) models.append(model) return models def get_models(self) -> List[ModelConfig]: """Get all available models.""" return self.models def get_model_by_name(self, name: str) -> Optional[ModelConfig]: """Get a specific model by name.""" for model in self.models: if model.name == name: return model return None def get_models_by_provider(self, provider: str) -> List[ModelConfig]: """Get all models from a specific provider.""" return [model for model in self.models if model.provider == provider] class HuggingFaceInference: """Interface for Hugging Face Inference API using InferenceClient.""" def __init__(self, api_token: Optional[str] = None): self.api_token = api_token or os.getenv("HF_TOKEN") # We'll create clients dynamically based on provider def generate(self, model_id: str, prompt: str, params: Dict[str, Any], provider: str = "hf-inference") -> str: """Generate text using Hugging Face Inference API with specified provider.""" try: # Create InferenceClient with the specified provider client = InferenceClient( provider=provider, api_key=os.environ.get("HF_TOKEN") ) # Use different methods based on provider capabilities if provider == "nebius" or provider == "together" or provider == "groq": # Nebius provider only supports conversational tasks, use chat completion completion = client.chat.completions.create( model=model_id, messages=[ { "role": "user", "content": prompt } ], max_tokens=params.get('max_new_tokens', 128), temperature=params.get('temperature', 0.1), top_p=params.get('top_p', 0.9) ) # Extract the content from the response return completion.choices[0].message.content else: # Other providers use text_generation result = client.text_generation( prompt=prompt, model=model_id, max_new_tokens=params.get('max_new_tokens', 128), temperature=params.get('temperature', 0.1), top_p=params.get('top_p', 0.9), return_full_text=False # Only return the generated part ) return result except Exception as e: # Improved error handling with detailed error messages error_msg = str(e) print(f"🔍 Debug - Full error: {error_msg}") if "404" in error_msg or "Not Found" in error_msg: raise Exception(f"Model not found: {model_id} - Model may not be available via {provider} provider") elif "401" in error_msg or "Unauthorized" in error_msg: raise Exception(f"Authentication failed - check HF_TOKEN for {provider} provider") elif "503" in error_msg or "Service Unavailable" in error_msg: raise Exception(f"Model {model_id} is loading on {provider}, please try again in a moment") elif "timeout" in error_msg.lower(): raise Exception(f"Request timeout - model may be loading on {provider}") elif "not supported for task" in error_msg: raise Exception(f"Model {model_id} task not supported by {provider} provider: {error_msg}") elif "not supported by provider" in error_msg: raise Exception(f"Model {model_id} not supported by {provider} provider: {error_msg}") else: raise Exception(f"{provider} API error: {error_msg}") class ModelInterface: """Unified interface for all model providers.""" def __init__(self): self.hf_interface = HuggingFaceInference() self.mock_mode = os.getenv("MOCK_MODE", "false").lower() == "true" self.has_hf_token = bool(os.getenv("HF_TOKEN")) def _generate_mock_sql(self, model_config: ModelConfig, prompt: str) -> str: """Generate mock SQL for demo purposes when API keys aren't available.""" # Get mock SQL configuration mock_config = config_loader.get_mock_sql_config() patterns = mock_config["patterns"] templates = mock_config["templates"] # Extract the question from the prompt if "Question:" in prompt: question = prompt.split("Question:")[1].split("Requirements:")[0].strip() else: question = "unknown question" # Simple mock SQL generation based on configured patterns question_lower = question.lower() # Check patterns in order of specificity if any(pattern in question_lower for pattern in patterns["count_queries"]): if "trips" in question_lower: return templates["count_trips"] else: return templates["count_generic"] elif any(pattern in question_lower for pattern in patterns["average_queries"]): if "fare" in question_lower: return templates["avg_fare"] else: return templates["avg_generic"] elif any(pattern in question_lower for pattern in patterns["total_queries"]): return templates["total_amount"] elif any(pattern in question_lower for pattern in patterns["passenger_queries"]): return templates["passenger_count"] else: # Default fallback return templates["default"] def generate_sql(self, model_config: ModelConfig, prompt: str) -> str: """Generate SQL using the specified model.""" # Use mock mode if no HF token is available if not self.has_hf_token: print(f"🎭 No HF_TOKEN available, using mock mode for {model_config.name}") return self._generate_mock_sql(model_config, prompt) # Use mock mode only if explicitly set if self.mock_mode: print(f"🎭 Mock mode enabled for {model_config.name}") return self._generate_mock_sql(model_config, prompt) try: if model_config.provider in ["huggingface", "hf-inference", "together", "nebius"]: print(f"🤗 Using {model_config.provider} Inference API for {model_config.name}") return self.hf_interface.generate( model_config.model_id, prompt, model_config.params, model_config.provider ) else: raise ValueError(f"Unsupported provider: {model_config.provider}") except Exception as e: print(f"⚠️ Error with {model_config.name}: {str(e)}") print(f"🎭 Falling back to mock mode for {model_config.name}") return self._generate_mock_sql(model_config, prompt) # Global instances models_registry = ModelsRegistry() model_interface = ModelInterface()