Spaces:
Paused
Paused
| import enum | |
| from typing import Any, List, Optional, Tuple, cast | |
| from urllib.parse import urlparse | |
| import httpx | |
| from httpx import Response | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
| _audio_or_image_in_message_content, | |
| convert_content_list_to_str, | |
| ) | |
| from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj | |
| from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error | |
| from litellm.llms.openai.openai import OpenAIConfig | |
| from litellm.secret_managers.main import get_secret_str | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import ModelResponse, ProviderField | |
| from litellm.utils import _add_path_to_api_base, supports_tool_choice | |
| class AzureFoundryErrorStrings(str, enum.Enum): | |
| SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'" | |
| class AzureAIStudioConfig(OpenAIConfig): | |
| def get_supported_openai_params(self, model: str) -> List: | |
| model_supports_tool_choice = True # azure ai supports this by default | |
| if not supports_tool_choice(model=f"azure_ai/{model}"): | |
| model_supports_tool_choice = False | |
| supported_params = super().get_supported_openai_params(model) | |
| if not model_supports_tool_choice: | |
| filtered_supported_params = [] | |
| for param in supported_params: | |
| if param != "tool_choice": | |
| filtered_supported_params.append(param) | |
| return filtered_supported_params | |
| return supported_params | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| if api_base and self._should_use_api_key_header(api_base): | |
| headers["api-key"] = api_key | |
| else: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| return headers | |
| def _should_use_api_key_header(self, api_base: str) -> bool: | |
| """ | |
| Returns True if the request should use `api-key` header for authentication. | |
| """ | |
| parsed_url = urlparse(api_base) | |
| host = parsed_url.hostname | |
| if host and ( | |
| host.endswith(".services.ai.azure.com") | |
| or host.endswith(".openai.azure.com") | |
| ): | |
| return True | |
| return False | |
| def get_complete_url( | |
| self, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| model: str, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| stream: Optional[bool] = None, | |
| ) -> str: | |
| """ | |
| Constructs a complete URL for the API request. | |
| Args: | |
| - api_base: Base URL, e.g., | |
| "https://litellm8397336933.services.ai.azure.com" | |
| OR | |
| "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" | |
| - model: Model name. | |
| - optional_params: Additional query parameters, including "api_version". | |
| - stream: If streaming is required (optional). | |
| Returns: | |
| - A complete URL string, e.g., | |
| "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" | |
| """ | |
| if api_base is None: | |
| raise ValueError( | |
| f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" | |
| ) | |
| original_url = httpx.URL(api_base) | |
| # Extract api_version or use default | |
| api_version = cast(Optional[str], litellm_params.get("api_version")) | |
| # Create a new dictionary with existing params | |
| query_params = dict(original_url.params) | |
| # Add api_version if needed | |
| if "api-version" not in query_params and api_version: | |
| query_params["api-version"] = api_version | |
| # Add the path to the base URL | |
| if "services.ai.azure.com" in api_base: | |
| new_url = _add_path_to_api_base( | |
| api_base=api_base, ending_path="/models/chat/completions" | |
| ) | |
| else: | |
| new_url = _add_path_to_api_base( | |
| api_base=api_base, ending_path="/chat/completions" | |
| ) | |
| # Use the new query_params dictionary | |
| final_url = httpx.URL(new_url).copy_with(params=query_params) | |
| return str(final_url) | |
| def get_required_params(self) -> List[ProviderField]: | |
| """For a given provider, return it's required fields with a description""" | |
| return [ | |
| ProviderField( | |
| field_name="api_key", | |
| field_type="string", | |
| field_description="Your Azure AI Studio API Key.", | |
| field_value="zEJ...", | |
| ), | |
| ProviderField( | |
| field_name="api_base", | |
| field_type="string", | |
| field_description="Your Azure AI Studio API Base.", | |
| field_value="https://Mistral-serverless.", | |
| ), | |
| ] | |
| def _transform_messages( | |
| self, | |
| messages: List[AllMessageValues], | |
| model: str, | |
| ) -> List: | |
| """ | |
| - Azure AI Studio doesn't support content as a list. This handles: | |
| 1. Transforms list content to a string. | |
| 2. If message contains an image or audio, send as is (user-intended) | |
| """ | |
| for message in messages: | |
| # Do nothing if the message contains an image or audio | |
| if _audio_or_image_in_message_content(message): | |
| continue | |
| texts = convert_content_list_to_str(message=message) | |
| if texts: | |
| message["content"] = texts | |
| return messages | |
| def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool: | |
| try: | |
| if "/" in model: | |
| model = model.split("/", 1)[1] | |
| if ( | |
| model in litellm.open_ai_chat_completion_models | |
| or model in litellm.open_ai_text_completion_models | |
| or model in litellm.open_ai_embedding_models | |
| ): | |
| return True | |
| except Exception: | |
| return False | |
| return False | |
| def _get_openai_compatible_provider_info( | |
| self, | |
| model: str, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| custom_llm_provider: str, | |
| ) -> Tuple[Optional[str], Optional[str], str]: | |
| api_base = api_base or get_secret_str("AZURE_AI_API_BASE") | |
| dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY") | |
| if self._is_azure_openai_model(model=model, api_base=api_base): | |
| verbose_logger.debug( | |
| "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( | |
| model | |
| ) | |
| ) | |
| custom_llm_provider = "azure" | |
| return api_base, dynamic_api_key, custom_llm_provider | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| extra_body = optional_params.pop("extra_body", {}) | |
| if extra_body and isinstance(extra_body, dict): | |
| optional_params.update(extra_body) | |
| optional_params.pop("max_retries", None) | |
| return super().transform_request( | |
| model, messages, optional_params, litellm_params, headers | |
| ) | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| model_response.model = f"azure_ai/{model}" | |
| return super().transform_response( | |
| model=model, | |
| raw_response=raw_response, | |
| model_response=model_response, | |
| logging_obj=logging_obj, | |
| request_data=request_data, | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| encoding=encoding, | |
| api_key=api_key, | |
| json_mode=json_mode, | |
| ) | |
| def should_retry_llm_api_inside_llm_translation_on_http_error( | |
| self, e: httpx.HTTPStatusError, litellm_params: dict | |
| ) -> bool: | |
| should_drop_params = litellm_params.get("drop_params") or litellm.drop_params | |
| error_text = e.response.text | |
| if should_drop_params and "Extra inputs are not permitted" in error_text: | |
| return True | |
| elif ( | |
| "unknown field: parameter index is not a valid field" in error_text | |
| ): # remove index from tool calls | |
| return True | |
| elif ( | |
| AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value | |
| in error_text | |
| ): # remove extra-parameters from tool calls | |
| return True | |
| return super().should_retry_llm_api_inside_llm_translation_on_http_error( | |
| e=e, litellm_params=litellm_params | |
| ) | |
| def max_retry_on_unprocessable_entity_error(self) -> int: | |
| return 2 | |
| def transform_request_on_unprocessable_entity_error( | |
| self, e: httpx.HTTPStatusError, request_data: dict | |
| ) -> dict: | |
| _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) | |
| if ( | |
| "unknown field: parameter index is not a valid field" in e.response.text | |
| and _messages is not None | |
| ): | |
| litellm.remove_index_from_tool_calls( | |
| messages=_messages, | |
| ) | |
| elif ( | |
| AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value | |
| in e.response.text | |
| ): | |
| request_data = self._drop_extra_params_from_request_data( | |
| request_data, e.response.text | |
| ) | |
| data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) | |
| return data | |
| def _drop_extra_params_from_request_data( | |
| self, request_data: dict, error_text: str | |
| ) -> dict: | |
| params_to_drop = self._extract_params_to_drop_from_error_text(error_text) | |
| if params_to_drop: | |
| for param in params_to_drop: | |
| if param in request_data: | |
| request_data.pop(param, None) | |
| return request_data | |
| def _extract_params_to_drop_from_error_text( | |
| self, error_text: str | |
| ) -> Optional[List[str]]: | |
| """ | |
| Error text looks like this" | |
| "Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'. | |
| """ | |
| import re | |
| # Extract parameters within square brackets | |
| match = re.search(r"\[(.*?)\]", error_text) | |
| if not match: | |
| return [] | |
| # Parse the extracted string into a list of parameter names | |
| params_str = match.group(1) | |
| params = [] | |
| for param in params_str.split(","): | |
| # Clean up the parameter name (remove quotes, spaces) | |
| clean_param = param.strip().strip("'").strip('"') | |
| if clean_param: | |
| params.append(clean_param) | |
| return params | |