|
from llama_index.core.query_pipeline import ( |
|
QueryPipeline, |
|
InputComponent, |
|
ArgPackComponent, |
|
) |
|
from llama_index.core.prompts import PromptTemplate |
|
from llama_index.llms.ollama import Ollama |
|
from llama_index.core.retrievers import RecursiveRetriever |
|
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, SummaryIndex, load_index_from_storage, StorageContext |
|
from utils.recursive_retrieve import get_file_name |
|
|
|
from llama_index.core.chat_engine import ContextChatEngine |
|
from utils.history import RedisChatHistory |
|
|
|
|
|
input_component = InputComponent() |
|
|
|
|
|
rewrite = ( |
|
"Please write a query to a semantic search engine using the current conversation.\n" |
|
"\n" |
|
"\n" |
|
"{chat_history_str}" |
|
"\n" |
|
"\n" |
|
"Latest message: {query_str}\n" |
|
'Query:"""\n' |
|
) |
|
rewrite_template = PromptTemplate(rewrite) |
|
llm = Ollama(model="pornchat", request_timeout=120) |
|
|
|
|
|
argpack_component = ArgPackComponent() |
|
|
|
|
|
top_vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir="/data1/home/purui/projects/chatbot/kb/top_index")) |
|
data_dir = "/data1/home/purui/projects/chatbot/data/txt" |
|
index_dir = "/data1/home/purui/projects/chatbot/kb" |
|
titles = get_file_name(data_dir) |
|
vector_retrievers = {} |
|
for title in titles: |
|
vector_index = load_index_from_storage(storage_context=StorageContext.from_defaults(persist_dir=f"{index_dir}/{title}")) |
|
vector_retriever = vector_index.as_retriever(similarity_top_k=1) |
|
vector_retrievers[title] = vector_retriever |
|
recursive_retriever = RecursiveRetriever( |
|
"vector", |
|
retriever_dict={"vector": top_vector_index.as_retriever(simliarity_top_k=1), **vector_retrievers}, |
|
) |
|
retriever = recursive_retriever |
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
from llama_index.core.bridge.pydantic import Field |
|
from llama_index.core.llms import ChatMessage |
|
from llama_index.core.query_pipeline import CustomQueryComponent |
|
from llama_index.core.schema import NodeWithScore |
|
|
|
DEFAULT_CONTEXT_PROMPT = ( |
|
"Here is some context that may be relevant:\n" |
|
"-----\n" |
|
"{node_context}\n" |
|
"-----chat_history-----\n" |
|
"{chat_history}\n" |
|
"-----\n" |
|
"Please write a response to the following question, using the above context:\n" |
|
"{query_str}\n" |
|
) |
|
|
|
|
|
class ResponseWithChatHistory(CustomQueryComponent): |
|
llm: Ollama = Field(..., description="Local LLM") |
|
system_prompt: Optional[str] = Field( |
|
default=None, description="System prompt to use for the LLM" |
|
) |
|
context_prompt: str = Field( |
|
default=DEFAULT_CONTEXT_PROMPT, |
|
description="Context prompt to use for the LLM", |
|
) |
|
|
|
def _validate_component_inputs( |
|
self, input: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
"""Validate component inputs during run_component.""" |
|
|
|
return input |
|
|
|
@property |
|
def _input_keys(self) -> set: |
|
"""Input keys dict.""" |
|
|
|
|
|
return {"chat_history", "nodes", "query_str"} |
|
|
|
@property |
|
def _output_keys(self) -> set: |
|
return {"response"} |
|
|
|
def _prepare_context( |
|
self, |
|
chat_history: List[ChatMessage], |
|
nodes: List[NodeWithScore], |
|
query_str: str, |
|
) -> List[ChatMessage]: |
|
node_context = "" |
|
for idx, node in enumerate(nodes): |
|
node_text = node.get_content(metadata_mode="llm") |
|
node_context += f"Context Chunk {idx}:\n{node_text}\n\n" |
|
|
|
formatted_context = self.context_prompt.format( |
|
node_context=node_context, query_str=query_str, chat_history=chat_history |
|
) |
|
user_message = ChatMessage(role="user", content=formatted_context) |
|
|
|
chat_history.append(user_message) |
|
|
|
if self.system_prompt is not None: |
|
chat_history = [ |
|
ChatMessage(role="system", content=self.system_prompt) |
|
] + chat_history |
|
|
|
return chat_history |
|
|
|
def _run_component(self, **kwargs) -> Dict[str, Any]: |
|
"""Run the component.""" |
|
chat_history = kwargs["chat_history"] |
|
nodes = kwargs["nodes"] |
|
query_str = kwargs["query_str"] |
|
|
|
prepared_context = self._prepare_context( |
|
chat_history, nodes, query_str |
|
) |
|
print(prepared_context) |
|
|
|
response = llm.chat(prepared_context) |
|
|
|
return {"response": response} |
|
|
|
async def _arun_component(self, **kwargs: Any) -> Dict[str, Any]: |
|
"""Run the component asynchronously.""" |
|
|
|
chat_history = kwargs["chat_history"] |
|
nodes = kwargs["nodes"] |
|
query_str = kwargs["query_str"] |
|
|
|
prepared_context = self._prepare_context( |
|
chat_history, nodes, query_str |
|
) |
|
|
|
response = await llm.achat(prepared_context) |
|
|
|
return {"response": response} |
|
|
|
|
|
response_component = ResponseWithChatHistory( |
|
llm=llm, |
|
system_prompt=( |
|
"You are a Q&A system. You will be provided with the previous chat history, " |
|
"as well as possibly relevant context, to assist in answering a user message." |
|
), |
|
) |
|
pipeline = QueryPipeline( |
|
modules={ |
|
"input": input_component, |
|
"rewrite_template": rewrite_template, |
|
"llm": llm, |
|
|
|
"query_retriever": retriever, |
|
|
|
|
|
"response_component": response_component, |
|
}, |
|
verbose=False, |
|
) |
|
|
|
|
|
pipeline.add_link( |
|
"input", "rewrite_template", src_key="query_str", dest_key="query_str" |
|
) |
|
pipeline.add_link( |
|
"input", |
|
"rewrite_template", |
|
src_key="chat_history_str", |
|
dest_key="chat_history_str", |
|
) |
|
pipeline.add_link("rewrite_template", "llm") |
|
|
|
pipeline.add_link("llm", "query_retriever") |
|
|
|
|
|
|
|
|
|
pipeline.add_link("query_retriever", "response_component", dest_key="nodes") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline.add_link( |
|
"input", "response_component", src_key="query_str", dest_key="query_str" |
|
) |
|
pipeline.add_link( |
|
"input", |
|
"response_component", |
|
src_key="chat_history", |
|
dest_key="chat_history", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from llama_index.core.query_engine import RetrieverQueryEngine |
|
query_engine_base = RetrieverQueryEngine.from_args(recursive_retriever, llm=Ollama(model="pornchat", request_timeout=120), verbose=True) |
|
memory = RedisChatHistory(userId="2343").as_memory() |
|
chat_engine = ContextChatEngine.from_defaults( |
|
retriever=retriever, |
|
llm=Ollama(model="pornchat", request_timeout=120), |
|
system_prompt="You are a helpful sexual education professor to chat with users, named Winnie. You will answer any questions in a Kind and Friendly tone.", |
|
memory=memory, |
|
) |
|
response = chat_engine.chat("Who are you?") |
|
print(response.response) |