import json from functools import wraps from typing import Any, Callable from openai.types.chat.chat_completion_message_tool_call_param import ( ChatCompletionMessageToolCallParam, Function, ) from neollm.constants import ( MAX_IMAGE_URL_LEN, ROLE2CPRINT_PARAMS, TITLE_COLOR, YEN_PAR_DOLLAR, ) from neollm.types import ( ChatCompletionMessage, Chunk, ClientSettings, InputType, LLMSettings, Message, Messages, OutputType, PriceInfo, TokenInfo, ) from neollm.utils.postprocess import json2dict from neollm.utils.utils import cprint def exception_log(func: Callable[..., None]) -> Callable[..., None]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> None: try: return func(*args, **kwargs) except Exception as e: cprint(e, color="red", background=True) return wrapper def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]: tool_calls: list[ChatCompletionMessageToolCallParam] = [] if "tool_calls" in message_dict: _tool_calls = message_dict.get("tool_calls", None) if _tool_calls is not None and isinstance(_tool_calls, list): # isinstance(_tool_calls, list)ないと通らん,,, for _tool_call in _tool_calls: tool_call = ChatCompletionMessageToolCallParam( id=_tool_call["id"], function=Function( arguments=_tool_call["function"]["arguments"], name=_tool_call["function"]["name"], ), type=_tool_call["type"], ) tool_calls.append(tool_call) if "function_call" in message_dict: function_call = message_dict.get("function_call", None) if function_call is not None and isinstance( function_call, dict ): # isinstance(function_call, dict)ないと通らん,,, tool_calls.append( ChatCompletionMessageToolCallParam( id="", function=Function( arguments=function_call["arguments"], name=function_call["name"], ), type="function", ) ) return tool_calls @exception_log def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None: cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "}) print( f"{time:.1f}s; " f"{token.total:,}({token.input:,}+{token.output:,})tokens; " f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}" ) @exception_log def print_inputs(inputs: InputType) -> None: cprint("[inputs]", color=TITLE_COLOR) print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False)) @exception_log def print_outputs(outputs: OutputType) -> None: cprint("[outputs]", color=TITLE_COLOR) print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False)) @exception_log def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None: if messages is None: cprint("Not yet running _preprocess", color="red") return if title: cprint("[messages]", color=TITLE_COLOR) messages = [ message.to_typeddict_message() if isinstance(message, ChatCompletionMessage) else message for message in messages ] for message in messages: # roleの出力 ---------------------------------------- cprint(" " + message["role"], **ROLE2CPRINT_PARAMS[message["role"]]) # contentの出力 ---------------------------------------- if message["content"] is None: pass elif isinstance(message["content"], str): print(" " + message["content"].replace("\n", "\n ")) elif isinstance(message["content"], list): for part in message["content"]: if part["type"] == "text": print(" " + part["text"].replace("\n", "\n ")) elif part["type"] == "image_url": cprint(" ", color="green", kwargs={"end": " "}) print(str(part["image_url"])[:MAX_IMAGE_URL_LEN]) if len(str(part["image_url"])) > MAX_IMAGE_URL_LEN: print("...") elif part["type"] == "refusal": cprint(" ", color="yellow", kwargs={"end": " "}) print(part["refusal"]) # tool_calls, function_callsの出力 ---------------------------------------- for tool_call in _get_tool_calls(message): print(" ", end="") cprint(tool_call["function"]["name"], color="green", background=True) print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n ")) @exception_log def print_delta(chunk: Chunk) -> None: if len(chunk.choices) == 0: return choice = chunk.choices[0] # TODO: n>2の対応 if choice.delta.role is not None: cprint(f" {choice.delta.role}\n" + " ", color="green") if choice.delta.content is not None: print(choice.delta.content.replace("\n", "\n "), end="") if choice.delta.function_call is not None: if choice.delta.function_call.name is not None: cprint(choice.delta.function_call.name, color="green", background=True) print(" ", end="") if choice.delta.function_call.arguments is not None: print(choice.delta.function_call.arguments.replace("\n", "\n "), end="") if choice.delta.tool_calls is not None: for tool_call in choice.delta.tool_calls: if tool_call.function is None: continue if tool_call.index >= 1: print("\n ", end="") if tool_call.function.name is not None: cprint(f"{tool_call.function.name}\n" + " ", color="green", background=True) if tool_call.function.arguments is not None: print(tool_call.function.arguments.replace("\n", "\n "), end="") if choice.finish_reason is not None: print() @exception_log def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None: cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "}) llm_settings_copy = dict(platform=platform, model=model, **llm_settings) # Azureの場合 if platform == "azure": llm_settings_copy["engine"] = engine # engineを追加 print(llm_settings_copy or "-") @exception_log def print_client_settings(client_settings: ClientSettings) -> None: cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "}) print(client_settings or "-") # ------- _DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any] DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"] def _arange_dumpable_object(obj: Any) -> DumplableType: if obj is None or isinstance(obj, (int, float, str, bool)): return obj # list -> Array if isinstance(obj, list): return [_arange_dumpable_object(item) for item in obj] # dict -> Object if isinstance(obj, dict): return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()} # Other -> String return f"<{type(obj).__name__}>{str(obj)}"