Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
#!/usr/bin/env python | |
""" | |
Script to test rate limits of Hugging Face Inference API providers. | |
Spams requests to a model/provider and collects error messages. | |
Usage: python test_provider_rate_limits.py --model "model_name" --provider "provider_name" --requests 50 | |
""" | |
import argparse | |
import json | |
import time | |
import os | |
import requests | |
import sys | |
import logging | |
from concurrent.futures import ThreadPoolExecutor | |
from collections import Counter | |
from typing import Dict, List, Tuple | |
from dotenv import load_dotenv | |
# Add parent directory to path to import from tasks | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from tasks.get_available_model_provider import prioritize_providers | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger("rate_limit_test") | |
# Default model to test | |
DEFAULT_MODEL = "meta-llama/Llama-3.3-70B-Instruct" | |
def send_request(model: str, provider: str, token: str, request_id: int) -> Dict: | |
""" | |
Send a single request to the model with the given provider. | |
Args: | |
model: Model name | |
provider: Provider name | |
token: HF token | |
request_id: ID for this request | |
Returns: | |
Dictionary with request info and result | |
""" | |
headers = { | |
"Authorization": f"Bearer {token}", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"inputs": f"Request {request_id}: Hello, what do you thing about the future of AI? And divide me 10 by {request_id}", | |
"parameters": { | |
"max_new_tokens": 10000, | |
"provider": provider | |
} | |
} | |
api_url = f"https://api-inference.huggingface.co/models/{model}" | |
start_time = time.time() | |
try: | |
response = requests.post(api_url, headers=headers, json=payload, timeout=15) | |
end_time = time.time() | |
result = { | |
"request_id": request_id, | |
"status_code": response.status_code, | |
"time_taken": end_time - start_time, | |
"headers": dict(response.headers), | |
"success": response.status_code == 200, | |
} | |
if response.status_code != 200: | |
try: | |
error_data = response.json() | |
if isinstance(error_data, dict) and "error" in error_data: | |
result["error_message"] = error_data["error"] | |
else: | |
result["error_message"] = str(error_data) | |
except: | |
result["error_message"] = response.text | |
return result | |
except Exception as e: | |
end_time = time.time() | |
return { | |
"request_id": request_id, | |
"status_code": 0, | |
"time_taken": end_time - start_time, | |
"success": False, | |
"error_message": str(e) | |
} | |
def run_rate_limit_test(model: str, provider: str = None, num_requests: int = 50, | |
max_workers: int = 10, delay: float = 0.1) -> List[Dict]: | |
""" | |
Run a rate limit test by sending multiple requests to the specified model/provider. | |
Args: | |
model: Model to test | |
provider: Provider to test (if None, will use first available) | |
num_requests: Number of requests to send | |
max_workers: Maximum number of concurrent workers | |
delay: Delay between batches of requests | |
Returns: | |
List of results for each request | |
""" | |
# Load environment variables | |
load_dotenv() | |
# Get HF token | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
logger.error("HF_TOKEN not defined in environment") | |
return [] | |
# If provider not specified, get first available | |
if not provider: | |
from tasks.get_available_model_provider import get_available_model_provider | |
provider = get_available_model_provider(model) | |
if not provider: | |
logger.error(f"No available provider found for {model}") | |
return [] | |
logger.info(f"Testing rate limits for {model} with provider: {provider}") | |
logger.info(f"Sending {num_requests} requests with {max_workers} concurrent workers") | |
# Send requests in parallel | |
results = [] | |
with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
future_to_id = { | |
executor.submit(send_request, model, provider, hf_token, i): i | |
for i in range(num_requests) | |
} | |
completed = 0 | |
for future in future_to_id: | |
result = future.result() | |
results.append(result) | |
completed += 1 | |
if completed % 10 == 0: | |
logger.info(f"Completed {completed}/{num_requests} requests") | |
# Add a small delay periodically to avoid overwhelming the API | |
if completed % max_workers == 0: | |
time.sleep(delay) | |
return results | |
def analyze_results(results: List[Dict]) -> Dict: | |
""" | |
Analyze the results of the rate limit test. | |
Args: | |
results: List of request results | |
Returns: | |
Dictionary with analysis | |
""" | |
total_requests = len(results) | |
successful = sum(1 for r in results if r["success"]) | |
failed = total_requests - successful | |
# Count different error messages | |
error_messages = Counter(r.get("error_message") for r in results if not r["success"]) | |
# Calculate timing statistics | |
times = [r["time_taken"] for r in results] | |
avg_time = sum(times) / len(times) if times else 0 | |
# Check for rate limiting headers | |
rate_limit_headers = set() | |
for r in results: | |
if "headers" in r: | |
for header in r["headers"]: | |
if "rate" in header.lower() or "limit" in header.lower(): | |
rate_limit_headers.add(header) | |
return { | |
"total_requests": total_requests, | |
"successful_requests": successful, | |
"failed_requests": failed, | |
"success_rate": successful / total_requests if total_requests > 0 else 0, | |
"average_time": avg_time, | |
"error_messages": dict(error_messages), | |
"rate_limit_headers": list(rate_limit_headers) | |
} | |
def display_results(results: List[Dict], analysis: Dict) -> None: | |
""" | |
Display the results of the rate limit test. | |
Args: | |
results: List of request results | |
analysis: Analysis of results | |
""" | |
print("\n" + "="*80) | |
print(f"RATE LIMIT TEST RESULTS") | |
print("="*80) | |
print(f"\nTotal Requests: {analysis['total_requests']}") | |
print(f"Successful: {analysis['successful_requests']} ({analysis['success_rate']*100:.1f}%)") | |
print(f"Failed: {analysis['failed_requests']}") | |
print(f"Average Time: {analysis['average_time']:.3f} seconds") | |
if analysis["rate_limit_headers"]: | |
print("\nRate Limit Headers Found:") | |
for header in analysis["rate_limit_headers"]: | |
print(f" - {header}") | |
if analysis["error_messages"]: | |
print("\nError Messages:") | |
for msg, count in analysis["error_messages"].items(): | |
print(f" - [{count} occurrences] {msg}") | |
# Print sample of headers from a failed request | |
failed_requests = [r for r in results if not r["success"]] | |
if failed_requests: | |
print("\nSample Headers from a Failed Request:") | |
for header, value in failed_requests[0].get("headers", {}).items(): | |
print(f" {header}: {value}") | |
def main(): | |
""" | |
Main entry point for the script. | |
""" | |
parser = argparse.ArgumentParser(description="Test rate limits of Hugging Face Inference API providers.") | |
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="Name of the model to test") | |
parser.add_argument("--provider", type=str, help="Name of the provider to test (if not specified, will use first available)") | |
parser.add_argument("--requests", type=int, default=50, help="Number of requests to send") | |
parser.add_argument("--workers", type=int, default=10, help="Maximum number of concurrent workers") | |
parser.add_argument("--delay", type=float, default=0.1, help="Delay between batches of requests") | |
parser.add_argument("--output", type=str, help="Path to save results as JSON (optional)") | |
args = parser.parse_args() | |
# Run the test | |
results = run_rate_limit_test( | |
model=args.model, | |
provider=args.provider, | |
num_requests=args.requests, | |
max_workers=args.workers, | |
delay=args.delay | |
) | |
if not results: | |
logger.error("Test failed to run properly") | |
return | |
# Analyze the results | |
analysis = analyze_results(results) | |
# Display the results | |
display_results(results, analysis) | |
# Save results if requested | |
if args.output: | |
with open(args.output, "w") as f: | |
json.dump({ | |
"results": results, | |
"analysis": analysis | |
}, f, indent=2) | |
logger.info(f"Results saved to {args.output}") | |
if __name__ == "__main__": | |
main() |