Tai Truong
fix readme
d202ada
from typing import Any, cast
from langchain.retrievers import ContextualCompressionRetriever
from langflow.base.vectorstores.model import (
LCVectorStoreComponent,
check_cached_vector_store,
)
from langflow.field_typing import Retriever, VectorStore
from langflow.io import (
DropdownInput,
HandleInput,
MultilineInput,
SecretStrInput,
StrInput,
)
from langflow.schema import Data
from langflow.schema.dotdict import dotdict
from langflow.template.field.base import Output
class NvidiaRerankComponent(LCVectorStoreComponent):
display_name = "NVIDIA Rerank"
description = "Rerank documents using the NVIDIA API and a retriever."
icon = "NVIDIA"
legacy: bool = True
inputs = [
MultilineInput(
name="search_query",
display_name="Search Query",
),
StrInput(
name="base_url",
display_name="Base URL",
value="https://integrate.api.nvidia.com/v1",
refresh_button=True,
info="The base URL of the NVIDIA API. Defaults to https://integrate.api.nvidia.com/v1.",
),
DropdownInput(
name="model",
display_name="Model",
options=["nv-rerank-qa-mistral-4b:1"],
value="nv-rerank-qa-mistral-4b:1",
),
SecretStrInput(name="api_key", display_name="API Key"),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"]),
]
outputs = [
Output(
display_name="Retriever",
name="base_retriever",
method="build_base_retriever",
),
Output(
display_name="Search Results",
name="search_results",
method="search_documents",
),
]
def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):
if field_name == "base_url" and field_value:
try:
build_model = self.build_model()
ids = [model.id for model in build_model.available_models]
build_config["model"]["options"] = ids
build_config["model"]["value"] = ids[0]
except Exception as e:
msg = f"Error getting model names: {e}"
raise ValueError(msg) from e
return build_config
def build_model(self):
try:
from langchain_nvidia_ai_endpoints import NVIDIARerank
except ImportError as e:
msg = "Please install langchain-nvidia-ai-endpoints to use the NVIDIA model."
raise ImportError(msg) from e
return NVIDIARerank(api_key=self.api_key, model=self.model, base_url=self.base_url)
def build_base_retriever(self) -> Retriever: # type: ignore[type-var]
nvidia_reranker = self.build_model()
retriever = ContextualCompressionRetriever(base_compressor=nvidia_reranker, base_retriever=self.retriever)
return cast("Retriever", retriever)
async def search_documents(self) -> list[Data]: # type: ignore[override]
retriever = self.build_base_retriever()
documents = await retriever.ainvoke(self.search_query, config={"callbacks": self.get_langchain_callbacks()})
data = self.to_data(documents)
self.status = data
return data
@check_cached_vector_store
def build_vector_store(self) -> VectorStore:
msg = "NVIDIA Rerank does not support vector stores."
raise NotImplementedError(msg)