柿崎透真
feat: first application
55fc0a1
raw
history blame
7.57 kB
import json
from functools import wraps
from typing import Any, Callable
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from neollm.constants import (
MAX_IMAGE_URL_LEN,
ROLE2CPRINT_PARAMS,
TITLE_COLOR,
YEN_PAR_DOLLAR,
)
from neollm.types import (
ChatCompletionMessage,
Chunk,
ClientSettings,
InputType,
LLMSettings,
Message,
Messages,
OutputType,
PriceInfo,
TokenInfo,
)
from neollm.utils.postprocess import json2dict
from neollm.utils.utils import cprint
def exception_log(func: Callable[..., None]) -> Callable[..., None]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> None:
try:
return func(*args, **kwargs)
except Exception as e:
cprint(e, color="red", background=True)
return wrapper
def _get_tool_calls(message_dict: Message) -> list[ChatCompletionMessageToolCallParam]:
tool_calls: list[ChatCompletionMessageToolCallParam] = []
if "tool_calls" in message_dict:
_tool_calls = message_dict.get("tool_calls", None)
if _tool_calls is not None and isinstance(_tool_calls, list): # isinstance(_tool_calls, list)ないと通らん,,,
for _tool_call in _tool_calls:
tool_call = ChatCompletionMessageToolCallParam(
id=_tool_call["id"],
function=Function(
arguments=_tool_call["function"]["arguments"],
name=_tool_call["function"]["name"],
),
type=_tool_call["type"],
)
tool_calls.append(tool_call)
if "function_call" in message_dict:
function_call = message_dict.get("function_call", None)
if function_call is not None and isinstance(
function_call, dict
): # isinstance(function_call, dict)ないと通らん,,,
tool_calls.append(
ChatCompletionMessageToolCallParam(
id="",
function=Function(
arguments=function_call["arguments"],
name=function_call["name"],
),
type="function",
)
)
return tool_calls
@exception_log
def print_metadata(time: float, token: TokenInfo, price: PriceInfo) -> None:
cprint("[metadata]", color=TITLE_COLOR, kwargs={"end": " "})
print(
f"{time:.1f}s; "
f"{token.total:,}({token.input:,}+{token.output:,})tokens; "
f"${price.total:.2g}; ¥{price.total*YEN_PAR_DOLLAR:.2g}"
)
@exception_log
def print_inputs(inputs: InputType) -> None:
cprint("[inputs]", color=TITLE_COLOR)
print(json.dumps(_arange_dumpable_object(inputs), indent=2, ensure_ascii=False))
@exception_log
def print_outputs(outputs: OutputType) -> None:
cprint("[outputs]", color=TITLE_COLOR)
print(json.dumps(_arange_dumpable_object(outputs), indent=2, ensure_ascii=False))
@exception_log
def print_messages(messages: list[ChatCompletionMessage] | Messages | None, title: bool = True) -> None:
if messages is None:
cprint("Not yet running _preprocess", color="red")
return
if title:
cprint("[messages]", color=TITLE_COLOR)
messages = [
message.to_typeddict_message() if isinstance(message, ChatCompletionMessage) else message
for message in messages
]
for message in messages:
# roleの出力 ----------------------------------------
cprint(" " + message["role"], **ROLE2CPRINT_PARAMS[message["role"]])
# contentの出力 ----------------------------------------
if message["content"] is None:
pass
elif isinstance(message["content"], str):
print(" " + message["content"].replace("\n", "\n "))
elif isinstance(message["content"], list):
for part in message["content"]:
if part["type"] == "text":
print(" " + part["text"].replace("\n", "\n "))
elif part["type"] == "image_url":
cprint(" <image_url>", color="green", kwargs={"end": " "})
print(str(part["image_url"])[:MAX_IMAGE_URL_LEN])
if len(str(part["image_url"])) > MAX_IMAGE_URL_LEN:
print("...")
elif part["type"] == "refusal":
cprint(" <refusal>", color="yellow", kwargs={"end": " "})
print(part["refusal"])
# tool_calls, function_callsの出力 ----------------------------------------
for tool_call in _get_tool_calls(message):
print(" ", end="")
cprint(tool_call["function"]["name"], color="green", background=True)
print(" " + str(json2dict(tool_call["function"]["arguments"], error_key=None)).replace("\n", "\n "))
@exception_log
def print_delta(chunk: Chunk) -> None:
if len(chunk.choices) == 0:
return
choice = chunk.choices[0] # TODO: n>2の対応
if choice.delta.role is not None:
cprint(f" {choice.delta.role}\n" + " ", color="green")
if choice.delta.content is not None:
print(choice.delta.content.replace("\n", "\n "), end="")
if choice.delta.function_call is not None:
if choice.delta.function_call.name is not None:
cprint(choice.delta.function_call.name, color="green", background=True)
print(" ", end="")
if choice.delta.function_call.arguments is not None:
print(choice.delta.function_call.arguments.replace("\n", "\n "), end="")
if choice.delta.tool_calls is not None:
for tool_call in choice.delta.tool_calls:
if tool_call.function is None:
continue
if tool_call.index >= 1:
print("\n ", end="")
if tool_call.function.name is not None:
cprint(f"{tool_call.function.name}\n" + " ", color="green", background=True)
if tool_call.function.arguments is not None:
print(tool_call.function.arguments.replace("\n", "\n "), end="")
if choice.finish_reason is not None:
print()
@exception_log
def print_llm_settings(llm_settings: LLMSettings, model: str, engine: str | None, platform: str) -> None:
cprint("[llm_settings]", color=TITLE_COLOR, kwargs={"end": " "})
llm_settings_copy = dict(platform=platform, model=model, **llm_settings)
# Azureの場合
if platform == "azure":
llm_settings_copy["engine"] = engine # engineを追加
print(llm_settings_copy or "-")
@exception_log
def print_client_settings(client_settings: ClientSettings) -> None:
cprint("[client_settings]", color=TITLE_COLOR, kwargs={"end": " "})
print(client_settings or "-")
# -------
_DumplableEntity = int | float | str | bool | None | list[Any] | dict[Any, Any]
DumplableType = _DumplableEntity | list["DumplableType"] | dict["DumplableType", "DumplableType"]
def _arange_dumpable_object(obj: Any) -> DumplableType:
if obj is None or isinstance(obj, (int, float, str, bool)):
return obj
# list -> Array
if isinstance(obj, list):
return [_arange_dumpable_object(item) for item in obj]
# dict -> Object
if isinstance(obj, dict):
return {_arange_dumpable_object(key): _arange_dumpable_object(value) for key, value in obj.items()}
# Other -> String
return f"<{type(obj).__name__}>{str(obj)}"