File size: 3,110 Bytes
318db6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from llama_index.core.query_pipeline import (
    QueryPipeline,
    InputComponent,
    ArgPackComponent,
)
from llama_index.core.prompts import PromptTemplate
from llama_index.llms.ollama import Ollama
from typing import Any, Dict, List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llms import ChatMessage
from llama_index.core.query_pipeline import CustomQueryComponent
from llama_index.core.schema import NodeWithScore


DEFAULT_CONTEXT_PROMPT = (
    "Here is some context that may be relevant:\n"
    "-----\n"
    "{node_context}\n"
    "-----chat_history-----\n"
    "{chat_history}\n"
    "-----\n"
    "Please write a response to the following question, using the above context:\n"
    "{query_str}\n"
)


class ResponseComponent(CustomQueryComponent):
    llm: Ollama = Field(..., description="The language model to use for generating responses.")
    system_prompt: Optional[str] = Field(default=None, description="System prompt to use for the LLM")
    context_prompt: str = Field(
        default=DEFAULT_CONTEXT_PROMPT,
        description="Context prompt use for llm",
    )
    
    @property
    def _input_keys(self) -> set:
        return {"chat_history", "nodes", "query_str"}
    
    @property
    def _output_keys(self) -> set:
        return {"response"}

    def _prepare_context(
        self,
        chat_history: List[ChatMessage],
        nodes: List[NodeWithScore],
        query_str: str,
    ) -> List[ChatMessage]:
        node_context = ""
        for idx, node in enumerate(nodes):
            node_text = node.get_content(metadata_mode="llm")
            node_context += f"Context Chunk {idx}:\n{node_text}\n\n"
        
        formatted_context = self.context_prompt.format(
            node_context=node_context,
            query_str=query_str,
            chat_history=chat_history
        )
        user_message = ChatMessage(role="user", content=formatted_context)
        chat_history.append(user_message)
        
        if self.system_prompt is not None:
            chat_history = [
                ChatMessage(role="system", content=self.system_prompt),
            ] + chat_history
            
        return chat_history
    
    def _run_component(self, **kwargs) -> Dict[str, Any]:
        chat_history = kwargs["chat_history"]
        nodes = kwargs["nodes"]
        query_str = kwargs["query_str"]

        prepared_context = self._prepare_context(
            chat_history=chat_history,
            nodes=nodes,
            query_str=query_str,
        )

        response = self.llm.chat(prepared_context)
        
        return {"response": response}
        
    async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]:
        chat_history = kwargs["chat_history"]
        nodes = kwargs["nodes"]
        query_str = kwargs["query_str"]

        prepared_context = self._prepare_context(
            chat_history=chat_history,
            nodes=nodes,
            query_str=query_str,
        )

        response = await self.llm.chat(prepared_context)
        
        return {"response": response}