|
import json |
|
from functools import wraps |
|
from typing import Any, Callable |
|
|
|
from openai.types.chat.chat_completion_message_tool_call_param import ( |
|
ChatCompletionMessageToolCallParam, |
|
Function, |
|
) |
|
|
|
from neollm.constants import ( |
|
MAX_IMAGE_URL_LEN, |
|
ROLE2CPRINT_PARAMS, |
|
TITLE_COLOR, |
|
YEN_PAR_DOLLAR, |
|
) |
|
from neollm.types import ( |
|
ChatCompletionMessage, |
|
Chunk, |
|
ClientSettings, |
|
InputType, |
|
LLMSettings, |
|
Message, |
|
Messages, |
|
OutputType, |
|
PriceInfo, |
|
TokenInfo, |
|
) |
|
from neollm.utils.postprocess import json2dict |
|
from neollm.utils.utils import cprint |
|
|
|
|
|
def exception_log(func: Callable[..., None]) -> Callable[..., None]: |
|
@wraps(func) |
|
def wrapper(*args: Any, **kwargs: Any) -> None: |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
cprint(e, color="red", background=True) |
|
|
|
return wrapper |
|
|
|
|
|
def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]: |
|
tool_calls: list[ChatCompletionMessageToolCallParam] = [] |
|
if "tool_calls" in message_dict: |
|
_tool_calls = message_dict.get("tool_calls", None) |
|
if _tool_calls is not None and isinstance(_tool_calls, list): |
|
for _tool_call in _tool_calls: |
|
tool_call = ChatCompletionMessageToolCallParam( |
|
id=_tool_call["id"], |
|
function=Function( |
|
arguments=_tool_call["function"]["arguments"], |
|
name=_tool_call["function"]["name"], |
|
), |
|
type=_tool_call["type"], |
|
) |
|
tool_calls.append(tool_call) |
|
if "function_call" in message_dict: |
|
function_call = message_dict.get("function_call", None) |
|
if function_call is not None and isinstance( |
|
function_call, dict |
|
): |
|
tool_calls.append( |
|
ChatCompletionMessageToolCallParam( |
|
id="", |
|
function=Function( |
|
arguments=function_call["arguments"], |
|
name=function_call["name"], |
|
), |
|
type="function", |
|
) |
|
) |
|
return tool_calls |
|
|
|
|
|
@exception_log |
|
def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None: |
|
cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "}) |
|
print( |
|
f"{time:.1f}s; " |
|
f"{token.total:,}({token.input:,}+{token.output:,})tokens; " |
|
f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}" |
|
) |
|
|
|
|
|
@exception_log |
|
def print_inputs(inputs: InputType) -> None: |
|
cprint("[inputs]", color=TITLE_COLOR) |
|
print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False)) |
|
|
|
|
|
@exception_log |
|
def print_outputs(outputs: OutputType) -> None: |
|
cprint("[outputs]", color=TITLE_COLOR) |
|
print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False)) |
|
|
|
|
|
@exception_log |
|
def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None: |
|
if messages is None: |
|
cprint("Not yet running _preprocess", color="red") |
|
return |
|
if title: |
|
cprint("[messages]", color=TITLE_COLOR) |
|
messages = [ |
|
message.to_typeddict_message() if isinstance(message, ChatCompletionMessage) else message |
|
for message in messages |
|
] |
|
for message in messages: |
|
|
|
cprint(" " + message["role"], **ROLE2CPRINT_PARAMS[message["role"]]) |
|
|
|
|
|
if message["content"] is None: |
|
pass |
|
elif isinstance(message["content"], str): |
|
print(" " + message["content"].replace("\n", "\n ")) |
|
elif isinstance(message["content"], list): |
|
for part in message["content"]: |
|
if part["type"] == "text": |
|
print(" " + part["text"].replace("\n", "\n ")) |
|
elif part["type"] == "image_url": |
|
cprint(" <image_url>", color="green", kwargs={"end": " "}) |
|
print(str(part["image_url"])[:MAX_IMAGE_URL_LEN]) |
|
if len(str(part["image_url"])) > MAX_IMAGE_URL_LEN: |
|
print("...") |
|
elif part["type"] == "refusal": |
|
cprint(" <refusal>", color="yellow", kwargs={"end": " "}) |
|
print(part["refusal"]) |
|
|
|
|
|
for tool_call in _get_tool_calls(message): |
|
print(" ", end="") |
|
cprint(tool_call["function"]["name"], color="green", background=True) |
|
print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n ")) |
|
|
|
|
|
@exception_log |
|
def print_delta(chunk: Chunk) -> None: |
|
if len(chunk.choices) == 0: |
|
return |
|
choice = chunk.choices[0] |
|
if choice.delta.role is not None: |
|
cprint(f" {choice.delta.role}\n" + " ", color="green") |
|
if choice.delta.content is not None: |
|
print(choice.delta.content.replace("\n", "\n "), end="") |
|
if choice.delta.function_call is not None: |
|
if choice.delta.function_call.name is not None: |
|
cprint(choice.delta.function_call.name, color="green", background=True) |
|
print(" ", end="") |
|
if choice.delta.function_call.arguments is not None: |
|
print(choice.delta.function_call.arguments.replace("\n", "\n "), end="") |
|
if choice.delta.tool_calls is not None: |
|
for tool_call in choice.delta.tool_calls: |
|
if tool_call.function is None: |
|
continue |
|
if tool_call.index >= 1: |
|
print("\n ", end="") |
|
if tool_call.function.name is not None: |
|
cprint(f"{tool_call.function.name}\n" + " ", color="green", background=True) |
|
if tool_call.function.arguments is not None: |
|
print(tool_call.function.arguments.replace("\n", "\n "), end="") |
|
if choice.finish_reason is not None: |
|
print() |
|
|
|
|
|
@exception_log |
|
def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None: |
|
cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "}) |
|
llm_settings_copy = dict(platform=platform, model=model, **llm_settings) |
|
|
|
if platform == "azure": |
|
llm_settings_copy["engine"] = engine |
|
print(llm_settings_copy or "-") |
|
|
|
|
|
@exception_log |
|
def print_client_settings(client_settings: ClientSettings) -> None: |
|
cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "}) |
|
print(client_settings or "-") |
|
|
|
|
|
|
|
|
|
_DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any] |
|
DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"] |
|
|
|
|
|
def _arange_dumpable_object(obj: Any) -> DumplableType: |
|
if obj is None or isinstance(obj, (int, float, str, bool)): |
|
return obj |
|
|
|
|
|
if isinstance(obj, list): |
|
return [_arange_dumpable_object(item) for item in obj] |
|
|
|
|
|
if isinstance(obj, dict): |
|
return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()} |
|
|
|
|
|
return f"<{type(obj).__name__}>{str(obj)}" |
|
|