tfrere commited on
Commit
13efede
·
1 Parent(s): 00d9f6f

refactor get available model provider

Browse files
backend/tasks/get_available_model_provider.py CHANGED
@@ -4,53 +4,19 @@ import json
4
  from huggingface_hub import model_info, InferenceClient
5
  from dotenv import load_dotenv
6
 
 
 
 
7
  # Define preferred providers
8
  PREFERRED_PROVIDERS = ["sambanova", "novita"]
9
 
10
- def filter_providers(providers):
11
- """Filter providers to only include preferred ones."""
12
- return [provider for provider in providers if provider in PREFERRED_PROVIDERS]
13
-
14
- def prioritize_providers(providers):
15
- """Prioritize preferred providers, keeping all others."""
16
- preferred = [provider for provider in providers if provider in PREFERRED_PROVIDERS]
17
- non_preferred = [provider for provider in providers if provider not in PREFERRED_PROVIDERS]
18
- return preferred + non_preferred
19
-
20
  # Configure logging
21
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
  logger = logging.getLogger(__name__)
23
 
24
- def is_vision_model(model_name: str) -> bool:
25
- """
26
- Check if the model is a vision model based on its name
27
-
28
- Args:
29
- model_name: Name of the model
30
-
31
- Returns:
32
- True if it's a vision model, False otherwise
33
- """
34
- vision_indicators = ["-VL-", "vision", "clip", "image"]
35
- return any(indicator in model_name.lower() for indicator in vision_indicators)
36
-
37
- def get_test_payload(model_name: str) -> dict:
38
- """
39
- Get the appropriate test payload based on model type
40
-
41
- Args:
42
- model_name: Name of the model
43
-
44
- Returns:
45
- Dictionary containing the test payload
46
- """
47
- # We're only testing text models now
48
- return {
49
- "inputs": "Hello",
50
- "parameters": {
51
- "max_new_tokens": 5
52
- }
53
- }
54
 
55
  def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool:
56
  """
@@ -65,9 +31,6 @@ def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool
65
  True if the provider is available, False otherwise
66
  """
67
  try:
68
- # Load environment variables
69
- load_dotenv()
70
-
71
  # Get HF token from environment
72
  hf_token = os.environ.get("HF_TOKEN")
73
  if not hf_token:
@@ -128,9 +91,6 @@ def get_available_model_provider(model_name, verbose=False):
128
  First available provider or None if none are available
129
  """
130
  try:
131
- # Load environment variables
132
- load_dotenv()
133
-
134
  # Get HF token from environment
135
  hf_token = os.environ.get("HF_TOKEN")
136
  if not hf_token:
@@ -168,23 +128,6 @@ def get_available_model_provider(model_name, verbose=False):
168
  return None
169
 
170
  if __name__ == "__main__":
171
- # # Example usage with verbose mode enabled
172
- # model = "Qwen/Qwen2.5-72B-Instruct"
173
-
174
- # # Test sambanova provider
175
- # print("\nTesting sambanova provider:")
176
- # sambanova_available = test_provider(model, "sambanova", verbose=True)
177
- # print(f"sambanova available: {sambanova_available}")
178
-
179
- # # Test novita provider
180
- # print("\nTesting novita provider:")
181
- # novita_available = test_provider(model, "novita", verbose=True)
182
- # print(f"novita available: {novita_available}")
183
-
184
- # # Test automatic provider selection
185
- # print("\nTesting automatic provider selection:")
186
- # provider = get_available_model_provider(model, verbose=True)
187
- # print(f"Selected provider: {provider}")
188
 
189
  models = [
190
  "Qwen/QwQ-32B",
@@ -201,8 +144,4 @@ if __name__ == "__main__":
201
  providers.append(provider)
202
 
203
  print(f"Providers {len(providers)}: {providers}")
204
-
205
-
206
- # print("\nTesting novita provider:")
207
- # novita_available = test_provider("deepseek-ai/DeepSeek-V3-0324", "novita", verbose=True)
208
- # print(f"novita available: {novita_available}")
 
4
  from huggingface_hub import model_info, InferenceClient
5
  from dotenv import load_dotenv
6
 
7
+ # Load environment variables once at the module level
8
+ load_dotenv()
9
+
10
  # Define preferred providers
11
  PREFERRED_PROVIDERS = ["sambanova", "novita"]
12
 
 
 
 
 
 
 
 
 
 
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
  logger = logging.getLogger(__name__)
16
 
17
+ def prioritize_providers(providers):
18
+ """Prioritize preferred providers, keeping all others."""
19
+ return sorted(providers, key=lambda provider: provider not in PREFERRED_PROVIDERS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def test_provider(model_name: str, provider: str, verbose: bool = False) -> bool:
22
  """
 
31
  True if the provider is available, False otherwise
32
  """
33
  try:
 
 
 
34
  # Get HF token from environment
35
  hf_token = os.environ.get("HF_TOKEN")
36
  if not hf_token:
 
91
  First available provider or None if none are available
92
  """
93
  try:
 
 
 
94
  # Get HF token from environment
95
  hf_token = os.environ.get("HF_TOKEN")
96
  if not hf_token:
 
128
  return None
129
 
130
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  models = [
133
  "Qwen/QwQ-32B",
 
144
  providers.append(provider)
145
 
146
  print(f"Providers {len(providers)}: {providers}")
147
+