demo / backend /tasks /get_model_providers.py
tfrere's picture
update lighteval results
39acd70
raw
history blame
1.71 kB
from huggingface_hub import model_info
PREFERRED_PROVIDERS = ["sambanova", "novita"]
def filter_providers(providers):
"""Filter providers to only include preferred ones."""
return [provider for provider in providers if provider in PREFERRED_PROVIDERS]
def prioritize_providers(providers):
"""Prioritize preferred providers, keeping all others."""
preferred = [provider for provider in providers if provider in PREFERRED_PROVIDERS]
non_preferred = [provider for provider in providers if provider not in PREFERRED_PROVIDERS]
return preferred + non_preferred
def get_model_providers(models, prioritize=True):
"""Get model providers, optionally prioritizing preferred ones."""
results = []
for model_name in models:
try:
info = model_info(model_name, expand="inferenceProviderMapping")
if hasattr(info, "inference_provider_mapping"):
providers = info.inference_provider_mapping.keys()
if prioritize:
providers = prioritize_providers(providers)
else:
providers = filter_providers(providers)
else:
providers = []
results.append((model_name, providers))
except Exception as e:
results.append((model_name, []))
return results
if __name__ == "__main__":
example_models = [
"Qwen/Qwen2.5-72B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"Qwen/QwQ-32B",
"mistralai/Mistral-Small-24B-Instruct-2501"
]
results = get_model_providers(example_models, prioritize=True)
print(results)