abhinand2's picture
Create chain.py
5feb488 verified
raw
history blame
1.54 kB
from langchain.schema.runnable import RunnableParallel
from langchain_core.runnables import RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.output_parsers import StrOutputParser
def get_chain(
vectordb,
repo_id="HuggingFaceH4/zephyr-7b-beta",
task="text-generation",
max_new_tokens=512,
top_k=30,
temperature=0.1,
repetition_penalty=1.03,
search_type="mmr",
k=3,
fetch_k=5,
template="""Use the following sentences of context to answer the question at the end.
If you don't know the answer, that is if the answer is not in the context, then just say that you don't know, don't try to make up an answer.
Always say "Thanks for asking!" at the end of the answer.
{context}
Question: {question}
Helpful Answer:"""
):
retriever = vectordb.as_retriever(search_type=search_type, search_kwargs={"k": k, "fetch_k": fetch_k})
retrieval = RunnableParallel(
{
"context": RunnableLambda(lambda x: retriever.invoke(x["question"])),
"question": RunnableLambda(lambda x: x["question"])
}
)
prompt = PromptTemplate(input_variables=["context", "question"], template=template)
llm = HuggingFaceEndpoint(
repo_id=repo_id,
task=task,
max_new_tokens=max_new_tokens,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
return retrieval | prompt | llm | StrOutputParser()