File size: 3,447 Bytes
87285ff 42d6456 87285ff 43db08f 3b00b81 f081971 42d6456 87285ff 2c973dd 8260884 2c973dd 87285ff 2c973dd 43db08f 87285ff 43db08f 3b00b81 43db08f 2c973dd 43db08f a7268ce 43db08f 87285ff 43db08f 87285ff 43db08f a560c11 3b00b81 a560c11 43db08f 3b00b81 43db08f 2c973dd 43db08f 8260884 f081971 8260884 f081971 8260884 f081971 8260884 f081971 8260884 2c973dd f081971 24633e1 f081971 24633e1 f081971 24633e1 f081971 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from operator import itemgetter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document
from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
import time
from ..utils import rename_chain, pass_values
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
def _combine_documents(
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):
doc_strings = []
for i,doc in enumerate(docs):
# chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
chunk_type = "Doc"
if isinstance(doc,str):
doc_formatted = doc
else:
doc_formatted = format_document(doc, document_prompt)
doc_string = f"{chunk_type} {i+1}: " + doc_formatted
doc_string = doc_string.replace("\n"," ")
doc_strings.append(doc_string)
return sep.join(doc_strings)
def get_text_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "text"]
def get_image_docs(x):
return [doc for doc in x if doc.metadata["chunk_type"] == "image"]
def make_rag_chain(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_template)
chain = ({
"context":lambda x : _combine_documents(x["documents"]),
"context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
"query":itemgetter("query"),
"language":itemgetter("language"),
"audience":itemgetter("audience"),
} | prompt | llm | StrOutputParser())
return chain
def make_rag_chain_without_docs(llm):
prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
chain = prompt | llm | StrOutputParser()
return chain
def make_rag_node(llm,with_docs = True):
if with_docs:
rag_chain = make_rag_chain(llm)
else:
rag_chain = make_rag_chain_without_docs(llm)
async def answer_rag(state,config):
print("---- Answer RAG ----")
start_time = time.time()
answer = await rag_chain.ainvoke(state,config)
end_time = time.time()
elapsed_time = end_time - start_time
print("RAG elapsed time: ", elapsed_time)
print("Answer size : ", len(answer))
# print(f"\n\nAnswer:\n{answer}")
return {"answer":answer}
return answer_rag
def make_rag_papers_chain(llm):
prompt = ChatPromptTemplate.from_template(papers_prompt_template)
input_documents = {
"context":lambda x : _combine_documents(x["docs"]),
**pass_values(["question","language"])
}
chain = input_documents | prompt | llm | StrOutputParser()
chain = rename_chain(chain,"answer")
return chain
def make_illustration_chain(llm):
prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)
input_description_images = {
"images":lambda x : _combine_documents(get_image_docs(x["docs"])),
**pass_values(["question","audience","language","answer"]),
}
illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
return illustration_chain
|