|
from typing import Type |
|
|
|
from neollm.llm.llm.abstract_llm import AbstractLLM |
|
from neollm.llm.model_name._abstract_model_name import AbstractModelName |
|
from neollm.types import ClientSettings |
|
|
|
from .platform import Platform |
|
|
|
|
|
def get_llm(model_name: str, platform: str, client_settings: ClientSettings) -> AbstractLLM: |
|
try: |
|
platform_enum = Platform(platform) |
|
except ValueError as e: |
|
raise ValueError( |
|
f"{str(e)}\n" |
|
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in Platform])}." |
|
) from e |
|
|
|
model_name_class: Type[AbstractModelName] |
|
if platform_enum == Platform.AZURE: |
|
from neollm.llm.model_name.azure_model_name import AzureModelName |
|
|
|
model_name_class = AzureModelName |
|
|
|
elif platform_enum == Platform.OPENAI: |
|
from neollm.llm.model_name.openai_model_name import OpenAIModelName |
|
|
|
model_name_class = OpenAIModelName |
|
|
|
elif platform_enum == Platform.ANTHROPIC: |
|
from neollm.llm.model_name.anthropic_model_name import AnthropicModelName |
|
|
|
model_name_class = AnthropicModelName |
|
|
|
elif platform_enum == Platform.GCP: |
|
from neollm.llm.model_name.gcp_model_name import GCPModelName |
|
|
|
model_name_class = GCPModelName |
|
|
|
elif platform_enum == Platform.AWS: |
|
from neollm.llm.model_name.aws_model_name import AWSModelName |
|
|
|
model_name_class = AWSModelName |
|
|
|
elif platform_enum == Platform.LOCAL_VLLM: |
|
from neollm.llm.model_name.local_vllm_model_name import LocalvLLMModelName |
|
|
|
model_name_class = LocalvLLMModelName |
|
elif platform_enum == Platform.GOOGLE_GENERATIVEAI: |
|
from neollm.llm.model_name.google_generativeai_model_name import ( |
|
GoogleGenerativeAIModelName, |
|
) |
|
|
|
model_name_class = GoogleGenerativeAIModelName |
|
else: |
|
raise ValueError(f"{platform} is not supported.") |
|
|
|
try: |
|
|
|
model_name_enum = model_name_class(model_name) |
|
except ValueError as e: |
|
raise ValueError( |
|
f"{str(e)}\n" |
|
f"{platform} is not supported. Supported platforms are {', '.join([member.value for member in model_name_class])}." |
|
) from e |
|
return model_name_enum.to_llm(client_settings, model_name) |
|
|