Spaces:
Runtime error
Runtime error
from langchain.chains.base import Chain | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.prompts import PromptTemplate | |
from langchain.schema.language_model import BaseLanguageModel | |
from langchain.schema.retriever import BaseRetriever | |
from langchain.schema.runnable import RunnableSequence, RunnablePassthrough | |
prompt_template = """Write a concise summary of the following text, based on the user input. | |
User input: {query} | |
Text: | |
``` | |
{text} | |
``` | |
CONCISE SUMMARY:""" | |
refine_template = ( | |
"You are iteratively crafting a summary of the text below based on the user input\n" | |
"User input: {query}\n" | |
"We have provided an existing summary up to a certain point: {existing_answer}\n" | |
"We have the opportunity to refine the existing summary" | |
"(only if needed) with some more context below.\n" | |
"------------\n" | |
"{text}\n" | |
"------------\n" | |
"Given the new context, refine the original summary.\n" | |
"If the context isn't useful, return the original summary.\n" | |
"If the context is useful, refine the summary to include the new context.\n" | |
"Your contribution is helping to build a comprehensive summary of a large body of knowledge.\n" | |
"You do not have the complete context, so do not discard pieces of the original summary." | |
) | |
def get_summarization_chain( | |
llm: BaseLanguageModel, | |
prompt: str, | |
) -> Chain: | |
_prompt = PromptTemplate.from_template( | |
prompt_template, | |
partial_variables={"query": prompt}, | |
) | |
refine_prompt = PromptTemplate.from_template( | |
refine_template, | |
partial_variables={"query": prompt}, | |
) | |
return load_summarize_chain( | |
llm=llm, | |
chain_type="refine", | |
question_prompt=_prompt, | |
refine_prompt=refine_prompt, | |
return_intermediate_steps=False, | |
input_key="input_documents", | |
output_key="output_text", | |
) | |
def get_rag_summarization_chain( | |
prompt: str, | |
retriever: BaseRetriever, | |
llm: BaseLanguageModel, | |
input_key: str = "prompt", | |
) -> RunnableSequence: | |
return ( | |
{"input_documents": retriever, input_key: RunnablePassthrough()} | |
| get_summarization_chain(llm, prompt) | |
| (lambda output: output["output_text"]) | |
) | |