Spaces:
Running
Running
import warnings | |
import requests | |
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | |
from pydantic.v1 import SecretStr | |
from langflow.base.models.chat_result import get_chat_result | |
from langflow.base.models.model_utils import get_model_name | |
from langflow.custom.custom_component.component import Component | |
from langflow.io import ( | |
BoolInput, | |
DropdownInput, | |
HandleInput, | |
MessageInput, | |
MessageTextInput, | |
Output, | |
SecretStrInput, | |
StrInput, | |
) | |
from langflow.schema.message import Message | |
ND_MODEL_MAPPING = { | |
"gpt-4o": {"provider": "openai", "model": "gpt-4o"}, | |
"gpt-4o-mini": {"provider": "openai", "model": "gpt-4o-mini"}, | |
"gpt-4-turbo": {"provider": "openai", "model": "gpt-4-turbo-2024-04-09"}, | |
"claude-3-5-haiku-20241022": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | |
"claude-3-5-sonnet-20241022": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | |
"anthropic.claude-3-5-sonnet-20241022-v2:0": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | |
"anthropic.claude-3-5-haiku-20241022-v1:0": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | |
"gemini-1.5-pro": {"provider": "google", "model": "gemini-1.5-pro-latest"}, | |
"gemini-1.5-flash": {"provider": "google", "model": "gemini-1.5-flash-latest"}, | |
"llama-3.1-sonar-large-128k-online": {"provider": "perplexity", "model": "llama-3.1-sonar-large-128k-online"}, | |
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { | |
"provider": "togetherai", | |
"model": "Meta-Llama-3.1-70B-Instruct-Turbo", | |
}, | |
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { | |
"provider": "togetherai", | |
"model": "Meta-Llama-3.1-405B-Instruct-Turbo", | |
}, | |
"mistral-large-latest": {"provider": "mistral", "model": "mistral-large-2407"}, | |
} | |
class NotDiamondComponent(Component): | |
display_name = "Not Diamond Router" | |
description = "Call the right model at the right time with the world's most powerful AI model router." | |
documentation: str = "https://docs.notdiamond.ai/" | |
icon = "NotDiamond" | |
name = "NotDiamond" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._selected_model_name = None | |
inputs = [ | |
MessageInput(name="input_value", display_name="Input"), | |
MessageTextInput( | |
name="system_message", | |
display_name="System Message", | |
info="System message to pass to the model.", | |
advanced=False, | |
), | |
HandleInput( | |
name="models", | |
display_name="Language Models", | |
input_types=["LanguageModel"], | |
required=True, | |
is_list=True, | |
info="Link the models you want to route between.", | |
), | |
SecretStrInput( | |
name="api_key", | |
display_name="Not Diamond API Key", | |
info="The Not Diamond API Key to use for routing.", | |
advanced=False, | |
value="NOTDIAMOND_API_KEY", | |
), | |
StrInput( | |
name="preference_id", | |
display_name="Preference ID", | |
info="The ID of the router preference that was configured via the Dashboard.", | |
advanced=False, | |
), | |
DropdownInput( | |
name="tradeoff", | |
display_name="Tradeoff", | |
info="The tradeoff between cost and latency for the router to determine the best LLM for a given query.", | |
advanced=False, | |
options=["quality", "cost", "latency"], | |
value="quality", | |
), | |
BoolInput( | |
name="hash_content", | |
display_name="Hash Content", | |
info="Whether to hash the content before being sent to the NotDiamond API.", | |
advanced=False, | |
value=False, | |
), | |
] | |
outputs = [ | |
Output(display_name="Output", name="output", method="model_select"), | |
Output( | |
display_name="Selected Model", | |
name="selected_model", | |
method="get_selected_model", | |
required_inputs=["output"], | |
), | |
] | |
def get_selected_model(self) -> str: | |
return self._selected_model_name | |
def model_select(self) -> Message: | |
api_key = SecretStr(self.api_key).get_secret_value() if self.api_key else None | |
input_value = self.input_value | |
system_message = self.system_message | |
messages = self._format_input(input_value, system_message) | |
selected_models = [] | |
mapped_selected_models = [] | |
for model in self.models: | |
model_name = get_model_name(model) | |
if model_name in ND_MODEL_MAPPING: | |
selected_models.append(model) | |
mapped_selected_models.append(ND_MODEL_MAPPING[model_name]) | |
payload = { | |
"messages": messages, | |
"llm_providers": mapped_selected_models, | |
"hash_content": self.hash_content, | |
} | |
if self.tradeoff != "quality": | |
payload["tradeoff"] = self.tradeoff | |
if self.preference_id and self.preference_id != "": | |
payload["preference_id"] = self.preference_id | |
header = { | |
"Authorization": f"Bearer {api_key}", | |
"accept": "application/json", | |
"content-type": "application/json", | |
} | |
response = requests.post( | |
"https://api.notdiamond.ai/v2/modelRouter/modelSelect", | |
json=payload, | |
headers=header, | |
timeout=10, | |
) | |
result = response.json() | |
chosen_model = self.models[0] # By default there is a fallback model | |
self._selected_model_name = get_model_name(chosen_model) | |
if "providers" not in result: | |
# No provider returned by NotDiamond API, likely failed. Fallback to first model. | |
return self._call_get_chat_result(chosen_model, input_value, system_message) | |
providers = result["providers"] | |
if len(providers) == 0: | |
# No provider returned by NotDiamond API, likely failed. Fallback to first model. | |
return self._call_get_chat_result(chosen_model, input_value, system_message) | |
nd_result = providers[0] | |
for nd_model, selected_model in zip(mapped_selected_models, selected_models, strict=False): | |
if nd_model["provider"] == nd_result["provider"] and nd_model["model"] == nd_result["model"]: | |
chosen_model = selected_model | |
self._selected_model_name = get_model_name(chosen_model) | |
break | |
return self._call_get_chat_result(chosen_model, input_value, system_message) | |
def _call_get_chat_result(self, chosen_model, input_value, system_message): | |
return get_chat_result( | |
runnable=chosen_model, | |
input_value=input_value, | |
system_message=system_message, | |
) | |
def _format_input( | |
self, | |
input_value: str | Message, | |
system_message: str | None = None, | |
): | |
messages: list[BaseMessage] = [] | |
if not input_value and not system_message: | |
msg = "The message you want to send to the router is empty." | |
raise ValueError(msg) | |
system_message_added = False | |
if input_value: | |
if isinstance(input_value, Message): | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
if "prompt" in input_value: | |
prompt = input_value.load_lc_prompt() | |
if system_message: | |
prompt.messages = [ | |
SystemMessage(content=system_message), | |
*prompt.messages, # type: ignore[has-type] | |
] | |
system_message_added = True | |
messages.extend(prompt.messages) | |
else: | |
messages.append(input_value.to_lc_message()) | |
else: | |
messages.append(HumanMessage(content=input_value)) | |
if system_message and not system_message_added: | |
messages.insert(0, SystemMessage(content=system_message)) | |
# Convert Langchain messages to OpenAI format | |
openai_messages = [] | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
openai_messages.append({"role": "user", "content": msg.content}) | |
elif isinstance(msg, AIMessage): | |
openai_messages.append({"role": "assistant", "content": msg.content}) | |
elif isinstance(msg, SystemMessage): | |
openai_messages.append({"role": "system", "content": msg.content}) | |
return openai_messages | |