abhinand2 commited on
Commit
dd958d7
·
verified ·
1 Parent(s): 66f2ba6

Update chain.py

Browse files
Files changed (1) hide show
  1. chain.py +10 -1
chain.py CHANGED
@@ -5,6 +5,9 @@ from langchain_core.prompts import PromptTemplate
5
  from langchain_huggingface import HuggingFaceEndpoint
6
  from langchain_core.output_parsers import StrOutputParser
7
 
 
 
 
8
 
9
  def get_chain(
10
  vectordb,
@@ -26,8 +29,11 @@ Always say "Thanks for asking!" at the end of the answer.
26
  Question: {question}
27
  Helpful Answer:"""
28
  ):
29
- retriever = vectordb.as_retriever(search_type=search_type, search_kwargs={"k": k, "fetch_k": fetch_k})
 
 
30
 
 
31
  retrieval = RunnableParallel(
32
  {
33
  "context": RunnableLambda(lambda x: retriever.invoke(x["question"])),
@@ -35,8 +41,10 @@ Helpful Answer:"""
35
  }
36
  )
37
 
 
38
  prompt = PromptTemplate(input_variables=["context", "question"], template=template)
39
 
 
40
  llm = HuggingFaceEndpoint(
41
  repo_id=repo_id,
42
  task=task,
@@ -46,5 +54,6 @@ Helpful Answer:"""
46
  repetition_penalty=repetition_penalty,
47
  )
48
 
 
49
  return retrieval | prompt | llm | StrOutputParser()
50
 
 
5
  from langchain_huggingface import HuggingFaceEndpoint
6
  from langchain_core.output_parsers import StrOutputParser
7
 
8
+ import logging
9
+ logger = logging.getLogger(__name__)
10
+
11
 
12
  def get_chain(
13
  vectordb,
 
29
  Question: {question}
30
  Helpful Answer:"""
31
  ):
32
+ search_kwargs = {"k": k, "fetch_k": fetch_k}
33
+ logger.info(f'Setting up vectordb retriever with search_type={search_type} and search_kwargs={search_kwargs}')
34
+ retriever = vectordb.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
35
 
36
+ logger.info('Setting up retrieval runnable')
37
  retrieval = RunnableParallel(
38
  {
39
  "context": RunnableLambda(lambda x: retriever.invoke(x["question"])),
 
41
  }
42
  )
43
 
44
+ logger.info(f'Setting up prompt from the template:\n{template}')
45
  prompt = PromptTemplate(input_variables=["context", "question"], template=template)
46
 
47
+ logger.info(f'Instantiating llm with repo_id={repo_id}, task={task}, max_new_tokens={max_new_tokens}, top_k={top_k}, temperature={temperature} and repetition_penalty={repetition_penalty}')
48
  llm = HuggingFaceEndpoint(
49
  repo_id=repo_id,
50
  task=task,
 
54
  repetition_penalty=repetition_penalty,
55
  )
56
 
57
+ logger.info('Instantiating and returning chain = retrieval | prompt | llm | StrOutputParser()')
58
  return retrieval | prompt | llm | StrOutputParser()
59