Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import logging | |
import json | |
from huggingface_hub import model_info, InferenceClient | |
from dotenv import load_dotenv | |
# Load environment variables once at the module level | |
load_dotenv() | |
# Define preferred providers | |
PREFERRED_PROVIDERS = ["fireworks-ai","sambanova", "novita"] | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
def prioritize_providers(providers): | |
"""Prioritize preferred providers, keeping all others.""" | |
return sorted(providers, key=lambda provider: provider not in PREFERRED_PROVIDERS) | |
def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool: | |
""" | |
Test if a specific provider is available for a model using InferenceClient | |
Args: | |
model_name: Name of the model | |
provider: Provider to test | |
verbose: Whether to log detailed information | |
Returns: | |
True if the provider is available, False otherwise | |
""" | |
try: | |
# Get HF token from environment | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN not defined in environment") | |
# Get HF token from environment | |
hf_organization = os.environ.get("HF_ORGANIZATION") | |
if not hf_organization: | |
raise ValueError("HF_ORGANIZATION not defined in environment") | |
if verbose: | |
logger.info(f"Testing provider {provider} for model {model_name}") | |
# Initialize the InferenceClient with the specific provider | |
client = InferenceClient( | |
model=model_name, | |
token=hf_token, | |
provider=provider, | |
# bill_to=hf_organization, | |
timeout=10 # Increased timeout to allow model loading | |
) | |
try: | |
# Use the chat completions method for testing | |
response = client.chat_completion( | |
messages=[{"role": "user", "content": "Hello"}], | |
max_tokens=5 | |
) | |
if verbose: | |
logger.info(f"Provider {provider} is available for {model_name}") | |
return True | |
except Exception as e: | |
if verbose: | |
error_message = str(e) | |
logger.warning(f"Error with provider {provider}: {error_message}") | |
# Log specific error types if we can identify them | |
if "status_code=429" in error_message: | |
logger.warning(f"Provider {provider} rate limited. You may need to wait or upgrade your plan.") | |
elif "status_code=401" in error_message: | |
logger.warning(f"Authentication failed for provider {provider}. Check your token.") | |
elif "status_code=503" in error_message: | |
logger.warning(f"Provider {provider} service unavailable. Model may be loading or provider is down.") | |
elif "timed out" in error_message.lower(): | |
logger.warning(f"Timeout error with provider {provider} - request timed out after 10 seconds") | |
return False | |
except Exception as e: | |
if verbose: | |
logger.warning(f"Error in test_provider: {str(e)}") | |
return False | |
def get_available_model_provider(model_name, verbose=False): | |
""" | |
Get the first available provider for a given model. | |
Args: | |
model_name: Name of the model on the Hub | |
verbose: Whether to log detailed information | |
Returns: | |
First available provider or None if none are available | |
""" | |
try: | |
# Get HF token from environment | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise ValueError("HF_TOKEN not defined in environment") | |
# Get providers for the model and prioritize them | |
info = model_info(model_name, expand="inferenceProviderMapping") | |
if not hasattr(info, "inference_provider_mapping"): | |
if verbose: | |
logger.info(f"No inference providers found for {model_name}") | |
return None | |
providers = list(info.inference_provider_mapping.keys()) | |
if not providers: | |
if verbose: | |
logger.info(f"Empty list of providers for {model_name}") | |
return None | |
# Prioritize providers | |
providers = prioritize_providers(providers) | |
if verbose: | |
logger.info(f"Available providers for {model_name}: {', '.join(providers)}") | |
# Test each provider | |
for provider in providers: | |
if test_provider(model_name, provider, verbose): | |
return provider | |
# If we've tried all providers and none worked, log this but don't raise an exception | |
if verbose: | |
logger.error(f"No available providers for {model_name}") | |
return None | |
except Exception as e: | |
if verbose: | |
logger.error(f"Error in get_available_model_provider: {str(e)}") | |
return None | |
if __name__ == "__main__": | |
models = [ | |
"Qwen/QwQ-32B", | |
"Qwen/Qwen2.5-72B-Instruct", | |
"meta-llama/Llama-3.3-70B-Instruct", | |
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B", | |
"mistralai/Mistral-Small-24B-Instruct-2501", | |
"meta-llama/Llama-3.1-8B-Instruct", | |
"Qwen/Qwen2.5-32B-Instruct" | |
] | |
providers = [] | |
unavailable_models = [] | |
for model in models: | |
provider = get_available_model_provider(model, verbose=True) | |
if provider: | |
providers.append((model, provider)) | |
else: | |
unavailable_models.append(model) | |
for model, provider in providers: | |
print(f"Model: {model}, Provider: {provider}") | |
if unavailable_models: | |
print(f"Models with no available providers: {', '.join(unavailable_models)}") | |
print(f"Total Providers {len(providers)}: {providers}") |