from langchain.chains.query_constructor.base import AttributeInfo from langchain.retrievers.self_query.base import SelfQueryRetriever from langflow.custom import Component from langflow.inputs import HandleInput, MessageTextInput from langflow.io import Output from langflow.schema import Data from langflow.schema.message import Message class SelfQueryRetrieverComponent(Component): display_name = "Self Query Retriever" description = "Retriever that uses a vector store and an LLM to generate the vector store queries." name = "SelfQueryRetriever" icon = "LangChain" legacy: bool = True inputs = [ HandleInput( name="query", display_name="Query", info="Query to be passed as input.", input_types=["Message", "Text"], ), HandleInput( name="vectorstore", display_name="Vector Store", info="Vector Store to be passed as input.", input_types=["VectorStore"], ), HandleInput( name="attribute_infos", display_name="Metadata Field Info", info="Metadata Field Info to be passed as input.", input_types=["Data"], is_list=True, ), MessageTextInput( name="document_content_description", display_name="Document Content Description", info="Document Content Description to be passed as input.", ), HandleInput( name="llm", display_name="LLM", info="LLM to be passed as input.", input_types=["LanguageModel"], ), ] outputs = [ Output( display_name="Retrieved Documents", name="documents", method="retrieve_documents", ), ] def retrieve_documents(self) -> list[Data]: metadata_field_infos = [AttributeInfo(**value.data) for value in self.attribute_infos] self_query_retriever = SelfQueryRetriever.from_llm( llm=self.llm, vectorstore=self.vectorstore, document_contents=self.document_content_description, metadata_field_info=metadata_field_infos, enable_limit=True, ) if isinstance(self.query, Message): input_text = self.query.text elif isinstance(self.query, str): input_text = self.query else: msg = f"Query type {type(self.query)} not supported." raise TypeError(msg) documents = self_query_retriever.invoke(input=input_text, config={"callbacks": self.get_langchain_callbacks()}) data = [Data.from_document(document) for document in documents] self.status = data return data