|
""" |
|
Transformation logic from OpenAI format to Gemini format. |
|
|
|
Why separate file? Make it easy to see how transformation works |
|
""" |
|
|
|
import os |
|
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast |
|
|
|
import httpx |
|
from pydantic import BaseModel |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.litellm_core_utils.prompt_templates.factory import ( |
|
convert_to_anthropic_image_obj, |
|
convert_to_gemini_tool_call_invoke, |
|
convert_to_gemini_tool_call_result, |
|
response_schema_prompt, |
|
) |
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler |
|
from litellm.types.files import ( |
|
get_file_mime_type_for_file_type, |
|
get_file_type_from_extension, |
|
is_gemini_1_5_accepted_file_type, |
|
) |
|
from litellm.types.llms.openai import ( |
|
AllMessageValues, |
|
ChatCompletionAssistantMessage, |
|
ChatCompletionImageObject, |
|
ChatCompletionTextObject, |
|
) |
|
from litellm.types.llms.vertex_ai import * |
|
from litellm.types.llms.vertex_ai import ( |
|
GenerationConfig, |
|
PartType, |
|
RequestBody, |
|
SafetSettingsConfig, |
|
SystemInstructions, |
|
ToolConfig, |
|
Tools, |
|
) |
|
|
|
from ..common_utils import ( |
|
_check_text_in_content, |
|
get_supports_response_schema, |
|
get_supports_system_message, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj |
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj |
|
else: |
|
LiteLLMLoggingObj = Any |
|
|
|
|
|
def _process_gemini_image(image_url: str) -> PartType: |
|
""" |
|
Given an image URL, return the appropriate PartType for Gemini |
|
""" |
|
try: |
|
|
|
if "gs://" in image_url: |
|
|
|
extension_with_dot = os.path.splitext(image_url)[-1] |
|
extension = extension_with_dot[1:] |
|
|
|
file_type = get_file_type_from_extension(extension) |
|
|
|
|
|
if not is_gemini_1_5_accepted_file_type(file_type): |
|
raise Exception(f"File type not supported by gemini - {file_type}") |
|
|
|
mime_type = get_file_mime_type_for_file_type(file_type) |
|
file_data = FileDataType(mime_type=mime_type, file_uri=image_url) |
|
|
|
return PartType(file_data=file_data) |
|
elif ( |
|
"https://" in image_url |
|
and (image_type := _get_image_mime_type_from_url(image_url)) is not None |
|
): |
|
file_data = FileDataType(file_uri=image_url, mime_type=image_type) |
|
return PartType(file_data=file_data) |
|
elif "http://" in image_url or "https://" in image_url or "base64" in image_url: |
|
|
|
image = convert_to_anthropic_image_obj(image_url) |
|
_blob = BlobType(data=image["data"], mime_type=image["media_type"]) |
|
return PartType(inline_data=_blob) |
|
raise Exception("Invalid image received - {}".format(image_url)) |
|
except Exception as e: |
|
raise e |
|
|
|
|
|
def _get_image_mime_type_from_url(url: str) -> Optional[str]: |
|
""" |
|
Get mime type for common image URLs |
|
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements |
|
|
|
Supported by Gemini: |
|
- PNG (`image/png`) |
|
- JPEG (`image/jpeg`) |
|
- WebP (`image/webp`) |
|
Example: |
|
url = https://example.com/image.jpg |
|
Returns: image/jpeg |
|
""" |
|
url = url.lower() |
|
if url.endswith((".jpg", ".jpeg")): |
|
return "image/jpeg" |
|
elif url.endswith(".png"): |
|
return "image/png" |
|
elif url.endswith(".webp"): |
|
return "image/webp" |
|
elif url.endswith(".mp4"): |
|
return "video/mp4" |
|
elif url.endswith(".pdf"): |
|
return "application/pdf" |
|
return None |
|
|
|
|
|
def _gemini_convert_messages_with_history( |
|
messages: List[AllMessageValues], |
|
) -> List[ContentType]: |
|
""" |
|
Converts given messages from OpenAI format to Gemini format |
|
|
|
- Parts must be iterable |
|
- Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) |
|
- Please ensure that function response turn comes immediately after a function call turn |
|
""" |
|
user_message_types = {"user", "system"} |
|
contents: List[ContentType] = [] |
|
|
|
last_message_with_tool_calls = None |
|
|
|
msg_i = 0 |
|
tool_call_responses = [] |
|
try: |
|
while msg_i < len(messages): |
|
user_content: List[PartType] = [] |
|
init_msg_i = msg_i |
|
|
|
while ( |
|
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types |
|
): |
|
_message_content = messages[msg_i].get("content") |
|
if _message_content is not None and isinstance(_message_content, list): |
|
_parts: List[PartType] = [] |
|
for element in _message_content: |
|
if ( |
|
element["type"] == "text" |
|
and "text" in element |
|
and len(element["text"]) > 0 |
|
): |
|
element = cast(ChatCompletionTextObject, element) |
|
_part = PartType(text=element["text"]) |
|
_parts.append(_part) |
|
elif element["type"] == "image_url": |
|
element = cast(ChatCompletionImageObject, element) |
|
img_element = element |
|
if isinstance(img_element["image_url"], dict): |
|
image_url = img_element["image_url"]["url"] |
|
else: |
|
image_url = img_element["image_url"] |
|
_part = _process_gemini_image(image_url=image_url) |
|
_parts.append(_part) |
|
user_content.extend(_parts) |
|
elif ( |
|
_message_content is not None |
|
and isinstance(_message_content, str) |
|
and len(_message_content) > 0 |
|
): |
|
_part = PartType(text=_message_content) |
|
user_content.append(_part) |
|
|
|
msg_i += 1 |
|
|
|
if user_content: |
|
""" |
|
check that user_content has 'text' parameter. |
|
- Known Vertex Error: Unable to submit request because it must have a text parameter. |
|
- Relevant Issue: https://github.com/BerriAI/litellm/issues/5515 |
|
""" |
|
has_text_in_content = _check_text_in_content(user_content) |
|
if has_text_in_content is False: |
|
verbose_logger.warning( |
|
"No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515" |
|
) |
|
user_content.append( |
|
PartType(text=" ") |
|
) |
|
contents.append(ContentType(role="user", parts=user_content)) |
|
assistant_content = [] |
|
|
|
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": |
|
if isinstance(messages[msg_i], BaseModel): |
|
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() |
|
else: |
|
msg_dict = messages[msg_i] |
|
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) |
|
_message_content = assistant_msg.get("content", None) |
|
if _message_content is not None and isinstance(_message_content, list): |
|
_parts = [] |
|
for element in _message_content: |
|
if isinstance(element, dict): |
|
if element["type"] == "text": |
|
_part = PartType(text=element["text"]) |
|
_parts.append(_part) |
|
assistant_content.extend(_parts) |
|
elif ( |
|
_message_content is not None |
|
and isinstance(_message_content, str) |
|
and _message_content |
|
): |
|
assistant_text = _message_content |
|
assistant_content.append(PartType(text=assistant_text)) |
|
|
|
|
|
if ( |
|
assistant_msg.get("tool_calls", []) is not None |
|
or assistant_msg.get("function_call") is not None |
|
): |
|
assistant_content.extend( |
|
convert_to_gemini_tool_call_invoke(assistant_msg) |
|
) |
|
last_message_with_tool_calls = assistant_msg |
|
|
|
msg_i += 1 |
|
|
|
if assistant_content: |
|
contents.append(ContentType(role="model", parts=assistant_content)) |
|
|
|
|
|
tool_call_message_roles = ["tool", "function"] |
|
if ( |
|
msg_i < len(messages) |
|
and messages[msg_i]["role"] in tool_call_message_roles |
|
): |
|
_part = convert_to_gemini_tool_call_result( |
|
messages[msg_i], last_message_with_tool_calls |
|
) |
|
msg_i += 1 |
|
tool_call_responses.append(_part) |
|
if msg_i < len(messages) and ( |
|
messages[msg_i]["role"] not in tool_call_message_roles |
|
): |
|
if len(tool_call_responses) > 0: |
|
contents.append(ContentType(parts=tool_call_responses)) |
|
tool_call_responses = [] |
|
|
|
if msg_i == init_msg_i: |
|
raise Exception( |
|
"Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( |
|
messages[msg_i] |
|
) |
|
) |
|
if len(tool_call_responses) > 0: |
|
contents.append(ContentType(parts=tool_call_responses)) |
|
return contents |
|
except Exception as e: |
|
raise e |
|
|
|
|
|
def _transform_request_body( |
|
messages: List[AllMessageValues], |
|
model: str, |
|
optional_params: dict, |
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], |
|
litellm_params: dict, |
|
cached_content: Optional[str], |
|
) -> RequestBody: |
|
""" |
|
Common transformation logic across sync + async Gemini /generateContent calls. |
|
""" |
|
|
|
supports_system_message = get_supports_system_message( |
|
model=model, custom_llm_provider=custom_llm_provider |
|
) |
|
system_instructions, messages = _transform_system_message( |
|
supports_system_message=supports_system_message, messages=messages |
|
) |
|
|
|
if "response_schema" in optional_params: |
|
supports_response_schema = get_supports_response_schema( |
|
model=model, custom_llm_provider=custom_llm_provider |
|
) |
|
if supports_response_schema is False: |
|
user_response_schema_message = response_schema_prompt( |
|
model=model, response_schema=optional_params.get("response_schema") |
|
) |
|
messages.append({"role": "user", "content": user_response_schema_message}) |
|
optional_params.pop("response_schema") |
|
|
|
|
|
|
|
remove_keys = [] |
|
for k, v in optional_params.items(): |
|
if k.startswith("litellm_param_"): |
|
litellm_params.update({k: v}) |
|
remove_keys.append(k) |
|
|
|
optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys} |
|
|
|
try: |
|
if custom_llm_provider == "gemini": |
|
content = litellm.GoogleAIStudioGeminiConfig()._transform_messages( |
|
messages=messages |
|
) |
|
else: |
|
content = litellm.VertexGeminiConfig()._transform_messages( |
|
messages=messages |
|
) |
|
tools: Optional[Tools] = optional_params.pop("tools", None) |
|
tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) |
|
safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( |
|
"safety_settings", None |
|
) |
|
config_fields = GenerationConfig.__annotations__.keys() |
|
|
|
filtered_params = { |
|
k: v for k, v in optional_params.items() if k in config_fields |
|
} |
|
|
|
generation_config: Optional[GenerationConfig] = GenerationConfig( |
|
**filtered_params |
|
) |
|
data = RequestBody(contents=content) |
|
if system_instructions is not None: |
|
data["system_instruction"] = system_instructions |
|
if tools is not None: |
|
data["tools"] = tools |
|
if tool_choice is not None: |
|
data["toolConfig"] = tool_choice |
|
if safety_settings is not None: |
|
data["safetySettings"] = safety_settings |
|
if generation_config is not None: |
|
data["generationConfig"] = generation_config |
|
if cached_content is not None: |
|
data["cachedContent"] = cached_content |
|
except Exception as e: |
|
raise e |
|
|
|
return data |
|
|
|
|
|
def sync_transform_request_body( |
|
gemini_api_key: Optional[str], |
|
messages: List[AllMessageValues], |
|
api_base: Optional[str], |
|
model: str, |
|
client: Optional[HTTPHandler], |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
extra_headers: Optional[dict], |
|
optional_params: dict, |
|
logging_obj: LiteLLMLoggingObj, |
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], |
|
litellm_params: dict, |
|
) -> RequestBody: |
|
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints |
|
|
|
context_caching_endpoints = ContextCachingEndpoints() |
|
|
|
if gemini_api_key is not None: |
|
messages, cached_content = context_caching_endpoints.check_and_create_cache( |
|
messages=messages, |
|
api_key=gemini_api_key, |
|
api_base=api_base, |
|
model=model, |
|
client=client, |
|
timeout=timeout, |
|
extra_headers=extra_headers, |
|
cached_content=optional_params.pop("cached_content", None), |
|
logging_obj=logging_obj, |
|
) |
|
else: |
|
cached_content = optional_params.pop("cached_content", None) |
|
|
|
return _transform_request_body( |
|
messages=messages, |
|
model=model, |
|
custom_llm_provider=custom_llm_provider, |
|
litellm_params=litellm_params, |
|
cached_content=cached_content, |
|
optional_params=optional_params, |
|
) |
|
|
|
|
|
async def async_transform_request_body( |
|
gemini_api_key: Optional[str], |
|
messages: List[AllMessageValues], |
|
api_base: Optional[str], |
|
model: str, |
|
client: Optional[AsyncHTTPHandler], |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
extra_headers: Optional[dict], |
|
optional_params: dict, |
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, |
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], |
|
litellm_params: dict, |
|
) -> RequestBody: |
|
from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints |
|
|
|
context_caching_endpoints = ContextCachingEndpoints() |
|
|
|
if gemini_api_key is not None: |
|
messages, cached_content = ( |
|
await context_caching_endpoints.async_check_and_create_cache( |
|
messages=messages, |
|
api_key=gemini_api_key, |
|
api_base=api_base, |
|
model=model, |
|
client=client, |
|
timeout=timeout, |
|
extra_headers=extra_headers, |
|
cached_content=optional_params.pop("cached_content", None), |
|
logging_obj=logging_obj, |
|
) |
|
) |
|
else: |
|
cached_content = optional_params.pop("cached_content", None) |
|
|
|
return _transform_request_body( |
|
messages=messages, |
|
model=model, |
|
custom_llm_provider=custom_llm_provider, |
|
litellm_params=litellm_params, |
|
cached_content=cached_content, |
|
optional_params=optional_params, |
|
) |
|
|
|
|
|
def _transform_system_message( |
|
supports_system_message: bool, messages: List[AllMessageValues] |
|
) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]: |
|
""" |
|
Extracts the system message from the openai message list. |
|
|
|
Converts the system message to Gemini format |
|
|
|
Returns |
|
- system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format. |
|
- messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately) |
|
""" |
|
|
|
system_prompt_indices = [] |
|
system_content_blocks: List[PartType] = [] |
|
if supports_system_message is True: |
|
for idx, message in enumerate(messages): |
|
if message["role"] == "system": |
|
_system_content_block: Optional[PartType] = None |
|
if isinstance(message["content"], str): |
|
_system_content_block = PartType(text=message["content"]) |
|
elif isinstance(message["content"], list): |
|
system_text = "" |
|
for content in message["content"]: |
|
system_text += content.get("text") or "" |
|
_system_content_block = PartType(text=system_text) |
|
if _system_content_block is not None: |
|
system_content_blocks.append(_system_content_block) |
|
system_prompt_indices.append(idx) |
|
if len(system_prompt_indices) > 0: |
|
for idx in reversed(system_prompt_indices): |
|
messages.pop(idx) |
|
|
|
if len(system_content_blocks) > 0: |
|
return SystemInstructions(parts=system_content_blocks), messages |
|
|
|
return None, messages |
|
|