# What is this? ## handler file for TextCompletionCodestral Integration - https://codestral.com/ import json from functools import partial from typing import Callable, List, Optional, Union import httpx # type: ignore import litellm from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, ) from litellm.types.utils import TextChoices from litellm.utils import CustomStreamWrapper, TextCompletionResponse class TextCompletionCodestralError(Exception): def __init__( self, status_code, message, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, ): self.status_code = status_code self.message = message if request is not None: self.request = request else: self.request = httpx.Request( method="POST", url="https://docs.codestral.com/user-guide/inference/rest_api", ) if response is not None: self.response = response else: self.response = httpx.Response( status_code=status_code, request=self.request ) super().__init__( self.message ) # Call the base class constructor with the parameters it needs async def make_call( client: AsyncHTTPHandler, api_base: str, headers: dict, data: str, model: str, messages: list, logging_obj, ): response = await client.post(api_base, headers=headers, data=data, stream=True) if response.status_code != 200: raise TextCompletionCodestralError( status_code=response.status_code, message=response.text ) completion_stream = response.aiter_lines() # LOGGING logging_obj.post_call( input=messages, api_key="", original_response=completion_stream, # Pass the completion stream for logging additional_args={"complete_input_dict": data}, ) return completion_stream class CodestralTextCompletion: def __init__(self) -> None: super().__init__() def _validate_environment( self, api_key: Optional[str], user_headers: dict, ) -> dict: if api_key is None: raise ValueError( "Missing CODESTRAL_API_Key - Please add CODESTRAL_API_Key to your environment variables" ) headers = { "content-type": "application/json", "Authorization": "Bearer {}".format(api_key), } if user_headers is not None and isinstance(user_headers, dict): headers = {**headers, **user_headers} return headers def output_parser(self, generated_text: str): """ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 """ chat_template_tokens = [ "<|assistant|>", "<|system|>", "<|user|>", "", "", ] for token in chat_template_tokens: if generated_text.strip().startswith(token): generated_text = generated_text.replace(token, "", 1) if generated_text.endswith(token): generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] return generated_text def process_text_completion_response( self, model: str, response: httpx.Response, model_response: TextCompletionResponse, stream: bool, logging_obj: LiteLLMLogging, optional_params: dict, api_key: str, data: Union[dict, str], messages: list, print_verbose, encoding, ) -> TextCompletionResponse: ## LOGGING logging_obj.post_call( input=messages, api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, ) print_verbose(f"codestral api: raw model_response: {response.text}") ## RESPONSE OBJECT if response.status_code != 200: raise TextCompletionCodestralError( message=str(response.text), status_code=response.status_code, ) try: completion_response = response.json() except Exception: raise TextCompletionCodestralError(message=response.text, status_code=422) _original_choices = completion_response.get("choices", []) _choices: List[TextChoices] = [] for choice in _original_choices: # This is what 1 choice looks like from codestral API # { # "index": 0, # "message": { # "role": "assistant", # "content": "\n assert is_odd(1)\n assert", # "tool_calls": null # }, # "finish_reason": "length", # "logprobs": null # } _finish_reason = None _index = 0 _text = None _logprobs = None _choice_message = choice.get("message", {}) _choice = litellm.utils.TextChoices( finish_reason=choice.get("finish_reason"), index=choice.get("index"), text=_choice_message.get("content"), logprobs=choice.get("logprobs"), ) _choices.append(_choice) _response = litellm.TextCompletionResponse( id=completion_response.get("id"), choices=_choices, created=completion_response.get("created"), model=completion_response.get("model"), usage=completion_response.get("usage"), stream=False, object=completion_response.get("object"), ) return _response def completion( self, model: str, messages: list, api_base: str, custom_prompt_dict: dict, model_response: TextCompletionResponse, print_verbose: Callable, encoding, api_key: str, logging_obj, optional_params: dict, timeout: Union[float, httpx.Timeout], acompletion=None, litellm_params=None, logger_fn=None, headers: dict = {}, ) -> Union[TextCompletionResponse, CustomStreamWrapper]: headers = self._validate_environment(api_key, headers) if optional_params.pop("custom_endpoint", None) is True: completion_url = api_base else: completion_url = ( api_base or "https://codestral.mistral.ai/v1/fim/completions" ) if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], initial_prompt_value=model_prompt_details["initial_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) ## Load Config config = litellm.CodestralTextCompletionConfig.get_config() for k, v in config.items(): if ( k not in optional_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v stream = optional_params.pop("stream", False) data = { "model": model, "prompt": prompt, **optional_params, } input_text = prompt ## LOGGING logging_obj.pre_call( input=input_text, api_key=api_key, additional_args={ "complete_input_dict": data, "headers": headers, "api_base": completion_url, "acompletion": acompletion, }, ) ## COMPLETION CALL if acompletion is True: ### ASYNC STREAMING if stream is True: return self.async_streaming( model=model, messages=messages, data=data, api_base=completion_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, timeout=timeout, ) # type: ignore else: ### ASYNC COMPLETION return self.async_completion( model=model, messages=messages, data=data, api_base=completion_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, stream=False, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, timeout=timeout, ) # type: ignore ### SYNC STREAMING if stream is True: response = litellm.module_level_client.post( completion_url, headers=headers, data=json.dumps(data), stream=stream, ) _response = CustomStreamWrapper( response.iter_lines(), model, custom_llm_provider="codestral", logging_obj=logging_obj, ) return _response ### SYNC COMPLETION else: response = litellm.module_level_client.post( url=completion_url, headers=headers, data=json.dumps(data), ) return self.process_text_completion_response( model=model, response=response, model_response=model_response, stream=optional_params.get("stream", False), logging_obj=logging_obj, # type: ignore optional_params=optional_params, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, encoding=encoding, ) async def async_completion( self, model: str, messages: list, api_base: str, model_response: TextCompletionResponse, print_verbose: Callable, encoding, api_key, logging_obj, stream, data: dict, optional_params: dict, timeout: Union[float, httpx.Timeout], litellm_params=None, logger_fn=None, headers={}, ) -> TextCompletionResponse: async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.TEXT_COMPLETION_CODESTRAL, params={"timeout": timeout}, ) try: response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) except httpx.HTTPStatusError as e: raise TextCompletionCodestralError( status_code=e.response.status_code, message="HTTPStatusError - {}".format(e.response.text), ) except Exception as e: raise TextCompletionCodestralError( status_code=500, message="{}".format(str(e)) ) # don't use verbose_logger.exception, if exception is raised return self.process_text_completion_response( model=model, response=response, model_response=model_response, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, ) async def async_streaming( self, model: str, messages: list, api_base: str, model_response: TextCompletionResponse, print_verbose: Callable, encoding, api_key, logging_obj, data: dict, timeout: Union[float, httpx.Timeout], optional_params=None, litellm_params=None, logger_fn=None, headers={}, ) -> CustomStreamWrapper: data["stream"] = True streamwrapper = CustomStreamWrapper( completion_stream=None, make_call=partial( make_call, api_base=api_base, headers=headers, data=json.dumps(data), model=model, messages=messages, logging_obj=logging_obj, ), model=model, custom_llm_provider="text-completion-codestral", logging_obj=logging_obj, ) return streamwrapper def embedding(self, *args, **kwargs): pass