Spaces:
Runtime error
Runtime error
import time | |
from typing import Any, Dict, List, Optional | |
import qdrant_client | |
from langchain import chains | |
from langchain.callbacks.manager import CallbackManagerForChainRun | |
from langchain.chains.base import Chain | |
from langchain.llms import HuggingFacePipeline | |
from unstructured.cleaners.core import ( | |
clean, | |
clean_extra_whitespace, | |
clean_non_ascii_chars, | |
group_broken_paragraphs, | |
replace_unicode_quotes, | |
) | |
from financial_bot.embeddings import EmbeddingModelSingleton | |
from financial_bot.template import PromptTemplate | |
class StatelessMemorySequentialChain(chains.SequentialChain): | |
""" | |
A sequential chain that uses a stateless memory to store context between calls. | |
This chain overrides the _call and prep_outputs methods to load and clear the memory | |
before and after each call, respectively. | |
""" | |
history_input_key: str = "to_load_history" | |
def _call(self, inputs: Dict[str, str], **kwargs) -> Dict[str, str]: | |
""" | |
Override _call to load history before calling the chain. | |
This method loads the history from the input dictionary and saves it to the | |
stateless memory. It then updates the inputs dictionary with the memory values | |
and removes the history input key. Finally, it calls the parent _call method | |
with the updated inputs and returns the results. | |
""" | |
to_load_history = inputs[self.history_input_key] | |
for ( | |
human, | |
ai, | |
) in to_load_history: | |
self.memory.save_context( | |
inputs={self.memory.input_key: human}, | |
outputs={self.memory.output_key: ai}, | |
) | |
memory_values = self.memory.load_memory_variables({}) | |
inputs.update(memory_values) | |
del inputs[self.history_input_key] | |
return super()._call(inputs, **kwargs) | |
def prep_outputs( | |
self, | |
inputs: Dict[str, str], | |
outputs: Dict[str, str], | |
return_only_outputs: bool = False, | |
) -> Dict[str, str]: | |
""" | |
Override prep_outputs to clear the internal memory after each call. | |
This method calls the parent prep_outputs method to get the results, then | |
clears the stateless memory and removes the memory key from the results | |
dictionary. It then returns the updated results. | |
""" | |
results = super().prep_outputs(inputs, outputs, return_only_outputs) | |
# Clear the internal memory. | |
self.memory.clear() | |
if self.memory.memory_key in results: | |
results[self.memory.memory_key] = "" | |
return results | |
class ContextExtractorChain(Chain): | |
""" | |
Encode the question, search the vector store for top-k articles and return | |
context news from documents collection of Alpaca news. | |
Attributes: | |
----------- | |
top_k : int | |
The number of top matches to retrieve from the vector store. | |
embedding_model : EmbeddingModelSingleton | |
The embedding model to use for encoding the question. | |
vector_store : qdrant_client.QdrantClient | |
The vector store to search for matches. | |
vector_collection : str | |
The name of the collection to search in the vector store. | |
""" | |
top_k: int = 1 | |
embedding_model: EmbeddingModelSingleton | |
vector_store: qdrant_client.QdrantClient | |
vector_collection: str | |
def input_keys(self) -> List[str]: | |
return ["about_me", "question"] | |
def output_keys(self) -> List[str]: | |
return ["context"] | |
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |
_, quest_key = self.input_keys | |
question_str = inputs[quest_key] | |
cleaned_question = self.clean(question_str) | |
# TODO: Instead of cutting the question at 'max_input_length', chunk the question in 'max_input_length' chunks, | |
# pass them through the model and average the embeddings. | |
cleaned_question = cleaned_question[: self.embedding_model.max_input_length] | |
embeddings = self.embedding_model(cleaned_question) | |
# TODO: Using the metadata, use the filter to take into consideration only the news from the last 24 hours | |
# (or other time frame). | |
matches = self.vector_store.search( | |
query_vector=embeddings, | |
limit=self.top_k, | |
collection_name=self.vector_collection, | |
) | |
context = "" | |
for match in matches: | |
context += match.payload["summary"] + "\n" | |
return { | |
"context": context, | |
} | |
def clean(self, question: str) -> str: | |
""" | |
Clean the input question by removing unwanted characters. | |
Parameters: | |
----------- | |
question : str | |
The input question to clean. | |
Returns: | |
-------- | |
str | |
The cleaned question. | |
""" | |
question = clean(question) | |
question = replace_unicode_quotes(question) | |
question = clean_non_ascii_chars(question) | |
return question | |
class FinancialBotQAChain(Chain): | |
"""This custom chain handles LLM generation upon given prompt""" | |
hf_pipeline: HuggingFacePipeline | |
template: PromptTemplate | |
def input_keys(self) -> List[str]: | |
"""Returns a list of input keys for the chain""" | |
return ["context"] | |
def output_keys(self) -> List[str]: | |
"""Returns a list of output keys for the chain""" | |
return ["answer"] | |
def _call( | |
self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
"""Calls the chain with the given inputs and returns the output""" | |
inputs = self.clean(inputs) | |
prompt = self.template.format_infer( | |
{ | |
"user_context": inputs["about_me"], | |
"news_context": inputs["context"], | |
"chat_history": inputs["chat_history"], | |
"question": inputs["question"], | |
} | |
) | |
start_time = time.time() | |
response = self.hf_pipeline(prompt["prompt"]) | |
end_time = time.time() | |
duration_milliseconds = (end_time - start_time) * 1000 | |
if run_manager: | |
run_manager.on_chain_end( | |
outputs={ | |
"answer": response, | |
}, | |
# TODO: Count tokens instead of using len(). | |
metadata={ | |
"prompt": prompt["prompt"], | |
"prompt_template_variables": prompt["payload"], | |
"prompt_template": self.template.infer_raw_template, | |
"usage.prompt_tokens": len(prompt["prompt"]), | |
"usage.total_tokens": len(prompt["prompt"]) + len(response), | |
"usage.actual_new_tokens": len(response), | |
"duration_milliseconds": duration_milliseconds, | |
}, | |
) | |
return {"answer": response} | |
def clean(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
"""Cleans the inputs by removing extra whitespace and grouping broken paragraphs""" | |
for key, input in inputs.items(): | |
cleaned_input = clean_extra_whitespace(input) | |
cleaned_input = group_broken_paragraphs(cleaned_input) | |
inputs[key] = cleaned_input | |
return inputs | |