import os import logging import json from huggingface_hub import model_info, InferenceClient from dotenv import load_dotenv from config.models_config import PREFERRED_PROVIDERS # Load environment variables once at the module level load_dotenv() # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) def prioritize_providers(providers): """Prioritize preferred providers, keeping all others.""" return sorted(providers, key=lambda provider: provider not in PREFERRED_PROVIDERS) def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool: """ Test if a specific provider is available for a model using InferenceClient Args: model_name: Name of the model provider: Provider to test verbose: Whether to log detailed information Returns: True if the provider is available, False otherwise """ try: # Get HF token from environment hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN not defined in environment") # Get HF token from environment hf_organization = os.environ.get("HF_ORGANIZATION") if not hf_organization: raise ValueError("HF_ORGANIZATION not defined in environment") if verbose: logger.info(f"Testing provider {provider} for model {model_name}") # Initialize the InferenceClient with the specific provider client = InferenceClient( model=model_name, token=hf_token, provider=provider, # bill_to=hf_organization, timeout=3 # Increased timeout to allow model loading ) try: # Use the chat completions method for testing response = client.chat_completion( messages=[{"role": "user", "content": "Hello"}], max_tokens=5 ) if verbose: logger.info(f"Provider {provider} is available for {model_name}") return True except Exception as e: if verbose: error_message = str(e) logger.warning(f"Error with provider {provider}: {error_message}") # Log specific error types if we can identify them if "status_code=429" in error_message: logger.warning(f"Provider {provider} rate limited. You may need to wait or upgrade your plan.") elif "status_code=401" in error_message: logger.warning(f"Authentication failed for provider {provider}. Check your token.") elif "status_code=503" in error_message: logger.warning(f"Provider {provider} service unavailable. Model may be loading or provider is down.") elif "timed out" in error_message.lower(): logger.warning(f"Timeout error with provider {provider} - request timed out after 10 seconds") return False except Exception as e: if verbose: logger.warning(f"Error in test_provider: {str(e)}") return False def get_available_model_provider(model_name, verbose=False): """ Get the first available provider for a given model. Args: model_name: Name of the model on the Hub verbose: Whether to log detailed information Returns: First available provider or None if none are available """ try: # Get HF token from environment hf_token = os.environ.get("HF_TOKEN") if not hf_token: if verbose: logger.error("HF_TOKEN not defined in environment") raise ValueError("HF_TOKEN not defined in environment") # Get providers for the model and prioritize them try: info = model_info(model_name, token=hf_token, expand="inferenceProviderMapping") if not hasattr(info, "inference_provider_mapping"): if verbose: logger.info(f"No inference providers found for {model_name}") return None providers = list(info.inference_provider_mapping.keys()) if not providers: if verbose: logger.info(f"Empty list of providers for {model_name}") return None except Exception as e: if verbose: logger.error(f"Error retrieving model info for {model_name}: {str(e)}") return None # Prioritize providers prioritized_providers = prioritize_providers(providers) if verbose: logger.info(f"Available providers for {model_name}: {', '.join(providers)}") logger.info(f"Prioritized providers: {', '.join(prioritized_providers)}") # Test each preferred provider first failed_providers = [] for provider in prioritized_providers: if verbose: logger.info(f"Testing provider {provider} for {model_name}") try: if test_provider(model_name, provider, verbose): if verbose: logger.info(f"Provider {provider} is available for {model_name}") return provider else: failed_providers.append(provider) if verbose: logger.warning(f"Provider {provider} test failed for {model_name}") except Exception as e: failed_providers.append(provider) if verbose: logger.error(f"Exception while testing provider {provider} for {model_name}: {str(e)}") # If all prioritized providers failed, try any remaining providers remaining_providers = [p for p in providers if p not in prioritized_providers and p not in failed_providers] if remaining_providers and verbose: logger.info(f"Trying remaining non-prioritized providers: {', '.join(remaining_providers)}") for provider in remaining_providers: if verbose: logger.info(f"Testing non-prioritized provider {provider} for {model_name}") try: if test_provider(model_name, provider, verbose): if verbose: logger.info(f"Non-prioritized provider {provider} is available for {model_name}") return provider except Exception as e: if verbose: logger.error(f"Exception while testing non-prioritized provider {provider}: {str(e)}") # If we've tried all providers and none worked, log this but don't raise an exception if verbose: logger.error(f"No available providers for {model_name}. Tried {len(failed_providers + remaining_providers)} providers.") return None except Exception as e: if verbose: logger.error(f"Error in get_available_model_provider: {str(e)}") return None if __name__ == "__main__": models = [ "Qwen/QwQ-32B", "Qwen/Qwen2.5-72B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "mistralai/Mistral-Small-24B-Instruct-2501", "meta-llama/Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-32B-Instruct" ] providers = [] unavailable_models = [] for model in models: provider = get_available_model_provider(model, verbose=True) if provider: providers.append((model, provider)) else: unavailable_models.append(model) for model, provider in providers: print(f"Model: {model}, Provider: {provider}") if unavailable_models: print(f"Models with no available providers: {', '.join(unavailable_models)}") print(f"Total Providers {len(providers)}: {providers}")