File size: 3,447 Bytes
87285ff
 
42d6456
 
 
 
87285ff
43db08f
 
3b00b81
f081971
 
42d6456
87285ff
 
 
 
 
2c973dd
 
 
 
 
 
8260884
 
 
 
 
2c973dd
 
 
87285ff
 
 
2c973dd
 
 
 
 
 
43db08f
87285ff
43db08f
 
3b00b81
43db08f
 
 
 
 
2c973dd
43db08f
 
 
 
a7268ce
43db08f
87285ff
43db08f
 
 
 
87285ff
43db08f
a560c11
3b00b81
a560c11
43db08f
3b00b81
 
 
 
 
 
 
43db08f
2c973dd
43db08f
8260884
 
 
 
f081971
8260884
f081971
 
 
 
 
8260884
f081971
 
8260884
f081971
8260884
 
2c973dd
 
 
 
f081971
24633e1
f081971
24633e1
f081971
 
 
 
24633e1
f081971
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template
import time
from ..utils import rename_chain, pass_values


DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        if isinstance(doc,str):
            doc_formatted = doc
        else:
            doc_formatted = format_document(doc, document_prompt)
        doc_string = f"{chunk_type} {i+1}: " + doc_formatted
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]

def make_rag_chain(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    chain = ({
        "context":lambda x : _combine_documents(x["documents"]),
        "context_length":lambda x : print("CONTEXT LENGTH : " , len(_combine_documents(x["documents"]))),
        "query":itemgetter("query"),
        "language":itemgetter("language"),
        "audience":itemgetter("audience"),
    } | prompt | llm | StrOutputParser())
    return chain

def make_rag_chain_without_docs(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
    chain = prompt | llm | StrOutputParser()
    return chain

def make_rag_node(llm,with_docs = True):

    if with_docs:
        rag_chain = make_rag_chain(llm)
    else:
        rag_chain = make_rag_chain_without_docs(llm)

    async def answer_rag(state,config):
        print("---- Answer RAG ----")
        start_time = time.time()

        answer = await rag_chain.ainvoke(state,config)
    
        end_time = time.time()
        elapsed_time = end_time - start_time
        print("RAG elapsed time: ", elapsed_time)
        print("Answer size : ", len(answer))
        # print(f"\n\nAnswer:\n{answer}")
        
        return {"answer":answer}

    return answer_rag




def make_rag_papers_chain(llm):

    prompt = ChatPromptTemplate.from_template(papers_prompt_template)
    input_documents = {
        "context":lambda x : _combine_documents(x["docs"]),
        **pass_values(["question","language"])
    }

    chain = input_documents | prompt | llm | StrOutputParser()
    chain = rename_chain(chain,"answer")

    return chain






def make_illustration_chain(llm):

    prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

    input_description_images = {
        "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
        **pass_values(["question","audience","language","answer"]),
    }

    illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
    return illustration_chain