InferBench / api /__init__.py
nifleisch
feat: add core logic for project
2c50826
raw
history blame
1.37 kB
from typing import Optional, Type
from api.baseline import BaselineAPI
from api.fireworks import FireworksAPI
from api.flux import FluxAPI
from api.pruna import PrunaAPI
from api.replicate import ReplicateAPI
from api.together import TogetherAPI
from api.fal import FalAPI
def create_api(api_type: str) -> FluxAPI:
"""
Factory function to create API instances.
Args:
api_type (str): The type of API to create. Must be one of:
- "baseline"
- "fireworks"
- "pruna_speed_mode" (where speed_mode is the desired speed mode)
- "replicate"
- "together"
- "fal"
Returns:
FluxAPI: An instance of the requested API implementation
Raises:
ValueError: If an invalid API type is provided
"""
if api_type.startswith("pruna_"):
speed_mode = api_type[6:] # Remove "pruna_" prefix
return PrunaAPI(speed_mode)
api_map: dict[str, Type[FluxAPI]] = {
"baseline": BaselineAPI,
"fireworks": FireworksAPI,
"replicate": ReplicateAPI,
"together": TogetherAPI,
"fal": FalAPI,
}
if api_type not in api_map:
raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'")
return api_map[api_type]()