from typing import Type | |
from neollm.llm.llm.google_generativeai.google_generativeai import ( | |
Gemini20FlashExp, | |
GeminiExp1114, | |
GeminiExp1121, | |
GeminiExp1206, | |
_GoogleGenerativeLLM, | |
) | |
from neollm.llm.model_name._abstract_model_name import AbstractModelName | |
from neollm.types import ClientSettings | |
class GoogleGenerativeAIModelName(AbstractModelName): | |
# Gemini ------------------------------------------------ | |
GEMINI_EXP_1114 = "gemini-exp-1114" | |
GEMINI_EXP_1121 = "gemini-exp-1121" | |
GEMINI_EXP_1206 = "gemini-exp-1206" | |
# Gemini 2.0 ------------------------------------------------ | |
GEMINI_20_FLASH_EXP = "gemini-20-flash-exp" | |
def to_llm(self, client_settings: ClientSettings, model_name: str | None = None) -> _GoogleGenerativeLLM: | |
llm_class_map: dict["GoogleGenerativeAIModelName", Type[_GoogleGenerativeLLM]] = { | |
# Gemini | |
GoogleGenerativeAIModelName.GEMINI_EXP_1114: GeminiExp1114, | |
GoogleGenerativeAIModelName.GEMINI_EXP_1121: GeminiExp1121, | |
GoogleGenerativeAIModelName.GEMINI_EXP_1206: GeminiExp1206, | |
# Gemini 2.0 | |
GoogleGenerativeAIModelName.GEMINI_20_FLASH_EXP: Gemini20FlashExp, | |
} | |
return llm_class_map[self](client_settings) | |