|
""" |
|
Utility functions for base LLM classes. |
|
""" |
|
|
|
import copy |
|
from abc import ABC, abstractmethod |
|
from typing import List, Optional, Type, Union |
|
|
|
from openai.lib import _parsing, _pydantic |
|
from pydantic import BaseModel |
|
|
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ProviderSpecificModelInfo |
|
|
|
|
|
class BaseLLMModelInfo(ABC): |
|
def get_provider_info( |
|
self, |
|
model: str, |
|
) -> Optional[ProviderSpecificModelInfo]: |
|
return None |
|
|
|
@abstractmethod |
|
def get_models(self) -> List[str]: |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def get_api_key(api_key: Optional[str] = None) -> Optional[str]: |
|
pass |
|
|
|
@staticmethod |
|
@abstractmethod |
|
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: |
|
pass |
|
|
|
|
|
def _dict_to_response_format_helper( |
|
response_format: dict, ref_template: Optional[str] = None |
|
) -> dict: |
|
if ref_template is not None and response_format.get("type") == "json_schema": |
|
|
|
modified_format = copy.deepcopy(response_format) |
|
schema = modified_format["json_schema"]["schema"] |
|
|
|
|
|
def update_refs(schema): |
|
stack = [(schema, [])] |
|
visited = set() |
|
|
|
while stack: |
|
obj, path = stack.pop() |
|
obj_id = id(obj) |
|
|
|
if obj_id in visited: |
|
continue |
|
visited.add(obj_id) |
|
|
|
if isinstance(obj, dict): |
|
if "$ref" in obj: |
|
ref_path = obj["$ref"] |
|
model_name = ref_path.split("/")[-1] |
|
obj["$ref"] = ref_template.format(model=model_name) |
|
|
|
for k, v in obj.items(): |
|
if isinstance(v, (dict, list)): |
|
stack.append((v, path + [k])) |
|
|
|
elif isinstance(obj, list): |
|
for i, item in enumerate(obj): |
|
if isinstance(item, (dict, list)): |
|
stack.append((item, path + [i])) |
|
|
|
update_refs(schema) |
|
return modified_format |
|
return response_format |
|
|
|
|
|
def type_to_response_format_param( |
|
response_format: Optional[Union[Type[BaseModel], dict]], |
|
ref_template: Optional[str] = None, |
|
) -> Optional[dict]: |
|
""" |
|
Re-implementation of openai's 'type_to_response_format_param' function |
|
|
|
Used for converting pydantic object to api schema. |
|
""" |
|
if response_format is None: |
|
return None |
|
|
|
if isinstance(response_format, dict): |
|
return _dict_to_response_format_helper(response_format, ref_template) |
|
|
|
|
|
|
|
|
|
if not _parsing._completions.is_basemodel_type(response_format): |
|
raise TypeError(f"Unsupported response_format type - {response_format}") |
|
|
|
if ref_template is not None: |
|
schema = response_format.model_json_schema(ref_template=ref_template) |
|
else: |
|
schema = _pydantic.to_strict_json_schema(response_format) |
|
|
|
return { |
|
"type": "json_schema", |
|
"json_schema": { |
|
"schema": schema, |
|
"name": response_format.__name__, |
|
"strict": True, |
|
}, |
|
} |
|
|
|
|
|
def map_developer_role_to_system_role( |
|
messages: List[AllMessageValues], |
|
) -> List[AllMessageValues]: |
|
""" |
|
Translate `developer` role to `system` role for non-OpenAI providers. |
|
""" |
|
new_messages: List[AllMessageValues] = [] |
|
for m in messages: |
|
if m["role"] == "developer": |
|
new_messages.append({"role": "system", "content": m["content"]}) |
|
else: |
|
new_messages.append(m) |
|
return new_messages |
|
|