File size: 1,705 Bytes
2a8ebbd
8e3f969
2a8ebbd
 
39acd70
2a8ebbd
 
39acd70
 
 
 
 
 
 
 
2a8ebbd
 
 
 
 
39acd70
 
 
 
 
 
 
 
2a8ebbd
 
 
 
 
 
 
 
 
 
 
 
 
 
39acd70
2a8ebbd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from huggingface_hub import model_info
PREFERRED_PROVIDERS = ["sambanova", "novita"]

def filter_providers(providers):
    """Filter providers to only include preferred ones."""
    return [provider for provider in providers if provider in PREFERRED_PROVIDERS]

def prioritize_providers(providers):
    """Prioritize preferred providers, keeping all others."""
    preferred = [provider for provider in providers if provider in PREFERRED_PROVIDERS]
    non_preferred = [provider for provider in providers if provider not in PREFERRED_PROVIDERS]
    return preferred + non_preferred

def get_model_providers(models, prioritize=True):
    """Get model providers, optionally prioritizing preferred ones."""
    results = []
    
    for model_name in models:
        try:
            info = model_info(model_name, expand="inferenceProviderMapping")
            if hasattr(info, "inference_provider_mapping"):
                providers = info.inference_provider_mapping.keys()
                if prioritize:
                    providers = prioritize_providers(providers)
                else:
                    providers = filter_providers(providers)
            else:
                providers = []
            results.append((model_name, providers))
        except Exception as e:
            results.append((model_name, []))
    
    return results

if __name__ == "__main__":
    example_models = [
        "Qwen/Qwen2.5-72B-Instruct",
        "meta-llama/Llama-3.3-70B-Instruct",
        "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
        "Qwen/QwQ-32B",
        "mistralai/Mistral-Small-24B-Instruct-2501"
    ]
    results = get_model_providers(example_models, prioritize=True)
    print(results)