Joshua Sundance Bailey
commited on
Commit
·
6467ea5
1
Parent(s):
d992641
get_rag_qa_gen_chain
Browse files
langchain-streamlit-demo/app.py
CHANGED
@@ -354,7 +354,8 @@ if st.session_state.llm:
|
|
354 |
st.session_state.doc_chain = "summarization"
|
355 |
elif document_chat_chain_type == "Q&A Generation":
|
356 |
st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
|
357 |
-
|
|
|
358 |
else:
|
359 |
st.session_state.doc_chain = RetrievalQA.from_chain_type(
|
360 |
llm=st.session_state.llm,
|
@@ -432,6 +433,7 @@ if st.session_state.llm:
|
|
432 |
],
|
433 |
config,
|
434 |
)
|
|
|
435 |
results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
|
436 |
|
437 |
def _to_str(idx, qap):
|
|
|
354 |
st.session_state.doc_chain = "summarization"
|
355 |
elif document_chat_chain_type == "Q&A Generation":
|
356 |
st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
|
357 |
+
# from qagen import get_rag_qa_gen_chain
|
358 |
+
# st.session_state.doc_chain = get_rag_qa_gen_chain(st.session_state.retriever, st.session_state.llm)
|
359 |
else:
|
360 |
st.session_state.doc_chain = RetrievalQA.from_chain_type(
|
361 |
llm=st.session_state.llm,
|
|
|
433 |
],
|
434 |
config,
|
435 |
)
|
436 |
+
# raw_results = st.session_state.doc_chain.invoke(prompt, config)
|
437 |
results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
|
438 |
|
439 |
def _to_str(idx, qap):
|
langchain-streamlit-demo/qagen.py
CHANGED
@@ -6,7 +6,8 @@ from langchain.prompts.chat import (
|
|
6 |
ChatPromptTemplate,
|
7 |
)
|
8 |
from langchain.schema.language_model import BaseLanguageModel
|
9 |
-
from langchain.schema.
|
|
|
10 |
from pydantic import BaseModel, Field
|
11 |
|
12 |
|
@@ -67,3 +68,16 @@ def get_qa_gen_chain(llm: BaseLanguageModel) -> RunnableSequence:
|
|
67 |
return (
|
68 |
CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
|
69 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
ChatPromptTemplate,
|
7 |
)
|
8 |
from langchain.schema.language_model import BaseLanguageModel
|
9 |
+
from langchain.schema.retriever import BaseRetriever
|
10 |
+
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence
|
11 |
from pydantic import BaseModel, Field
|
12 |
|
13 |
|
|
|
68 |
return (
|
69 |
CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
|
70 |
)
|
71 |
+
|
72 |
+
|
73 |
+
def get_rag_qa_gen_chain(
|
74 |
+
retriever: BaseRetriever,
|
75 |
+
llm: BaseLanguageModel,
|
76 |
+
input_key: str = "prompt",
|
77 |
+
) -> RunnableSequence:
|
78 |
+
return (
|
79 |
+
{"context": retriever, input_key: RunnablePassthrough()}
|
80 |
+
| CHAT_PROMPT
|
81 |
+
| llm
|
82 |
+
| OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
|
83 |
+
)
|