Joshua Sundance Bailey commited on
Commit
8aab446
·
1 Parent(s): 6467ea5

enable rag q&a

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -26,7 +26,7 @@ from langchain.vectorstores import FAISS
26
  from langsmith.client import Client
27
  from streamlit_feedback import streamlit_feedback
28
 
29
- from qagen import get_qa_gen_chain, combine_qa_pair_lists
30
  from summarize import get_summarization_chain
31
 
32
  __version__ = "0.0.10"
@@ -353,9 +353,11 @@ if st.session_state.llm:
353
  if document_chat_chain_type == "Summarization":
354
  st.session_state.doc_chain = "summarization"
355
  elif document_chat_chain_type == "Q&A Generation":
356
- st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
357
- # from qagen import get_rag_qa_gen_chain
358
- # st.session_state.doc_chain = get_rag_qa_gen_chain(st.session_state.retriever, st.session_state.llm)
 
 
359
  else:
360
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
361
  llm=st.session_state.llm,
@@ -426,14 +428,14 @@ if st.session_state.llm:
426
  )
427
  if st.session_state.provider == "Anthropic":
428
  config["max_concurrency"] = 5
429
- raw_results = st.session_state.doc_chain.batch(
430
- [
431
- {"input": doc.page_content, "prompt": prompt}
432
- for doc in st.session_state.texts
433
- ],
434
- config,
435
- )
436
- # raw_results = st.session_state.doc_chain.invoke(prompt, config)
437
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
438
 
439
  def _to_str(idx, qap):
 
26
  from langsmith.client import Client
27
  from streamlit_feedback import streamlit_feedback
28
 
29
+ from qagen import combine_qa_pair_lists, get_rag_qa_gen_chain
30
  from summarize import get_summarization_chain
31
 
32
  __version__ = "0.0.10"
 
353
  if document_chat_chain_type == "Summarization":
354
  st.session_state.doc_chain = "summarization"
355
  elif document_chat_chain_type == "Q&A Generation":
356
+ # st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
357
+ st.session_state.doc_chain = get_rag_qa_gen_chain(
358
+ st.session_state.retriever,
359
+ st.session_state.llm,
360
+ )
361
  else:
362
  st.session_state.doc_chain = RetrievalQA.from_chain_type(
363
  llm=st.session_state.llm,
 
428
  )
429
  if st.session_state.provider == "Anthropic":
430
  config["max_concurrency"] = 5
431
+ # raw_results = st.session_state.doc_chain.batch(
432
+ # [
433
+ # {"input": doc.page_content, "prompt": prompt}
434
+ # for doc in st.session_state.texts
435
+ # ],
436
+ # config,
437
+ # )
438
+ raw_results = st.session_state.doc_chain.invoke(prompt, config)
439
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
440
 
441
  def _to_str(idx, qap):
langchain-streamlit-demo/qagen.py CHANGED
@@ -64,10 +64,10 @@ def combine_qa_pair_lists(
64
  )
65
 
66
 
67
- def get_qa_gen_chain(llm: BaseLanguageModel) -> RunnableSequence:
68
- return (
69
- CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
70
- )
71
 
72
 
73
  def get_rag_qa_gen_chain(
 
64
  )
65
 
66
 
67
+ # def get_qa_gen_chain(llm: BaseLanguageModel) -> RunnableSequence:
68
+ # return (
69
+ # CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
70
+ # )
71
 
72
 
73
  def get_rag_qa_gen_chain(