|
from typing import List, Optional, Union |
|
|
|
from openai import OpenAI |
|
|
|
import litellm |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
HTTPHandler, |
|
get_async_httpx_client, |
|
) |
|
from litellm.llms.openai.openai import OpenAIChatCompletion |
|
from litellm.types.llms.azure_ai import ImageEmbeddingRequest |
|
from litellm.types.utils import EmbeddingResponse |
|
from litellm.utils import convert_to_model_response_object |
|
|
|
from .cohere_transformation import AzureAICohereConfig |
|
|
|
|
|
class AzureAIEmbedding(OpenAIChatCompletion): |
|
|
|
def _process_response( |
|
self, |
|
image_embedding_responses: Optional[List], |
|
text_embedding_responses: Optional[List], |
|
image_embeddings_idx: List[int], |
|
model_response: EmbeddingResponse, |
|
input: List, |
|
): |
|
combined_responses = [] |
|
if ( |
|
image_embedding_responses is not None |
|
and text_embedding_responses is not None |
|
): |
|
|
|
text_idx = 0 |
|
image_idx = 0 |
|
|
|
for idx in range(len(input)): |
|
if idx in image_embeddings_idx: |
|
combined_responses.append(image_embedding_responses[image_idx]) |
|
image_idx += 1 |
|
else: |
|
combined_responses.append(text_embedding_responses[text_idx]) |
|
text_idx += 1 |
|
|
|
model_response.data = combined_responses |
|
elif image_embedding_responses is not None: |
|
model_response.data = image_embedding_responses |
|
elif text_embedding_responses is not None: |
|
model_response.data = text_embedding_responses |
|
|
|
response = AzureAICohereConfig()._transform_response(response=model_response) |
|
|
|
return response |
|
|
|
async def async_image_embedding( |
|
self, |
|
model: str, |
|
data: ImageEmbeddingRequest, |
|
timeout: float, |
|
logging_obj, |
|
model_response: litellm.EmbeddingResponse, |
|
optional_params: dict, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, |
|
) -> EmbeddingResponse: |
|
if client is None or not isinstance(client, AsyncHTTPHandler): |
|
client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.AZURE_AI, |
|
params={"timeout": timeout}, |
|
) |
|
|
|
url = "{}/images/embeddings".format(api_base) |
|
|
|
response = await client.post( |
|
url=url, |
|
json=data, |
|
headers={"Authorization": "Bearer {}".format(api_key)}, |
|
) |
|
|
|
embedding_response = response.json() |
|
embedding_headers = dict(response.headers) |
|
returned_response: EmbeddingResponse = convert_to_model_response_object( |
|
response_object=embedding_response, |
|
model_response_object=model_response, |
|
response_type="embedding", |
|
stream=False, |
|
_response_headers=embedding_headers, |
|
) |
|
return returned_response |
|
|
|
def image_embedding( |
|
self, |
|
model: str, |
|
data: ImageEmbeddingRequest, |
|
timeout: float, |
|
logging_obj, |
|
model_response: EmbeddingResponse, |
|
optional_params: dict, |
|
api_key: Optional[str], |
|
api_base: Optional[str], |
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, |
|
): |
|
if api_base is None: |
|
raise ValueError( |
|
"api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request." |
|
) |
|
if api_key is None: |
|
raise ValueError( |
|
"api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request." |
|
) |
|
|
|
if client is None or not isinstance(client, HTTPHandler): |
|
client = HTTPHandler(timeout=timeout, concurrent_limit=1) |
|
|
|
url = "{}/images/embeddings".format(api_base) |
|
|
|
response = client.post( |
|
url=url, |
|
json=data, |
|
headers={"Authorization": "Bearer {}".format(api_key)}, |
|
) |
|
|
|
embedding_response = response.json() |
|
embedding_headers = dict(response.headers) |
|
returned_response: EmbeddingResponse = convert_to_model_response_object( |
|
response_object=embedding_response, |
|
model_response_object=model_response, |
|
response_type="embedding", |
|
stream=False, |
|
_response_headers=embedding_headers, |
|
) |
|
return returned_response |
|
|
|
async def async_embedding( |
|
self, |
|
model: str, |
|
input: List, |
|
timeout: float, |
|
logging_obj, |
|
model_response: litellm.EmbeddingResponse, |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
client=None, |
|
) -> EmbeddingResponse: |
|
|
|
( |
|
image_embeddings_request, |
|
v1_embeddings_request, |
|
image_embeddings_idx, |
|
) = AzureAICohereConfig()._transform_request( |
|
input=input, optional_params=optional_params, model=model |
|
) |
|
|
|
image_embedding_responses: Optional[List] = None |
|
text_embedding_responses: Optional[List] = None |
|
|
|
if image_embeddings_request["input"]: |
|
image_response = await self.async_image_embedding( |
|
model=model, |
|
data=image_embeddings_request, |
|
timeout=timeout, |
|
logging_obj=logging_obj, |
|
model_response=model_response, |
|
optional_params=optional_params, |
|
api_key=api_key, |
|
api_base=api_base, |
|
client=client, |
|
) |
|
|
|
image_embedding_responses = image_response.data |
|
if image_embedding_responses is None: |
|
raise Exception("/image/embeddings route returned None Embeddings.") |
|
|
|
if v1_embeddings_request["input"]: |
|
response: EmbeddingResponse = await super().embedding( |
|
model=model, |
|
input=input, |
|
timeout=timeout, |
|
logging_obj=logging_obj, |
|
model_response=model_response, |
|
optional_params=optional_params, |
|
api_key=api_key, |
|
api_base=api_base, |
|
client=client, |
|
aembedding=True, |
|
) |
|
text_embedding_responses = response.data |
|
if text_embedding_responses is None: |
|
raise Exception("/v1/embeddings route returned None Embeddings.") |
|
|
|
return self._process_response( |
|
image_embedding_responses=image_embedding_responses, |
|
text_embedding_responses=text_embedding_responses, |
|
image_embeddings_idx=image_embeddings_idx, |
|
model_response=model_response, |
|
input=input, |
|
) |
|
|
|
def embedding( |
|
self, |
|
model: str, |
|
input: List, |
|
timeout: float, |
|
logging_obj, |
|
model_response: EmbeddingResponse, |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
client=None, |
|
aembedding=None, |
|
max_retries: Optional[int] = None, |
|
) -> EmbeddingResponse: |
|
""" |
|
- Separate image url from text |
|
-> route image url call to `/image/embeddings` |
|
-> route text call to `/v1/embeddings` (OpenAI route) |
|
|
|
assemble result in-order, and return |
|
""" |
|
if aembedding is True: |
|
return self.async_embedding( |
|
model, |
|
input, |
|
timeout, |
|
logging_obj, |
|
model_response, |
|
optional_params, |
|
api_key, |
|
api_base, |
|
client, |
|
) |
|
|
|
( |
|
image_embeddings_request, |
|
v1_embeddings_request, |
|
image_embeddings_idx, |
|
) = AzureAICohereConfig()._transform_request( |
|
input=input, optional_params=optional_params, model=model |
|
) |
|
|
|
image_embedding_responses: Optional[List] = None |
|
text_embedding_responses: Optional[List] = None |
|
|
|
if image_embeddings_request["input"]: |
|
image_response = self.image_embedding( |
|
model=model, |
|
data=image_embeddings_request, |
|
timeout=timeout, |
|
logging_obj=logging_obj, |
|
model_response=model_response, |
|
optional_params=optional_params, |
|
api_key=api_key, |
|
api_base=api_base, |
|
client=client, |
|
) |
|
|
|
image_embedding_responses = image_response.data |
|
if image_embedding_responses is None: |
|
raise Exception("/image/embeddings route returned None Embeddings.") |
|
|
|
if v1_embeddings_request["input"]: |
|
response: EmbeddingResponse = super().embedding( |
|
model, |
|
input, |
|
timeout, |
|
logging_obj, |
|
model_response, |
|
optional_params, |
|
api_key, |
|
api_base, |
|
client=( |
|
client |
|
if client is not None and isinstance(client, OpenAI) |
|
else None |
|
), |
|
aembedding=aembedding, |
|
) |
|
|
|
text_embedding_responses = response.data |
|
if text_embedding_responses is None: |
|
raise Exception("/v1/embeddings route returned None Embeddings.") |
|
|
|
return self._process_response( |
|
image_embedding_responses=image_embedding_responses, |
|
text_embedding_responses=text_embedding_responses, |
|
image_embeddings_idx=image_embeddings_idx, |
|
model_response=model_response, |
|
input=input, |
|
) |
|
|