File size: 2,531 Bytes
622ac66
 
 
 
 
 
 
6467ea5
 
930d412
622ac66
 
 
 
 
 
bfaa0c3
 
 
 
 
 
622ac66
 
 
 
bfaa0c3
 
 
 
 
 
 
 
622ac66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3550ebd
622ac66
 
 
 
 
 
 
 
6467ea5
 
 
 
 
 
 
 
 
 
923e6fa
6467ea5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import List

from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from langchain.prompts.chat import (
    ChatPromptTemplate,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence
from pydantic import BaseModel, Field


class QuestionAnswerPair(BaseModel):
    question: str = Field(..., description="The question that will be answered.")
    answer: str = Field(..., description="The answer to the question that was asked.")

    def to_str(self, idx: int) -> str:
        question_piece = f"{idx}. **Q:** {self.question}"
        whitespace = " " * (len(str(idx)) + 2)
        answer_piece = f"{whitespace}**A:** {self.answer}"
        return f"{question_piece}\n\n{answer_piece}"


class QuestionAnswerPairList(BaseModel):
    QuestionAnswerPairs: List[QuestionAnswerPair]

    def to_str(self) -> str:
        return "\n\n".join(
            [
                qap.to_str(idx)
                for idx, qap in enumerate(self.QuestionAnswerPairs, start=1)
            ],
        )


PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
    pydantic_object=QuestionAnswerPairList,
)


templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions.
Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities.
Generate as many question/answer pairs as you can.
When coming up with the question/answer pairs, you must respond in the following format:
{format_instructions}

Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON.
"""
templ2 = """{prompt}
Please create question/answer pairs, in the specified JSON format, for the following text:
----------------
{context}"""
CHAT_PROMPT = ChatPromptTemplate.from_messages(
    [
        ("system", templ1),
        ("human", templ2),
    ],
).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)


def get_rag_qa_gen_chain(
    retriever: BaseRetriever,
    llm: BaseLanguageModel,
    input_key: str = "prompt",
) -> RunnableSequence:
    return (
        {"context": retriever, input_key: RunnablePassthrough()}
        | CHAT_PROMPT
        | llm
        | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
        | (lambda parsed_output: parsed_output.to_str())
    )