Spaces:
Runtime error
Runtime error
from langchain.schema.runnable import RunnableParallel | |
from langchain_core.runnables import RunnableLambda | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.output_parsers import StrOutputParser | |
import logging | |
logger = logging.getLogger(__name__) | |
def get_chain( | |
vectordb, | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
task="text-generation", | |
max_new_tokens=512, | |
top_k=30, | |
temperature=0.1, | |
repetition_penalty=1.03, | |
search_type="mmr", | |
k=3, | |
fetch_k=5, | |
template="""Use the following sentences of context to answer the question at the end. | |
If you don't know the answer, that is if the answer is not in the context, then just say that you don't know, don't try to make up an answer. | |
Always say "Thanks for asking!" at the end of the answer. | |
{context} | |
Question: {question} | |
Helpful Answer:""" | |
): | |
search_kwargs = {"k": k, "fetch_k": fetch_k} | |
logger.info(f'Setting up vectordb retriever with search_type={search_type} and search_kwargs={search_kwargs}') | |
retriever = vectordb.as_retriever(search_type=search_type, search_kwargs=search_kwargs) | |
logger.info('Setting up retrieval runnable') | |
retrieval = RunnableParallel( | |
{ | |
"context": RunnableLambda(lambda x: retriever.invoke(x["question"])), | |
"question": RunnableLambda(lambda x: x["question"]) | |
} | |
) | |
logger.info(f'Setting up prompt from the template:\n{template}') | |
prompt = PromptTemplate(input_variables=["context", "question"], template=template) | |
logger.info(f'Instantiating llm with repo_id={repo_id}, task={task}, max_new_tokens={max_new_tokens}, top_k={top_k}, temperature={temperature} and repetition_penalty={repetition_penalty}') | |
llm = HuggingFaceEndpoint( | |
repo_id=repo_id, | |
task=task, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
logger.info('Instantiating and returning chain = retrieval | prompt | llm | StrOutputParser()') | |
return retrieval | prompt | llm | StrOutputParser() | |