柿崎透真
feat: enable Gemini-2.0-Flash-Exp
3d7c096
raw
history blame
2.4 kB
from typing import Type
from neollm.llm.llm.abstract_llm import AbstractLLM
from neollm.llm.model_name._abstract_model_name import AbstractModelName
from neollm.types import ClientSettings
from .platform import Platform
def get_llm(model_name: str, platform: str, client_settings: ClientSettings) -> AbstractLLM:
try:
platform_enum = Platform(platform)
except ValueError as e:
raise ValueError(
f"{str(e)}\n"
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in Platform])}."
) from e
model_name_class: Type[AbstractModelName]
if platform_enum == Platform.AZURE:
from neollm.llm.model_name.azure_model_name import AzureModelName
model_name_class = AzureModelName
elif platform_enum == Platform.OPENAI:
from neollm.llm.model_name.openai_model_name import OpenAIModelName
model_name_class = OpenAIModelName
elif platform_enum == Platform.ANTHROPIC:
from neollm.llm.model_name.anthropic_model_name import AnthropicModelName
model_name_class = AnthropicModelName
elif platform_enum == Platform.GCP:
from neollm.llm.model_name.gcp_model_name import GCPModelName
model_name_class = GCPModelName
elif platform_enum == Platform.AWS:
from neollm.llm.model_name.aws_model_name import AWSModelName
model_name_class = AWSModelName
elif platform_enum == Platform.LOCAL_VLLM:
from neollm.llm.model_name.local_vllm_model_name import LocalvLLMModelName
model_name_class = LocalvLLMModelName
elif platform_enum == Platform.GOOGLE_GENERATIVEAI:
from neollm.llm.model_name.google_generativeai_model_name import (
GoogleGenerativeAIModelName,
)
model_name_class = GoogleGenerativeAIModelName
else:
raise ValueError(f"{platform} is not supported.")
try:
# TODO: Platformのmethodで`model_name`を吐き出すようにしたら簡素化できそう
model_name_enum = model_name_class(model_name) # type: ignore[abstract]
except ValueError as e:
raise ValueError(
f"{str(e)}\n"
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in model_name_class])}."
) from e
return model_name_enum.to_llm(client_settings, model_name)