Joshua Sundance Bailey commited on
Commit
6467ea5
·
1 Parent(s): d992641

get_rag_qa_gen_chain

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -354,7 +354,8 @@ if st.session_state.llm:
354
  st.session_state.doc_chain = "summarization"
355
  elif document_chat_chain_type == "Q&A Generation":
356
  st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
357
-
 
358
  else:
359
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
360
  llm=st.session_state.llm,
@@ -432,6 +433,7 @@ if st.session_state.llm:
432
  ],
433
  config,
434
  )
 
435
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
436
 
437
  def _to_str(idx, qap):
 
354
  st.session_state.doc_chain = "summarization"
355
  elif document_chat_chain_type == "Q&A Generation":
356
  st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
357
+ # from qagen import get_rag_qa_gen_chain
358
+ # st.session_state.doc_chain = get_rag_qa_gen_chain(st.session_state.retriever, st.session_state.llm)
359
  else:
360
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
361
  llm=st.session_state.llm,
 
433
  ],
434
  config,
435
  )
436
+ # raw_results = st.session_state.doc_chain.invoke(prompt, config)
437
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
438
 
439
  def _to_str(idx, qap):
langchain-streamlit-demo/qagen.py CHANGED
@@ -6,7 +6,8 @@ from langchain.prompts.chat import (
6
  ChatPromptTemplate,
7
  )
8
  from langchain.schema.language_model import BaseLanguageModel
9
- from langchain.schema.runnable import RunnableSequence
 
10
  from pydantic import BaseModel, Field
11
 
12
 
@@ -67,3 +68,16 @@ def get_qa_gen_chain(llm: BaseLanguageModel) -> RunnableSequence:
67
  return (
68
  CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
69
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  ChatPromptTemplate,
7
  )
8
  from langchain.schema.language_model import BaseLanguageModel
9
+ from langchain.schema.retriever import BaseRetriever
10
+ from langchain.schema.runnable import RunnablePassthrough, RunnableSequence
11
  from pydantic import BaseModel, Field
12
 
13
 
 
68
  return (
69
  CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
70
  )
71
+
72
+
73
+ def get_rag_qa_gen_chain(
74
+ retriever: BaseRetriever,
75
+ llm: BaseLanguageModel,
76
+ input_key: str = "prompt",
77
+ ) -> RunnableSequence:
78
+ return (
79
+ {"context": retriever, input_key: RunnablePassthrough()}
80
+ | CHAT_PROMPT
81
+ | llm
82
+ | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
83
+ )