File size: 2,531 Bytes
622ac66 6467ea5 930d412 622ac66 bfaa0c3 622ac66 bfaa0c3 622ac66 3550ebd 622ac66 6467ea5 923e6fa 6467ea5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import List
from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
from langchain.prompts.chat import (
ChatPromptTemplate,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import RunnablePassthrough, RunnableSequence
from pydantic import BaseModel, Field
class QuestionAnswerPair(BaseModel):
question: str = Field(..., description="The question that will be answered.")
answer: str = Field(..., description="The answer to the question that was asked.")
def to_str(self, idx: int) -> str:
question_piece = f"{idx}. **Q:** {self.question}"
whitespace = " " * (len(str(idx)) + 2)
answer_piece = f"{whitespace}**A:** {self.answer}"
return f"{question_piece}\n\n{answer_piece}"
class QuestionAnswerPairList(BaseModel):
QuestionAnswerPairs: List[QuestionAnswerPair]
def to_str(self) -> str:
return "\n\n".join(
[
qap.to_str(idx)
for idx, qap in enumerate(self.QuestionAnswerPairs, start=1)
],
)
PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
pydantic_object=QuestionAnswerPairList,
)
templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions.
Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities.
Generate as many question/answer pairs as you can.
When coming up with the question/answer pairs, you must respond in the following format:
{format_instructions}
Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON.
"""
templ2 = """{prompt}
Please create question/answer pairs, in the specified JSON format, for the following text:
----------------
{context}"""
CHAT_PROMPT = ChatPromptTemplate.from_messages(
[
("system", templ1),
("human", templ2),
],
).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)
def get_rag_qa_gen_chain(
retriever: BaseRetriever,
llm: BaseLanguageModel,
input_key: str = "prompt",
) -> RunnableSequence:
return (
{"context": retriever, input_key: RunnablePassthrough()}
| CHAT_PROMPT
| llm
| OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
| (lambda parsed_output: parsed_output.to_str())
)
|