|
import json |
|
from typing import List, Optional |
|
|
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ( |
|
ChatCompletionToolCallChunk, |
|
ChatCompletionUsageBlock, |
|
GenericStreamingChunk, |
|
) |
|
|
|
|
|
class CohereError(BaseLLMException): |
|
def __init__(self, status_code, message): |
|
super().__init__(status_code=status_code, message=message) |
|
|
|
|
|
def validate_environment( |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
) -> dict: |
|
""" |
|
Return headers to use for cohere chat completion request |
|
|
|
Cohere API Ref: https://docs.cohere.com/reference/chat |
|
Expected headers: |
|
{ |
|
"Request-Source": "unspecified:litellm", |
|
"accept": "application/json", |
|
"content-type": "application/json", |
|
"Authorization": "bearer $CO_API_KEY" |
|
} |
|
""" |
|
headers.update( |
|
{ |
|
"Request-Source": "unspecified:litellm", |
|
"accept": "application/json", |
|
"content-type": "application/json", |
|
} |
|
) |
|
if api_key: |
|
headers["Authorization"] = f"bearer {api_key}" |
|
return headers |
|
|
|
|
|
class ModelResponseIterator: |
|
def __init__( |
|
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False |
|
): |
|
self.streaming_response = streaming_response |
|
self.response_iterator = self.streaming_response |
|
self.content_blocks: List = [] |
|
self.tool_index = -1 |
|
self.json_mode = json_mode |
|
|
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: |
|
try: |
|
text = "" |
|
tool_use: Optional[ChatCompletionToolCallChunk] = None |
|
is_finished = False |
|
finish_reason = "" |
|
usage: Optional[ChatCompletionUsageBlock] = None |
|
provider_specific_fields = None |
|
|
|
index = int(chunk.get("index", 0)) |
|
|
|
if "text" in chunk: |
|
text = chunk["text"] |
|
elif "is_finished" in chunk and chunk["is_finished"] is True: |
|
is_finished = chunk["is_finished"] |
|
finish_reason = chunk["finish_reason"] |
|
|
|
if "citations" in chunk: |
|
provider_specific_fields = {"citations": chunk["citations"]} |
|
|
|
returned_chunk = GenericStreamingChunk( |
|
text=text, |
|
tool_use=tool_use, |
|
is_finished=is_finished, |
|
finish_reason=finish_reason, |
|
usage=usage, |
|
index=index, |
|
provider_specific_fields=provider_specific_fields, |
|
) |
|
|
|
return returned_chunk |
|
|
|
except json.JSONDecodeError: |
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}") |
|
|
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
try: |
|
chunk = self.response_iterator.__next__() |
|
except StopIteration: |
|
raise StopIteration |
|
except ValueError as e: |
|
raise RuntimeError(f"Error receiving chunk from stream: {e}") |
|
|
|
try: |
|
str_line = chunk |
|
if isinstance(chunk, bytes): |
|
str_line = chunk.decode("utf-8") |
|
index = str_line.find("data:") |
|
if index != -1: |
|
str_line = str_line[index:] |
|
data_json = json.loads(str_line) |
|
return self.chunk_parser(chunk=data_json) |
|
except StopIteration: |
|
raise StopIteration |
|
except ValueError as e: |
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
|
|
|
|
|
def __aiter__(self): |
|
self.async_response_iterator = self.streaming_response.__aiter__() |
|
return self |
|
|
|
async def __anext__(self): |
|
try: |
|
chunk = await self.async_response_iterator.__anext__() |
|
except StopAsyncIteration: |
|
raise StopAsyncIteration |
|
except ValueError as e: |
|
raise RuntimeError(f"Error receiving chunk from stream: {e}") |
|
|
|
try: |
|
str_line = chunk |
|
if isinstance(chunk, bytes): |
|
str_line = chunk.decode("utf-8") |
|
index = str_line.find("data:") |
|
if index != -1: |
|
str_line = str_line[index:] |
|
|
|
data_json = json.loads(str_line) |
|
return self.chunk_parser(chunk=data_json) |
|
except StopAsyncIteration: |
|
raise StopAsyncIteration |
|
except ValueError as e: |
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |
|
|