|
from typing import Final, Literal |
|
|
|
from neollm.llm.platform import Platform |
|
from neollm.utils.utils import cprint |
|
|
|
|
|
def normalize_gpt_model_name(model_name: str, verbose: bool = True) -> str: |
|
preview_models = { |
|
"gpt-4-0125-preview": "gpt-4-turbo-0125", |
|
"gpt-4-1106-preview": "gpt-4-turbo-1106", |
|
"gpt-4-vision-preview": "gpt-4v-turbo-1106", |
|
} |
|
model_name = preview_models.get(model_name, model_name) |
|
|
|
GPT_DATE_MAPPER: Final[dict[str, str]] = { |
|
"gpt-4o": "gpt-4o-2024-05-13", |
|
"gpt-4o-mini": "gpt-4o-mini-2024-07-18", |
|
"gpt-3.5-turbo": "gpt-3.5-turbo-0613", |
|
"gpt-3.5-turbo-16k": "gpt-3.5-turbo-16k-0613", |
|
"gpt-4": "gpt-4-0613", |
|
"gpt-4-32k": "gpt-4-32k-0613", |
|
"gpt-4-turbo": "gpt-4-turbo-1106", |
|
"gpt-4v-turbo": "gpt-4v-turbo-1106", |
|
} |
|
if model_name in GPT_DATE_MAPPER: |
|
if verbose: |
|
cprint( |
|
f"WARNING: model_nameに日付を指定してください. {model_name} -> {GPT_DATE_MAPPER[model_name]}", |
|
color="yellow", |
|
background=True, |
|
) |
|
return GPT_DATE_MAPPER[model_name] |
|
return model_name |
|
|
|
|
|
def normalize_claude_model_name(model_name: str, platform: Platform, verbose: bool = True) -> str: |
|
models: Final[list[dict[Literal["no_date"] | Platform, str]]] = [ |
|
{ |
|
"no_date": "claude-3-5-sonnet-v2", |
|
Platform.GCP: "claude-3-5-sonnet-v2@20241022", |
|
Platform.AWS: "anthropic.claude-3-5-sonnet-20241022-v2:0", |
|
Platform.ANTHROPIC: "claude-3-5-sonnet-20241022", |
|
}, |
|
{ |
|
"no_date": "claude-3-5-sonnet", |
|
Platform.GCP: "claude-3-5-sonnet@20240620", |
|
Platform.AWS: "anthropic.claude-3-5-sonnet-20240620-v1:0", |
|
Platform.ANTHROPIC: "claude-3-5-sonnet-20240620", |
|
}, |
|
{ |
|
"no_date": "claude-3-5-haiku", |
|
Platform.GCP: "claude-3-5-haiku@20241022", |
|
Platform.AWS: "anthropic.claude-3-5-haiku-20241022-v1:0", |
|
Platform.ANTHROPIC: "claude-3-5-haiku-20241022", |
|
}, |
|
{ |
|
"no_date": "claude-3-opus", |
|
Platform.GCP: "claude-3-opus@20240229", |
|
Platform.AWS: "anthropic.claude-3-opus-20240229-v1:0", |
|
Platform.ANTHROPIC: "claude-3-opus-20240229", |
|
}, |
|
{ |
|
"no_date": "claude-3-sonnet", |
|
Platform.GCP: "claude-3-sonnet@20240229", |
|
Platform.AWS: "anthropic.claude-3-sonnet-20240229-v1:0", |
|
Platform.ANTHROPIC: "claude-3-sonnet-20240229", |
|
}, |
|
{ |
|
"no_date": "claude-3-haiku", |
|
Platform.GCP: "claude-3-haiku@20240307", |
|
Platform.AWS: "anthropic.claude-3-haiku-20240307-v1:0", |
|
Platform.ANTHROPIC: "claude-3-haiku-20240307", |
|
}, |
|
] |
|
|
|
|
|
CLAUDE_DATE_MAPPER: Final[dict[str, str]] = {model["no_date"]: model[platform] for model in models} |
|
if model_name in CLAUDE_DATE_MAPPER: |
|
if verbose: |
|
cprint( |
|
f"WARNING: model_nameに日付を指定してください. {model_name} -> {CLAUDE_DATE_MAPPER[model_name]}", |
|
color="yellow", |
|
background=True, |
|
) |
|
return CLAUDE_DATE_MAPPER[model_name] |
|
|
|
|
|
CLAUDE_FOR_PLATFORM_MAPPER: Final[dict[str, str]] = { |
|
model[other_platform]: model[platform] |
|
for model in models |
|
for other_platform in [Platform.GCP, Platform.AWS, Platform.ANTHROPIC] |
|
if platform != other_platform |
|
} |
|
if model_name in CLAUDE_FOR_PLATFORM_MAPPER: |
|
if verbose: |
|
cprint( |
|
f"WARNING: model_nameを{platform}用に変更してください. {model_name} -> {CLAUDE_FOR_PLATFORM_MAPPER[model_name]}", |
|
color="yellow", |
|
background=True, |
|
) |
|
return CLAUDE_FOR_PLATFORM_MAPPER[model_name] |
|
return model_name |
|
|