Tai Truong
fix readme
d202ada
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langflow.custom import Component
from langflow.inputs import HandleInput, MessageTextInput
from langflow.io import Output
from langflow.schema import Data
from langflow.schema.message import Message
class SelfQueryRetrieverComponent(Component):
display_name = "Self Query Retriever"
description = "Retriever that uses a vector store and an LLM to generate the vector store queries."
name = "SelfQueryRetriever"
icon = "LangChain"
legacy: bool = True
inputs = [
HandleInput(
name="query",
display_name="Query",
info="Query to be passed as input.",
input_types=["Message", "Text"],
),
HandleInput(
name="vectorstore",
display_name="Vector Store",
info="Vector Store to be passed as input.",
input_types=["VectorStore"],
),
HandleInput(
name="attribute_infos",
display_name="Metadata Field Info",
info="Metadata Field Info to be passed as input.",
input_types=["Data"],
is_list=True,
),
MessageTextInput(
name="document_content_description",
display_name="Document Content Description",
info="Document Content Description to be passed as input.",
),
HandleInput(
name="llm",
display_name="LLM",
info="LLM to be passed as input.",
input_types=["LanguageModel"],
),
]
outputs = [
Output(
display_name="Retrieved Documents",
name="documents",
method="retrieve_documents",
),
]
def retrieve_documents(self) -> list[Data]:
metadata_field_infos = [AttributeInfo(**value.data) for value in self.attribute_infos]
self_query_retriever = SelfQueryRetriever.from_llm(
llm=self.llm,
vectorstore=self.vectorstore,
document_contents=self.document_content_description,
metadata_field_info=metadata_field_infos,
enable_limit=True,
)
if isinstance(self.query, Message):
input_text = self.query.text
elif isinstance(self.query, str):
input_text = self.query
else:
msg = f"Query type {type(self.query)} not supported."
raise TypeError(msg)
documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()})
data = [Data.from_document(document) for document in documents]
self.status = data
return data