Tai Truong
fix readme
d202ada
from langchain.chains import RetrievalQA
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import BoolInput, DropdownInput, HandleInput, MultilineInput
class RetrievalQAComponent(LCChainComponent):
display_name = "Retrieval QA"
description = "Chain for question-answering querying sources from a retriever."
name = "RetrievalQA"
legacy: bool = True
icon = "LangChain"
inputs = [
MultilineInput(
name="input_value",
display_name="Input",
info="The input value to pass to the chain.",
required=True,
),
DropdownInput(
name="chain_type",
display_name="Chain Type",
info="Chain type to use.",
options=["Stuff", "Map Reduce", "Refine", "Map Rerank"],
value="Stuff",
advanced=True,
),
HandleInput(
name="llm",
display_name="Language Model",
input_types=["LanguageModel"],
required=True,
),
HandleInput(
name="retriever",
display_name="Retriever",
input_types=["Retriever"],
required=True,
),
HandleInput(
name="memory",
display_name="Memory",
input_types=["BaseChatMemory"],
),
BoolInput(
name="return_source_documents",
display_name="Return Source Documents",
value=False,
),
]
def invoke_chain(self) -> Message:
chain_type = self.chain_type.lower().replace(" ", "_")
if self.memory:
self.memory.input_key = "query"
self.memory.output_key = "result"
runnable = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type=chain_type,
retriever=self.retriever,
memory=self.memory,
# always include to help debugging
#
return_source_documents=True,
)
result = runnable.invoke(
{"query": self.input_value},
config={"callbacks": self.get_langchain_callbacks()},
)
source_docs = self.to_data(result.get("source_documents", keys=[]))
result_str = str(result.get("result", ""))
if self.return_source_documents and len(source_docs):
references_str = self.create_references_from_data(source_docs)
result_str = f"{result_str}\n{references_str}"
# put the entire result to debug history, query and content
self.status = {**result, "source_documents": source_docs, "output": result_str}
return result_str