Tai Truong
fix readme
d202ada
from typing import Any
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs import DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
from langflow.inputs.inputs import HandleInput
from langflow.schema.dotdict import dotdict
class NVIDIAModelComponent(LCModelComponent):
display_name = "NVIDIA"
description = "Generates text using NVIDIA LLMs."
icon = "NVIDIA"
inputs = [
*LCModelComponent._base_inputs,
IntInput(
name="max_tokens",
display_name="Max Tokens",
advanced=True,
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
advanced=False,
options=["mistralai/mixtral-8x7b-instruct-v0.1"],
value="mistralai/mixtral-8x7b-instruct-v0.1",
),
StrInput(
name="base_url",
display_name="NVIDIA Base URL",
value="https://integrate.api.nvidia.com/v1",
refresh_button=True,
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
),
SecretStrInput(
name="nvidia_api_key",
display_name="NVIDIA API Key",
info="The NVIDIA API Key.",
advanced=False,
value="NVIDIA_API_KEY",
),
FloatInput(name="temperature", display_name="Temperature", value=0.1),
IntInput(
name="seed",
display_name="Seed",
info="The seed controls the reproducibility of the job.",
advanced=True,
value=1,
),
HandleInput(
name="output_parser",
display_name="Output Parser",
info="The parser to use to parse the output of the model",
advanced=True,
input_types=["OutputParser"],
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_model()
ids = [model.id for model in build_model.available_models]
build_config["model_name"]["options"] = ids
build_config["model_name"]["value"] = ids[0]
except Exception as e:
msg = f"Error getting model names: {e}"
raise ValueError(msg) from e
return build_config
def build_model(self) -> LanguageModel: # type: ignore[type-var]
try:
from langchain_nvidia_ai_endpoints import ChatNVIDIA
except ImportError as e:
msg = "Please install langchain-nvidia-ai-endpoints to use the NVIDIA model."
raise ImportError(msg) from e
nvidia_api_key = self.nvidia_api_key
temperature = self.temperature
model_name: str = self.model_name
max_tokens = self.max_tokens
seed = self.seed
return ChatNVIDIA(
max_tokens=max_tokens or None,
model=model_name,
base_url=self.base_url,
api_key=nvidia_api_key,
temperature=temperature or 0.1,
seed=seed,
)