Spaces:
Runtime error
Runtime error
Update chain.py
Browse files
chain.py
CHANGED
@@ -5,6 +5,9 @@ from langchain_core.prompts import PromptTemplate
|
|
5 |
from langchain_huggingface import HuggingFaceEndpoint
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
|
|
|
|
|
|
|
8 |
|
9 |
def get_chain(
|
10 |
vectordb,
|
@@ -26,8 +29,11 @@ Always say "Thanks for asking!" at the end of the answer.
|
|
26 |
Question: {question}
|
27 |
Helpful Answer:"""
|
28 |
):
|
29 |
-
|
|
|
|
|
30 |
|
|
|
31 |
retrieval = RunnableParallel(
|
32 |
{
|
33 |
"context": RunnableLambda(lambda x: retriever.invoke(x["question"])),
|
@@ -35,8 +41,10 @@ Helpful Answer:"""
|
|
35 |
}
|
36 |
)
|
37 |
|
|
|
38 |
prompt = PromptTemplate(input_variables=["context", "question"], template=template)
|
39 |
|
|
|
40 |
llm = HuggingFaceEndpoint(
|
41 |
repo_id=repo_id,
|
42 |
task=task,
|
@@ -46,5 +54,6 @@ Helpful Answer:"""
|
|
46 |
repetition_penalty=repetition_penalty,
|
47 |
)
|
48 |
|
|
|
49 |
return retrieval | prompt | llm | StrOutputParser()
|
50 |
|
|
|
5 |
from langchain_huggingface import HuggingFaceEndpoint
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
|
8 |
+
import logging
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
|
12 |
def get_chain(
|
13 |
vectordb,
|
|
|
29 |
Question: {question}
|
30 |
Helpful Answer:"""
|
31 |
):
|
32 |
+
search_kwargs = {"k": k, "fetch_k": fetch_k}
|
33 |
+
logger.info(f'Setting up vectordb retriever with search_type={search_type} and search_kwargs={search_kwargs}')
|
34 |
+
retriever = vectordb.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
|
35 |
|
36 |
+
logger.info('Setting up retrieval runnable')
|
37 |
retrieval = RunnableParallel(
|
38 |
{
|
39 |
"context": RunnableLambda(lambda x: retriever.invoke(x["question"])),
|
|
|
41 |
}
|
42 |
)
|
43 |
|
44 |
+
logger.info(f'Setting up prompt from the template:\n{template}')
|
45 |
prompt = PromptTemplate(input_variables=["context", "question"], template=template)
|
46 |
|
47 |
+
logger.info(f'Instantiating llm with repo_id={repo_id}, task={task}, max_new_tokens={max_new_tokens}, top_k={top_k}, temperature={temperature} and repetition_penalty={repetition_penalty}')
|
48 |
llm = HuggingFaceEndpoint(
|
49 |
repo_id=repo_id,
|
50 |
task=task,
|
|
|
54 |
repetition_penalty=repetition_penalty,
|
55 |
)
|
56 |
|
57 |
+
logger.info('Instantiating and returning chain = retrieval | prompt | llm | StrOutputParser()')
|
58 |
return retrieval | prompt | llm | StrOutputParser()
|
59 |
|