|
from llama_index.core.query_pipeline import ( |
|
QueryPipeline, |
|
InputComponent, |
|
ArgPackComponent, |
|
) |
|
from llama_index.core.prompts import PromptTemplate |
|
from llama_index.llms.ollama import Ollama |
|
from typing import Any, Dict, List, Optional |
|
from llama_index.core.bridge.pydantic import Field |
|
from llama_index.core.llms import ChatMessage |
|
from llama_index.core.query_pipeline import CustomQueryComponent |
|
from llama_index.core.schema import NodeWithScore |
|
|
|
|
|
DEFAULT_CONTEXT_PROMPT = ( |
|
"Here is some context that may be relevant:\n" |
|
"-----\n" |
|
"{node_context}\n" |
|
"-----chat_history-----\n" |
|
"{chat_history}\n" |
|
"-----\n" |
|
"Please write a response to the following question, using the above context:\n" |
|
"{query_str}\n" |
|
) |
|
|
|
|
|
class ResponseComponent(CustomQueryComponent): |
|
llm: Ollama = Field(..., description="The language model to use for generating responses.") |
|
system_prompt: Optional[str] = Field(default=None, description="System prompt to use for the LLM") |
|
context_prompt: str = Field( |
|
default=DEFAULT_CONTEXT_PROMPT, |
|
description="Context prompt use for llm", |
|
) |
|
|
|
@property |
|
def _input_keys(self) -> set: |
|
return {"chat_history", "nodes", "query_str"} |
|
|
|
@property |
|
def _output_keys(self) -> set: |
|
return {"response"} |
|
|
|
def _prepare_context( |
|
self, |
|
chat_history: List[ChatMessage], |
|
nodes: List[NodeWithScore], |
|
query_str: str, |
|
) -> List[ChatMessage]: |
|
node_context = "" |
|
for idx, node in enumerate(nodes): |
|
node_text = node.get_content(metadata_mode="llm") |
|
node_context += f"Context Chunk {idx}:\n{node_text}\n\n" |
|
|
|
formatted_context = self.context_prompt.format( |
|
node_context=node_context, |
|
query_str=query_str, |
|
chat_history=chat_history |
|
) |
|
user_message = ChatMessage(role="user", content=formatted_context) |
|
chat_history.append(user_message) |
|
|
|
if self.system_prompt is not None: |
|
chat_history = [ |
|
ChatMessage(role="system", content=self.system_prompt), |
|
] + chat_history |
|
|
|
return chat_history |
|
|
|
def _run_component(self, **kwargs) -> Dict[str, Any]: |
|
chat_history = kwargs["chat_history"] |
|
nodes = kwargs["nodes"] |
|
query_str = kwargs["query_str"] |
|
|
|
prepared_context = self._prepare_context( |
|
chat_history=chat_history, |
|
nodes=nodes, |
|
query_str=query_str, |
|
) |
|
|
|
response = self.llm.chat(prepared_context) |
|
|
|
return {"response": response} |
|
|
|
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: |
|
chat_history = kwargs["chat_history"] |
|
nodes = kwargs["nodes"] |
|
query_str = kwargs["query_str"] |
|
|
|
prepared_context = self._prepare_context( |
|
chat_history=chat_history, |
|
nodes=nodes, |
|
query_str=query_str, |
|
) |
|
|
|
response = await self.llm.chat(prepared_context) |
|
|
|
return {"response": response} |