Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	File size: 5,065 Bytes
			
			| 469eae6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import logging
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.types.llms.openai import AllMessageValues, ChatCompletionRequest
if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
    LoggingClass = LiteLLMLoggingObj
else:
    LoggingClass = Any
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import HuggingFaceError, _fetch_inference_provider_mapping
logger = logging.getLogger(__name__)
BASE_URL = "https://router.huggingface.co"
class HuggingFaceChatConfig(OpenAIGPTConfig):
    """
    Reference: https://huggingface.co/docs/huggingface_hub/guides/inference
    """
    def validate_environment(
        self,
        headers: dict,
        model: str,
        messages: List[AllMessageValues],
        optional_params: Dict,
        litellm_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        default_headers = {
            "content-type": "application/json",
        }
        if api_key is not None:
            default_headers["Authorization"] = f"Bearer {api_key}"
        headers = {**headers, **default_headers}
        return headers
    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
    ) -> BaseLLMException:
        return HuggingFaceError(
            status_code=status_code, message=error_message, headers=headers
        )
    def get_base_url(self, model: str, base_url: Optional[str]) -> Optional[str]:
        """
        Get the API base for the Huggingface API.
        Do not add the chat/embedding/rerank extension here. Let the handler do this.
        """
        if model.startswith(("http://", "https://")):
            base_url = model
        elif base_url is None:
            base_url = os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE", "")
        return base_url
    def get_complete_url(
        self,
        api_base: Optional[str],
        api_key: Optional[str],
        model: str,
        optional_params: dict,
        litellm_params: dict,
        stream: Optional[bool] = None,
    ) -> str:
        """
        Get the complete URL for the API call.
        For provider-specific routing through huggingface
        """
        # 1. Check if api_base is provided
        if api_base is not None:
            complete_url = api_base
        elif os.getenv("HF_API_BASE") or os.getenv("HUGGINGFACE_API_BASE"):
            complete_url = str(os.getenv("HF_API_BASE")) or str(
                os.getenv("HUGGINGFACE_API_BASE")
            )
        elif model.startswith(("http://", "https://")):
            complete_url = model
        # 4. Default construction with provider
        else:
            # Parse provider and model
            first_part, remaining = model.split("/", 1)
            if "/" in remaining:
                provider = first_part
            else:
                provider = "hf-inference"
            if provider == "hf-inference":
                route = f"{provider}/models/{model}/v1/chat/completions"
            elif provider == "novita":
                route = f"{provider}/chat/completions"
            else:
                route = f"{provider}/v1/chat/completions"
            complete_url = f"{BASE_URL}/{route}"
        # Ensure URL doesn't end with a slash
        complete_url = complete_url.rstrip("/")
        return complete_url
    def transform_request(
        self,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        headers: dict,
    ) -> dict:
        if "max_retries" in optional_params:
            logger.warning("`max_retries` is not supported. It will be ignored.")
            optional_params.pop("max_retries", None)
        first_part, remaining = model.split("/", 1)
        if "/" in remaining:
            provider = first_part
            model_id = remaining
        else:
            provider = "hf-inference"
            model_id = model
        provider_mapping = _fetch_inference_provider_mapping(model_id)
        if provider not in provider_mapping:
            raise HuggingFaceError(
                message=f"Model {model_id} is not supported for provider {provider}",
                status_code=404,
                headers={},
            )
        provider_mapping = provider_mapping[provider]
        if provider_mapping["status"] == "staging":
            logger.warning(
                f"Model {model_id} is in staging mode for provider {provider}. Meant for test purposes only."
            )
        mapped_model = provider_mapping["providerId"]
        messages = self._transform_messages(messages=messages, model=mapped_model)
        return dict(
            ChatCompletionRequest(
                model=mapped_model, messages=messages, **optional_params
            )
        )
 | 
