Spaces:
Runtime error
Runtime error
Upload 22 files
Browse files- financial_bot/__init__.py +55 -0
- financial_bot/__pycache__/__init__.cpython-310.pyc +0 -0
- financial_bot/__pycache__/base.cpython-310.pyc +0 -0
- financial_bot/__pycache__/chains.cpython-310.pyc +0 -0
- financial_bot/__pycache__/constants.cpython-310.pyc +0 -0
- financial_bot/__pycache__/embeddings.cpython-310.pyc +0 -0
- financial_bot/__pycache__/handlers.cpython-310.pyc +0 -0
- financial_bot/__pycache__/langchain_bot.cpython-310.pyc +0 -0
- financial_bot/__pycache__/models.cpython-310.pyc +0 -0
- financial_bot/__pycache__/qdrant.cpython-310.pyc +0 -0
- financial_bot/__pycache__/template.cpython-310.pyc +0 -0
- financial_bot/__pycache__/utils.cpython-310.pyc +0 -0
- financial_bot/base.py +38 -0
- financial_bot/chains.py +226 -0
- financial_bot/constants.py +23 -0
- financial_bot/embeddings.py +123 -0
- financial_bot/handlers.py +64 -0
- financial_bot/langchain_bot.py +223 -0
- financial_bot/models.py +264 -0
- financial_bot/qdrant.py +49 -0
- financial_bot/template.py +132 -0
- financial_bot/utils.py +106 -0
financial_bot/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import logging.config
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
from dotenv import find_dotenv, load_dotenv
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
def initialize(logging_config_path: str = "logging.yaml", env_file_path: str = ".env"):
|
12 |
+
"""
|
13 |
+
Initializes the logger and environment variables.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
logging_config_path (str): The path to the logging configuration file. Defaults to "logging.yaml".
|
17 |
+
env_file_path (str): The path to the environment variables file. Defaults to ".env".
|
18 |
+
"""
|
19 |
+
|
20 |
+
logger.info("Initializing logger...")
|
21 |
+
try:
|
22 |
+
initialize_logger(config_path=logging_config_path)
|
23 |
+
except FileNotFoundError:
|
24 |
+
logger.warning(
|
25 |
+
f"No logging configuration file found at: {logging_config_path}. Setting logging level to INFO."
|
26 |
+
)
|
27 |
+
logging.basicConfig(level=logging.INFO)
|
28 |
+
|
29 |
+
logger.info("Initializing env vars...")
|
30 |
+
if env_file_path is None:
|
31 |
+
env_file_path = find_dotenv(raise_error_if_not_found=True, usecwd=False)
|
32 |
+
|
33 |
+
logger.info(f"Loading environment variables from: {env_file_path}")
|
34 |
+
found_env_file = load_dotenv(env_file_path, verbose=True, override=True)
|
35 |
+
if found_env_file is False:
|
36 |
+
raise RuntimeError(f"Could not find environment file at: {env_file_path}")
|
37 |
+
|
38 |
+
|
39 |
+
def initialize_logger(
|
40 |
+
config_path: str = "logging.yaml", logs_dir_name: str = "logs"
|
41 |
+
) -> logging.Logger:
|
42 |
+
"""Initialize logger from a YAML config file."""
|
43 |
+
|
44 |
+
# Create logs directory.
|
45 |
+
config_path_parent = Path(config_path).parent
|
46 |
+
logs_dir = config_path_parent / logs_dir_name
|
47 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
48 |
+
|
49 |
+
with open(config_path, "rt") as f:
|
50 |
+
config = yaml.safe_load(f.read())
|
51 |
+
|
52 |
+
# Make sure that existing logger will still work.
|
53 |
+
config["disable_existing_loggers"] = False
|
54 |
+
|
55 |
+
logging.config.dictConfig(config)
|
financial_bot/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.96 kB). View file
|
|
financial_bot/__pycache__/base.cpython-310.pyc
ADDED
Binary file (936 Bytes). View file
|
|
financial_bot/__pycache__/chains.cpython-310.pyc
ADDED
Binary file (6.98 kB). View file
|
|
financial_bot/__pycache__/constants.cpython-310.pyc
ADDED
Binary file (720 Bytes). View file
|
|
financial_bot/__pycache__/embeddings.cpython-310.pyc
ADDED
Binary file (4.37 kB). View file
|
|
financial_bot/__pycache__/handlers.cpython-310.pyc
ADDED
Binary file (2.59 kB). View file
|
|
financial_bot/__pycache__/langchain_bot.cpython-310.pyc
ADDED
Binary file (7.71 kB). View file
|
|
financial_bot/__pycache__/models.cpython-310.pyc
ADDED
Binary file (8.25 kB). View file
|
|
financial_bot/__pycache__/qdrant.cpython-310.pyc
ADDED
Binary file (1.56 kB). View file
|
|
financial_bot/__pycache__/template.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
financial_bot/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.34 kB). View file
|
|
financial_bot/base.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Lock
|
2 |
+
|
3 |
+
|
4 |
+
class SingletonMeta(type):
|
5 |
+
"""
|
6 |
+
This is a thread-safe implementation of Singleton.
|
7 |
+
"""
|
8 |
+
|
9 |
+
_instances = {}
|
10 |
+
|
11 |
+
_lock: Lock = Lock()
|
12 |
+
|
13 |
+
"""
|
14 |
+
We now have a lock object that will be used to synchronize threads during
|
15 |
+
first access to the Singleton.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __call__(cls, *args, **kwargs):
|
19 |
+
"""
|
20 |
+
Possible changes to the value of the `__init__` argument do not affect
|
21 |
+
the returned instance.
|
22 |
+
"""
|
23 |
+
# Now, imagine that the program has just been launched. Since there's no
|
24 |
+
# Singleton instance yet, multiple threads can simultaneously pass the
|
25 |
+
# previous conditional and reach this point almost at the same time. The
|
26 |
+
# first of them will acquire lock and will proceed further, while the
|
27 |
+
# rest will wait here.
|
28 |
+
with cls._lock:
|
29 |
+
# The first thread to acquire the lock, reaches this conditional,
|
30 |
+
# goes inside and creates the Singleton instance. Once it leaves the
|
31 |
+
# lock block, a thread that might have been waiting for the lock
|
32 |
+
# release may then enter this section. But since the Singleton field
|
33 |
+
# is already initialized, the thread won't create a new object.
|
34 |
+
if cls not in cls._instances:
|
35 |
+
instance = super().__call__(*args, **kwargs)
|
36 |
+
cls._instances[cls] = instance
|
37 |
+
|
38 |
+
return cls._instances[cls]
|
financial_bot/chains.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Any, Dict, List, Optional
|
3 |
+
|
4 |
+
import qdrant_client
|
5 |
+
from langchain import chains
|
6 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
7 |
+
from langchain.chains.base import Chain
|
8 |
+
from langchain.llms import HuggingFacePipeline
|
9 |
+
from unstructured.cleaners.core import (
|
10 |
+
clean,
|
11 |
+
clean_extra_whitespace,
|
12 |
+
clean_non_ascii_chars,
|
13 |
+
group_broken_paragraphs,
|
14 |
+
replace_unicode_quotes,
|
15 |
+
)
|
16 |
+
|
17 |
+
from financial_bot.embeddings import EmbeddingModelSingleton
|
18 |
+
from financial_bot.template import PromptTemplate
|
19 |
+
|
20 |
+
|
21 |
+
class StatelessMemorySequentialChain(chains.SequentialChain):
|
22 |
+
"""
|
23 |
+
A sequential chain that uses a stateless memory to store context between calls.
|
24 |
+
|
25 |
+
This chain overrides the _call and prep_outputs methods to load and clear the memory
|
26 |
+
before and after each call, respectively.
|
27 |
+
"""
|
28 |
+
|
29 |
+
history_input_key: str = "to_load_history"
|
30 |
+
|
31 |
+
def _call(self, inputs: Dict[str, str], **kwargs) -> Dict[str, str]:
|
32 |
+
"""
|
33 |
+
Override _call to load history before calling the chain.
|
34 |
+
|
35 |
+
This method loads the history from the input dictionary and saves it to the
|
36 |
+
stateless memory. It then updates the inputs dictionary with the memory values
|
37 |
+
and removes the history input key. Finally, it calls the parent _call method
|
38 |
+
with the updated inputs and returns the results.
|
39 |
+
"""
|
40 |
+
|
41 |
+
to_load_history = inputs[self.history_input_key]
|
42 |
+
for (
|
43 |
+
human,
|
44 |
+
ai,
|
45 |
+
) in to_load_history:
|
46 |
+
self.memory.save_context(
|
47 |
+
inputs={self.memory.input_key: human},
|
48 |
+
outputs={self.memory.output_key: ai},
|
49 |
+
)
|
50 |
+
memory_values = self.memory.load_memory_variables({})
|
51 |
+
inputs.update(memory_values)
|
52 |
+
|
53 |
+
del inputs[self.history_input_key]
|
54 |
+
|
55 |
+
return super()._call(inputs, **kwargs)
|
56 |
+
|
57 |
+
def prep_outputs(
|
58 |
+
self,
|
59 |
+
inputs: Dict[str, str],
|
60 |
+
outputs: Dict[str, str],
|
61 |
+
return_only_outputs: bool = False,
|
62 |
+
) -> Dict[str, str]:
|
63 |
+
"""
|
64 |
+
Override prep_outputs to clear the internal memory after each call.
|
65 |
+
|
66 |
+
This method calls the parent prep_outputs method to get the results, then
|
67 |
+
clears the stateless memory and removes the memory key from the results
|
68 |
+
dictionary. It then returns the updated results.
|
69 |
+
"""
|
70 |
+
|
71 |
+
results = super().prep_outputs(inputs, outputs, return_only_outputs)
|
72 |
+
|
73 |
+
# Clear the internal memory.
|
74 |
+
self.memory.clear()
|
75 |
+
if self.memory.memory_key in results:
|
76 |
+
results[self.memory.memory_key] = ""
|
77 |
+
|
78 |
+
return results
|
79 |
+
|
80 |
+
|
81 |
+
class ContextExtractorChain(Chain):
|
82 |
+
"""
|
83 |
+
Encode the question, search the vector store for top-k articles and return
|
84 |
+
context news from documents collection of Alpaca news.
|
85 |
+
|
86 |
+
Attributes:
|
87 |
+
-----------
|
88 |
+
top_k : int
|
89 |
+
The number of top matches to retrieve from the vector store.
|
90 |
+
embedding_model : EmbeddingModelSingleton
|
91 |
+
The embedding model to use for encoding the question.
|
92 |
+
vector_store : qdrant_client.QdrantClient
|
93 |
+
The vector store to search for matches.
|
94 |
+
vector_collection : str
|
95 |
+
The name of the collection to search in the vector store.
|
96 |
+
"""
|
97 |
+
|
98 |
+
top_k: int = 1
|
99 |
+
embedding_model: EmbeddingModelSingleton
|
100 |
+
vector_store: qdrant_client.QdrantClient
|
101 |
+
vector_collection: str
|
102 |
+
|
103 |
+
@property
|
104 |
+
def input_keys(self) -> List[str]:
|
105 |
+
return ["about_me", "question"]
|
106 |
+
|
107 |
+
@property
|
108 |
+
def output_keys(self) -> List[str]:
|
109 |
+
return ["context"]
|
110 |
+
|
111 |
+
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
112 |
+
_, quest_key = self.input_keys
|
113 |
+
question_str = inputs[quest_key]
|
114 |
+
|
115 |
+
cleaned_question = self.clean(question_str)
|
116 |
+
# TODO: Instead of cutting the question at 'max_input_length', chunk the question in 'max_input_length' chunks,
|
117 |
+
# pass them through the model and average the embeddings.
|
118 |
+
cleaned_question = cleaned_question[: self.embedding_model.max_input_length]
|
119 |
+
embeddings = self.embedding_model(cleaned_question)
|
120 |
+
|
121 |
+
# TODO: Using the metadata, use the filter to take into consideration only the news from the last 24 hours
|
122 |
+
# (or other time frame).
|
123 |
+
matches = self.vector_store.search(
|
124 |
+
query_vector=embeddings,
|
125 |
+
k=self.top_k,
|
126 |
+
collection_name=self.vector_collection,
|
127 |
+
)
|
128 |
+
|
129 |
+
context = ""
|
130 |
+
for match in matches:
|
131 |
+
context += match.payload["summary"] + "\n"
|
132 |
+
|
133 |
+
return {
|
134 |
+
"context": context,
|
135 |
+
}
|
136 |
+
|
137 |
+
def clean(self, question: str) -> str:
|
138 |
+
"""
|
139 |
+
Clean the input question by removing unwanted characters.
|
140 |
+
|
141 |
+
Parameters:
|
142 |
+
-----------
|
143 |
+
question : str
|
144 |
+
The input question to clean.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
--------
|
148 |
+
str
|
149 |
+
The cleaned question.
|
150 |
+
"""
|
151 |
+
question = clean(question)
|
152 |
+
question = replace_unicode_quotes(question)
|
153 |
+
question = clean_non_ascii_chars(question)
|
154 |
+
|
155 |
+
return question
|
156 |
+
|
157 |
+
|
158 |
+
class FinancialBotQAChain(Chain):
|
159 |
+
"""This custom chain handles LLM generation upon given prompt"""
|
160 |
+
|
161 |
+
hf_pipeline: HuggingFacePipeline
|
162 |
+
template: PromptTemplate
|
163 |
+
|
164 |
+
@property
|
165 |
+
def input_keys(self) -> List[str]:
|
166 |
+
"""Returns a list of input keys for the chain"""
|
167 |
+
|
168 |
+
return ["context"]
|
169 |
+
|
170 |
+
@property
|
171 |
+
def output_keys(self) -> List[str]:
|
172 |
+
"""Returns a list of output keys for the chain"""
|
173 |
+
|
174 |
+
return ["answer"]
|
175 |
+
|
176 |
+
def _call(
|
177 |
+
self,
|
178 |
+
inputs: Dict[str, Any],
|
179 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
180 |
+
) -> Dict[str, Any]:
|
181 |
+
"""Calls the chain with the given inputs and returns the output"""
|
182 |
+
|
183 |
+
inputs = self.clean(inputs)
|
184 |
+
prompt = self.template.format_infer(
|
185 |
+
{
|
186 |
+
"user_context": inputs["about_me"],
|
187 |
+
"news_context": inputs["context"],
|
188 |
+
"chat_history": inputs["chat_history"],
|
189 |
+
"question": inputs["question"],
|
190 |
+
}
|
191 |
+
)
|
192 |
+
|
193 |
+
start_time = time.time()
|
194 |
+
response = self.hf_pipeline(prompt["prompt"])
|
195 |
+
end_time = time.time()
|
196 |
+
duration_milliseconds = (end_time - start_time) * 1000
|
197 |
+
|
198 |
+
if run_manager:
|
199 |
+
run_manager.on_chain_end(
|
200 |
+
outputs={
|
201 |
+
"answer": response,
|
202 |
+
},
|
203 |
+
# TODO: Count tokens instead of using len().
|
204 |
+
metadata={
|
205 |
+
"prompt": prompt["prompt"],
|
206 |
+
"prompt_template_variables": prompt["payload"],
|
207 |
+
"prompt_template": self.template.infer_raw_template,
|
208 |
+
"usage.prompt_tokens": len(prompt["prompt"]),
|
209 |
+
"usage.total_tokens": len(prompt["prompt"]) + len(response),
|
210 |
+
"usage.actual_new_tokens": len(response),
|
211 |
+
"duration_milliseconds": duration_milliseconds,
|
212 |
+
},
|
213 |
+
)
|
214 |
+
|
215 |
+
return {"answer": response}
|
216 |
+
|
217 |
+
def clean(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
218 |
+
"""Cleans the inputs by removing extra whitespace and grouping broken paragraphs"""
|
219 |
+
|
220 |
+
for key, input in inputs.items():
|
221 |
+
cleaned_input = clean_extra_whitespace(input)
|
222 |
+
cleaned_input = group_broken_paragraphs(cleaned_input)
|
223 |
+
|
224 |
+
inputs[key] = cleaned_input
|
225 |
+
|
226 |
+
return inputs
|
financial_bot/constants.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
# == Embeddings model ==
|
4 |
+
EMBEDDING_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
|
5 |
+
EMBEDDING_MODEL_MAX_INPUT_LENGTH = 384
|
6 |
+
|
7 |
+
# == VECTOR Database ==
|
8 |
+
VECTOR_DB_OUTPUT_COLLECTION_NAME = "alpaca_financial_news"
|
9 |
+
VECTOR_DB_SEARCH_TOPK = 1
|
10 |
+
|
11 |
+
# == LLM Model ==
|
12 |
+
LLM_MODEL_ID = "unsloth/mistral-7b-instruct-v0.2-bnb-4bit"
|
13 |
+
LLM_QLORA_CHECKPOINT = "plantbased/mistral-7b-instruct-v0.2-4bit"
|
14 |
+
|
15 |
+
LLM_INFERNECE_MAX_NEW_TOKENS = 500
|
16 |
+
LLM_INFERENCE_TEMPERATURE = 1.0
|
17 |
+
|
18 |
+
|
19 |
+
# == Prompt Template ==
|
20 |
+
TEMPLATE_NAME = "mistral"
|
21 |
+
|
22 |
+
# === Misc ===
|
23 |
+
CACHE_DIR = Path.home() / ".cache" / "hands-on-llms"
|
financial_bot/embeddings.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import traceback
|
3 |
+
from typing import Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
+
|
8 |
+
from financial_bot import constants
|
9 |
+
from financial_bot.base import SingletonMeta
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class EmbeddingModelSingleton(metaclass=SingletonMeta):
|
15 |
+
"""
|
16 |
+
A singleton class that provides a pre-trained transformer model for generating embeddings of input text.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_id (str): The identifier of the pre-trained transformer model to use.
|
20 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
21 |
+
device (str): The device to use for running the model (e.g. "cpu", "cuda").
|
22 |
+
cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
|
23 |
+
If None, the default cache directory is used.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
27 |
+
tokenizer (AutoTokenizer): The tokenizer used to tokenize input text.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
model_id: str = constants.EMBEDDING_MODEL_ID,
|
33 |
+
max_input_length: int = constants.EMBEDDING_MODEL_MAX_INPUT_LENGTH,
|
34 |
+
device: str = "cuda:0",
|
35 |
+
cache_dir: Optional[str] = None,
|
36 |
+
):
|
37 |
+
"""
|
38 |
+
Initializes the EmbeddingModelSingleton instance.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
model_id (str): The identifier of the pre-trained transformer model to use.
|
42 |
+
max_input_length (int): The maximum length of input text to tokenize.
|
43 |
+
device (str): The device to use for running the model (e.g. "cpu", "cuda").
|
44 |
+
cache_dir (Optional[Path]): The directory to cache the pre-trained model files.
|
45 |
+
If None, the default cache directory is used.
|
46 |
+
"""
|
47 |
+
|
48 |
+
self._model_id = model_id
|
49 |
+
self._device = device
|
50 |
+
self._max_input_length = max_input_length
|
51 |
+
|
52 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_id)
|
53 |
+
self._model = AutoModel.from_pretrained(
|
54 |
+
model_id,
|
55 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
56 |
+
).to(self._device)
|
57 |
+
self._model.eval()
|
58 |
+
|
59 |
+
@property
|
60 |
+
def max_input_length(self) -> int:
|
61 |
+
"""
|
62 |
+
Returns the maximum length of input text to tokenize.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
int: The maximum length of input text to tokenize.
|
66 |
+
"""
|
67 |
+
|
68 |
+
return self._max_input_length
|
69 |
+
|
70 |
+
@property
|
71 |
+
def tokenizer(self) -> AutoTokenizer:
|
72 |
+
"""
|
73 |
+
Returns the tokenizer used to tokenize input text.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
AutoTokenizer: The tokenizer used to tokenize input text.
|
77 |
+
"""
|
78 |
+
|
79 |
+
return self._tokenizer
|
80 |
+
|
81 |
+
def __call__(
|
82 |
+
self, input_text: str, to_list: bool = True
|
83 |
+
) -> Union[np.ndarray, list]:
|
84 |
+
"""
|
85 |
+
Generates embeddings for the input text using the pre-trained transformer model.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
input_text (str): The input text to generate embeddings for.
|
89 |
+
to_list (bool): Whether to return the embeddings as a list or numpy array. Defaults to True.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Union[np.ndarray, list]: The embeddings generated for the input text.
|
93 |
+
"""
|
94 |
+
|
95 |
+
try:
|
96 |
+
tokenized_text = self._tokenizer(
|
97 |
+
input_text,
|
98 |
+
padding=True,
|
99 |
+
truncation=True,
|
100 |
+
return_tensors="pt",
|
101 |
+
max_length=self._max_input_length,
|
102 |
+
).to(self._device)
|
103 |
+
except Exception:
|
104 |
+
logger.error(traceback.format_exc())
|
105 |
+
logger.error(f"Error tokenizing the following input text: {input_text}")
|
106 |
+
|
107 |
+
return [] if to_list else np.array([])
|
108 |
+
|
109 |
+
try:
|
110 |
+
result = self._model(**tokenized_text)
|
111 |
+
except Exception:
|
112 |
+
logger.error(traceback.format_exc())
|
113 |
+
logger.error(
|
114 |
+
f"Error generating embeddings for the following model_id: {self._model_id} and input text: {input_text}"
|
115 |
+
)
|
116 |
+
|
117 |
+
return [] if to_list else np.array([])
|
118 |
+
|
119 |
+
embeddings = result.last_hidden_state[:, 0, :].cpu().detach().numpy()
|
120 |
+
if to_list:
|
121 |
+
embeddings = embeddings.flatten().tolist()
|
122 |
+
|
123 |
+
return embeddings
|
financial_bot/handlers.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import comet_llm
|
4 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
5 |
+
|
6 |
+
from financial_bot import constants
|
7 |
+
|
8 |
+
|
9 |
+
class CometLLMMonitoringHandler(BaseCallbackHandler):
|
10 |
+
"""
|
11 |
+
A callback handler for monitoring LLM models using Comet.ml.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
project_name (str): The name of the Comet.ml project to log to.
|
15 |
+
llm_model_id (str): The ID of the LLM model to use for inference.
|
16 |
+
llm_qlora_model_id (str): The ID of the PEFT model to use for inference.
|
17 |
+
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
|
18 |
+
llm_inference_temperature (float): The temperature to use during inference.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
project_name: str = None,
|
24 |
+
llm_model_id: str = constants.LLM_MODEL_ID,
|
25 |
+
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
|
26 |
+
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
27 |
+
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
28 |
+
):
|
29 |
+
self._project_name = project_name
|
30 |
+
self._llm_model_id = llm_model_id
|
31 |
+
self._llm_qlora_model_id = llm_qlora_model_id
|
32 |
+
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
|
33 |
+
self._llm_inference_temperature = llm_inference_temperature
|
34 |
+
|
35 |
+
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
36 |
+
"""
|
37 |
+
A callback function that logs the prompt and output to Comet.ml.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
outputs (Dict[str, Any]): The output of the LLM model.
|
41 |
+
**kwargs (Any): Additional arguments passed to the function.
|
42 |
+
"""
|
43 |
+
|
44 |
+
should_log_prompt = "metadata" in kwargs
|
45 |
+
if should_log_prompt:
|
46 |
+
metadata = kwargs["metadata"]
|
47 |
+
|
48 |
+
comet_llm.log_prompt(
|
49 |
+
project=self._project_name,
|
50 |
+
prompt=metadata["prompt"],
|
51 |
+
output=outputs["answer"],
|
52 |
+
prompt_template=metadata["prompt_template"],
|
53 |
+
prompt_template_variables=metadata["prompt_template_variables"],
|
54 |
+
metadata={
|
55 |
+
"usage.prompt_tokens": metadata["usage.prompt_tokens"],
|
56 |
+
"usage.total_tokens": metadata["usage.total_tokens"],
|
57 |
+
"usage.max_new_tokens": self._llm_inference_max_new_tokens,
|
58 |
+
"usage.temperature": self._llm_inference_temperature,
|
59 |
+
"usage.actual_new_tokens": metadata["usage.actual_new_tokens"],
|
60 |
+
"model": self._llm_model_id,
|
61 |
+
"peft_model": self._llm_qlora_model_id,
|
62 |
+
},
|
63 |
+
duration=metadata["duration_milliseconds"],
|
64 |
+
)
|
financial_bot/langchain_bot.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Iterable, List, Tuple
|
5 |
+
|
6 |
+
from langchain import chains
|
7 |
+
from langchain.memory import ConversationBufferWindowMemory
|
8 |
+
|
9 |
+
from financial_bot import constants
|
10 |
+
from financial_bot.chains import (
|
11 |
+
ContextExtractorChain,
|
12 |
+
FinancialBotQAChain,
|
13 |
+
StatelessMemorySequentialChain,
|
14 |
+
)
|
15 |
+
from financial_bot.embeddings import EmbeddingModelSingleton
|
16 |
+
from financial_bot.handlers import CometLLMMonitoringHandler
|
17 |
+
from financial_bot.models import build_huggingface_pipeline
|
18 |
+
from financial_bot.qdrant import build_qdrant_client
|
19 |
+
from financial_bot.template import get_llm_template
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class FinancialBot:
|
25 |
+
"""
|
26 |
+
A language chain bot that uses a language model to generate responses to user inputs.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
llm_model_id (str): The ID of the Hugging Face language model to use.
|
30 |
+
llm_qlora_model_id (str): The ID of the Hugging Face QLora model to use.
|
31 |
+
llm_template_name (str): The name of the LLM template to use.
|
32 |
+
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
|
33 |
+
llm_inference_temperature (float): The temperature to use during inference.
|
34 |
+
vector_collection_name (str): The name of the Qdrant vector collection to use.
|
35 |
+
vector_db_search_topk (int): The number of nearest neighbors to search for in the Qdrant vector database.
|
36 |
+
model_cache_dir (Path): The directory to use for caching the language model and embedding model.
|
37 |
+
streaming (bool): Whether to use the Hugging Face streaming API for inference.
|
38 |
+
embedding_model_device (str): The device to use for the embedding model.
|
39 |
+
debug (bool): Whether to enable debug mode.
|
40 |
+
|
41 |
+
Attributes:
|
42 |
+
finbot_chain (Chain): The language chain that generates responses to user inputs.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
llm_model_id: str = constants.LLM_MODEL_ID,
|
48 |
+
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
|
49 |
+
llm_template_name: str = constants.TEMPLATE_NAME,
|
50 |
+
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
51 |
+
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
52 |
+
vector_collection_name: str = constants.VECTOR_DB_OUTPUT_COLLECTION_NAME,
|
53 |
+
vector_db_search_topk: int = constants.VECTOR_DB_SEARCH_TOPK,
|
54 |
+
model_cache_dir: Path = constants.CACHE_DIR,
|
55 |
+
streaming: bool = False,
|
56 |
+
embedding_model_device: str = "cuda:0",
|
57 |
+
debug: bool = False,
|
58 |
+
):
|
59 |
+
self._llm_model_id = llm_model_id
|
60 |
+
self._llm_qlora_model_id = llm_qlora_model_id
|
61 |
+
self._llm_template_name = llm_template_name
|
62 |
+
self._llm_template = get_llm_template(name=self._llm_template_name)
|
63 |
+
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
|
64 |
+
self._llm_inference_temperature = llm_inference_temperature
|
65 |
+
self._vector_collection_name = vector_collection_name
|
66 |
+
self._vector_db_search_topk = vector_db_search_topk
|
67 |
+
self._debug = debug
|
68 |
+
|
69 |
+
self._qdrant_client = build_qdrant_client()
|
70 |
+
|
71 |
+
self._embd_model = EmbeddingModelSingleton(
|
72 |
+
cache_dir=model_cache_dir, device=embedding_model_device
|
73 |
+
)
|
74 |
+
self._llm_agent, self._streamer = build_huggingface_pipeline(
|
75 |
+
llm_model_id=llm_model_id,
|
76 |
+
llm_lora_model_id=llm_qlora_model_id,
|
77 |
+
max_new_tokens=llm_inference_max_new_tokens,
|
78 |
+
temperature=llm_inference_temperature,
|
79 |
+
use_streamer=streaming,
|
80 |
+
cache_dir=model_cache_dir,
|
81 |
+
debug=debug,
|
82 |
+
)
|
83 |
+
self.finbot_chain = self.build_chain()
|
84 |
+
|
85 |
+
@property
|
86 |
+
def is_streaming(self) -> bool:
|
87 |
+
return self._streamer is not None
|
88 |
+
|
89 |
+
def build_chain(self) -> chains.SequentialChain:
|
90 |
+
"""
|
91 |
+
Constructs and returns a financial bot chain.
|
92 |
+
This chain is designed to take as input the user description, `about_me` and a `question` and it will
|
93 |
+
connect to the VectorDB, searches the financial news that rely on the user's question and injects them into the
|
94 |
+
payload that is further passed as a prompt to a financial fine-tuned LLM that will provide answers.
|
95 |
+
|
96 |
+
The chain consists of two primary stages:
|
97 |
+
1. Context Extractor: This stage is responsible for embedding the user's question,
|
98 |
+
which means converting the textual question into a numerical representation.
|
99 |
+
This embedded question is then used to retrieve relevant context from the VectorDB.
|
100 |
+
The output of this chain will be a dict payload.
|
101 |
+
|
102 |
+
2. LLM Generator: Once the context is extracted,
|
103 |
+
this stage uses it to format a full prompt for the LLM and
|
104 |
+
then feed it to the model to get a response that is relevant to the user's question.
|
105 |
+
|
106 |
+
Returns
|
107 |
+
-------
|
108 |
+
chains.SequentialChain
|
109 |
+
The constructed financial bot chain.
|
110 |
+
|
111 |
+
Notes
|
112 |
+
-----
|
113 |
+
The actual processing flow within the chain can be visualized as:
|
114 |
+
[about: str][question: str] > ContextChain >
|
115 |
+
[about: str][question:str] + [context: str] > FinancialChain >
|
116 |
+
[answer: str]
|
117 |
+
"""
|
118 |
+
|
119 |
+
logger.info("Building 1/3 - ContextExtractorChain")
|
120 |
+
context_retrieval_chain = ContextExtractorChain(
|
121 |
+
embedding_model=self._embd_model,
|
122 |
+
vector_store=self._qdrant_client,
|
123 |
+
vector_collection=self._vector_collection_name,
|
124 |
+
top_k=self._vector_db_search_topk,
|
125 |
+
)
|
126 |
+
|
127 |
+
logger.info("Building 2/3 - FinancialBotQAChain")
|
128 |
+
if self._debug:
|
129 |
+
callabacks = []
|
130 |
+
else:
|
131 |
+
try:
|
132 |
+
comet_project_name = os.environ["COMET_PROJECT_NAME"]
|
133 |
+
except KeyError:
|
134 |
+
raise RuntimeError(
|
135 |
+
"Please set the COMET_PROJECT_NAME environment variable."
|
136 |
+
)
|
137 |
+
callabacks = [
|
138 |
+
CometLLMMonitoringHandler(
|
139 |
+
project_name=f"{comet_project_name}-monitor-prompts",
|
140 |
+
llm_model_id=self._llm_model_id,
|
141 |
+
llm_qlora_model_id=self._llm_qlora_model_id,
|
142 |
+
llm_inference_max_new_tokens=self._llm_inference_max_new_tokens,
|
143 |
+
llm_inference_temperature=self._llm_inference_temperature,
|
144 |
+
)
|
145 |
+
]
|
146 |
+
llm_generator_chain = FinancialBotQAChain(
|
147 |
+
hf_pipeline=self._llm_agent,
|
148 |
+
template=self._llm_template,
|
149 |
+
callbacks=callabacks,
|
150 |
+
)
|
151 |
+
|
152 |
+
logger.info("Building 3/3 - Connecting chains into SequentialChain")
|
153 |
+
seq_chain = StatelessMemorySequentialChain(
|
154 |
+
history_input_key="to_load_history",
|
155 |
+
memory=ConversationBufferWindowMemory(
|
156 |
+
memory_key="chat_history",
|
157 |
+
input_key="question",
|
158 |
+
output_key="answer",
|
159 |
+
k=3,
|
160 |
+
),
|
161 |
+
chains=[context_retrieval_chain, llm_generator_chain],
|
162 |
+
input_variables=["about_me", "question", "to_load_history"],
|
163 |
+
output_variables=["answer"],
|
164 |
+
verbose=True,
|
165 |
+
)
|
166 |
+
|
167 |
+
logger.info("Done building SequentialChain.")
|
168 |
+
logger.info("Workflow:")
|
169 |
+
logger.info(
|
170 |
+
"""
|
171 |
+
[about: str][question: str] > ContextChain >
|
172 |
+
[about: str][question:str] + [context: str] > FinancialChain >
|
173 |
+
[answer: str]
|
174 |
+
"""
|
175 |
+
)
|
176 |
+
|
177 |
+
return seq_chain
|
178 |
+
|
179 |
+
def answer(
|
180 |
+
self,
|
181 |
+
about_me: str,
|
182 |
+
question: str,
|
183 |
+
to_load_history: List[Tuple[str, str]] = None,
|
184 |
+
) -> str:
|
185 |
+
"""
|
186 |
+
Given a short description about the user and a question make the LLM
|
187 |
+
generate a response.
|
188 |
+
|
189 |
+
Parameters
|
190 |
+
----------
|
191 |
+
about_me : str
|
192 |
+
Short user description.
|
193 |
+
question : str
|
194 |
+
User question.
|
195 |
+
|
196 |
+
Returns
|
197 |
+
-------
|
198 |
+
str
|
199 |
+
LLM generated response.
|
200 |
+
"""
|
201 |
+
|
202 |
+
inputs = {
|
203 |
+
"about_me": about_me,
|
204 |
+
"question": question,
|
205 |
+
"to_load_history": to_load_history if to_load_history else [],
|
206 |
+
}
|
207 |
+
response = self.finbot_chain.run(inputs)
|
208 |
+
|
209 |
+
return response
|
210 |
+
|
211 |
+
def stream_answer(self) -> Iterable[str]:
|
212 |
+
"""Stream the answer from the LLM after each token is generated after calling `answer()`."""
|
213 |
+
|
214 |
+
assert (
|
215 |
+
self.is_streaming
|
216 |
+
), "Stream answer not available. Build the bot with `use_streamer=True`."
|
217 |
+
|
218 |
+
partial_answer = ""
|
219 |
+
for new_token in self._streamer:
|
220 |
+
if new_token != self._llm_template.eos:
|
221 |
+
partial_answer += new_token
|
222 |
+
|
223 |
+
yield partial_answer
|
financial_bot/models.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from comet_ml import API
|
8 |
+
from langchain.llms import HuggingFacePipeline
|
9 |
+
from peft import LoraConfig, PeftConfig, PeftModel
|
10 |
+
from transformers import (
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
AutoTokenizer,
|
13 |
+
BitsAndBytesConfig,
|
14 |
+
StoppingCriteria,
|
15 |
+
StoppingCriteriaList,
|
16 |
+
TextIteratorStreamer,
|
17 |
+
pipeline,
|
18 |
+
)
|
19 |
+
|
20 |
+
from financial_bot import constants
|
21 |
+
from financial_bot.utils import MockedPipeline
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
def download_from_model_registry(
|
27 |
+
model_id: str, cache_dir: Optional[Path] = None
|
28 |
+
) -> Path:
|
29 |
+
"""
|
30 |
+
Downloads a model from the Comet ML Learning model registry.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model_id (str): The ID of the model to download, in the format "workspace/model_name:version".
|
34 |
+
cache_dir (Optional[Path]): The directory to cache the downloaded model in. Defaults to the value of
|
35 |
+
`constants.CACHE_DIR`.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Path: The path to the downloaded model directory.
|
39 |
+
"""
|
40 |
+
|
41 |
+
if cache_dir is None:
|
42 |
+
cache_dir = constants.CACHE_DIR
|
43 |
+
output_folder = cache_dir / "models" / model_id
|
44 |
+
|
45 |
+
already_downloaded = output_folder.exists()
|
46 |
+
if not already_downloaded:
|
47 |
+
workspace, model_id = model_id.split("/")
|
48 |
+
model_name, version = model_id.split(":")
|
49 |
+
|
50 |
+
api = API()
|
51 |
+
model = api.get_model(workspace=workspace, model_name=model_name)
|
52 |
+
model.download(version=version, output_folder=output_folder, expand=True)
|
53 |
+
else:
|
54 |
+
logger.info(f"Model {model_id=} already downloaded to: {output_folder}")
|
55 |
+
|
56 |
+
subdirs = [d for d in output_folder.iterdir() if d.is_dir()]
|
57 |
+
if len(subdirs) == 1:
|
58 |
+
model_dir = subdirs[0]
|
59 |
+
else:
|
60 |
+
raise RuntimeError(
|
61 |
+
f"There should be only one directory inside the model folder. \
|
62 |
+
Check the downloaded model at: {output_folder}"
|
63 |
+
)
|
64 |
+
|
65 |
+
logger.info(f"Model {model_id=} downloaded from the registry to: {model_dir}")
|
66 |
+
|
67 |
+
return model_dir
|
68 |
+
|
69 |
+
|
70 |
+
class StopOnTokens(StoppingCriteria):
|
71 |
+
"""
|
72 |
+
A stopping criteria that stops generation when a specific token is generated.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
stop_ids (List[int]): A list of token ids that will trigger the stopping criteria.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, stop_ids: List[int]):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self._stop_ids = stop_ids
|
82 |
+
|
83 |
+
def __call__(
|
84 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
85 |
+
) -> bool:
|
86 |
+
"""
|
87 |
+
Check if the last generated token is in the stop_ids list.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
input_ids (torch.LongTensor): The input token ids.
|
91 |
+
scores (torch.FloatTensor): The scores of the generated tokens.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
bool: True if the last generated token is in the stop_ids list, False otherwise.
|
95 |
+
"""
|
96 |
+
|
97 |
+
for stop_id in self._stop_ids:
|
98 |
+
if input_ids[0][-1] == stop_id:
|
99 |
+
return True
|
100 |
+
|
101 |
+
return False
|
102 |
+
|
103 |
+
|
104 |
+
def build_huggingface_pipeline(
|
105 |
+
llm_model_id: str,
|
106 |
+
llm_lora_model_id: str,
|
107 |
+
max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
|
108 |
+
temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
|
109 |
+
gradient_checkpointing: bool = False,
|
110 |
+
use_streamer: bool = False,
|
111 |
+
cache_dir: Optional[Path] = None,
|
112 |
+
debug: bool = False,
|
113 |
+
) -> Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]:
|
114 |
+
"""
|
115 |
+
Builds a HuggingFace pipeline for text generation using a custom LLM + Finetuned checkpoint.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
llm_model_id (str): The ID or path of the LLM model.
|
119 |
+
llm_lora_model_id (str): The ID or path of the LLM LoRA model.
|
120 |
+
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 128.
|
121 |
+
temperature (float, optional): The temperature to use for sampling. Defaults to 0.7.
|
122 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
|
123 |
+
use_streamer (bool, optional): Whether to use a text iterator streamer. Defaults to False.
|
124 |
+
cache_dir (Optional[Path], optional): The directory to use for caching. Defaults to None.
|
125 |
+
debug (bool, optional): Whether to use a mocked pipeline for debugging. Defaults to False.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tuple[HuggingFacePipeline, Optional[TextIteratorStreamer]]: A tuple containing the HuggingFace pipeline
|
129 |
+
and the text iterator streamer (if used).
|
130 |
+
"""
|
131 |
+
|
132 |
+
if debug is True:
|
133 |
+
return (
|
134 |
+
HuggingFacePipeline(
|
135 |
+
pipeline=MockedPipeline(f=lambda _: "You are doing great!")
|
136 |
+
),
|
137 |
+
None,
|
138 |
+
)
|
139 |
+
|
140 |
+
model, tokenizer, _ = build_qlora_model(
|
141 |
+
pretrained_model_name_or_path=llm_model_id,
|
142 |
+
peft_pretrained_model_name_or_path=llm_lora_model_id,
|
143 |
+
gradient_checkpointing=gradient_checkpointing,
|
144 |
+
cache_dir=cache_dir,
|
145 |
+
)
|
146 |
+
model.eval()
|
147 |
+
|
148 |
+
if use_streamer:
|
149 |
+
streamer = TextIteratorStreamer(
|
150 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
151 |
+
)
|
152 |
+
stop_on_tokens = StopOnTokens(stop_ids=[tokenizer.eos_token_id])
|
153 |
+
stopping_criteria = StoppingCriteriaList([stop_on_tokens])
|
154 |
+
else:
|
155 |
+
streamer = None
|
156 |
+
stopping_criteria = StoppingCriteriaList([])
|
157 |
+
|
158 |
+
pipe = pipeline(
|
159 |
+
"text-generation",
|
160 |
+
model=model,
|
161 |
+
tokenizer=tokenizer,
|
162 |
+
max_new_tokens=max_new_tokens,
|
163 |
+
temperature=temperature,
|
164 |
+
streamer=streamer,
|
165 |
+
stopping_criteria=stopping_criteria,
|
166 |
+
)
|
167 |
+
hf = HuggingFacePipeline(pipeline=pipe)
|
168 |
+
|
169 |
+
return hf, streamer
|
170 |
+
|
171 |
+
|
172 |
+
def build_qlora_model(
|
173 |
+
pretrained_model_name_or_path: str = "tiiuae/falcon-7b-instruct",
|
174 |
+
peft_pretrained_model_name_or_path: Optional[str] = None,
|
175 |
+
gradient_checkpointing: bool = True,
|
176 |
+
cache_dir: Optional[Path] = None,
|
177 |
+
) -> Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
|
178 |
+
"""
|
179 |
+
Function that builds a QLoRA LLM model based on the given HuggingFace name:
|
180 |
+
1. Create and prepare the bitsandbytes configuration for QLoRa's quantization
|
181 |
+
2. Download, load, and quantize on-the-fly Falcon-7b
|
182 |
+
3. Create and prepare the LoRa configuration
|
183 |
+
4. Load and configuration Falcon-7B's tokenizer
|
184 |
+
|
185 |
+
Args:
|
186 |
+
pretrained_model_name_or_path (str): The name or path of the pretrained model to use.
|
187 |
+
peft_pretrained_model_name_or_path (Optional[str]): The name or path of the PEFT pretrained model to use.
|
188 |
+
gradient_checkpointing (bool): Whether to use gradient checkpointing or not.
|
189 |
+
cache_dir (Optional[Path]): The directory to cache the downloaded models.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer, PeftConfig]:
|
193 |
+
A tuple containing the QLoRA LLM model, tokenizer, and PEFT config.
|
194 |
+
"""
|
195 |
+
|
196 |
+
bnb_config = BitsAndBytesConfig(
|
197 |
+
load_in_4bit=True,
|
198 |
+
bnb_4bit_use_double_quant=True,
|
199 |
+
bnb_4bit_quant_type="nf4",
|
200 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
201 |
+
)
|
202 |
+
|
203 |
+
model = AutoModelForCausalLM.from_pretrained(
|
204 |
+
pretrained_model_name_or_path,
|
205 |
+
revision="main",
|
206 |
+
quantization_config=bnb_config,
|
207 |
+
load_in_4bit=True,
|
208 |
+
device_map="auto",
|
209 |
+
trust_remote_code=False,
|
210 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
211 |
+
)
|
212 |
+
|
213 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
214 |
+
pretrained_model_name_or_path,
|
215 |
+
trust_remote_code=False,
|
216 |
+
truncation=True,
|
217 |
+
cache_dir=str(cache_dir) if cache_dir else None,
|
218 |
+
)
|
219 |
+
if tokenizer.pad_token_id is None:
|
220 |
+
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
|
221 |
+
with torch.no_grad():
|
222 |
+
model.resize_token_embeddings(len(tokenizer))
|
223 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
224 |
+
|
225 |
+
if peft_pretrained_model_name_or_path:
|
226 |
+
is_model_name = not os.path.isdir(peft_pretrained_model_name_or_path)
|
227 |
+
if is_model_name:
|
228 |
+
logger.info(
|
229 |
+
f"Downloading {peft_pretrained_model_name_or_path} from CometML Model Registry:"
|
230 |
+
)
|
231 |
+
peft_pretrained_model_name_or_path = download_from_model_registry(
|
232 |
+
model_id=peft_pretrained_model_name_or_path,
|
233 |
+
cache_dir=cache_dir,
|
234 |
+
)
|
235 |
+
|
236 |
+
logger.info(f"Loading Lora Confing from: {peft_pretrained_model_name_or_path}")
|
237 |
+
lora_config = LoraConfig.from_pretrained(peft_pretrained_model_name_or_path)
|
238 |
+
assert (
|
239 |
+
lora_config.base_model_name_or_path == pretrained_model_name_or_path
|
240 |
+
), f"Lora Model trained on different base model than the one requested: \
|
241 |
+
{lora_config.base_model_name_or_path} != {pretrained_model_name_or_path}"
|
242 |
+
|
243 |
+
logger.info(f"Loading Peft Model from: {peft_pretrained_model_name_or_path}")
|
244 |
+
model = PeftModel.from_pretrained(model, peft_pretrained_model_name_or_path)
|
245 |
+
else:
|
246 |
+
lora_config = LoraConfig(
|
247 |
+
lora_alpha=16,
|
248 |
+
lora_dropout=0.1,
|
249 |
+
r=64,
|
250 |
+
bias="none",
|
251 |
+
task_type="CAUSAL_LM",
|
252 |
+
target_modules=["query_key_value"],
|
253 |
+
)
|
254 |
+
|
255 |
+
if gradient_checkpointing:
|
256 |
+
model.gradient_checkpointing_enable()
|
257 |
+
model.config.use_cache = (
|
258 |
+
False # Gradient checkpointing is not compatible with caching.
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
model.gradient_checkpointing_disable()
|
262 |
+
model.config.use_cache = True # It is good practice to enable caching when using the model for inference.
|
263 |
+
|
264 |
+
return model, tokenizer, lora_config
|
financial_bot/qdrant.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import qdrant_client
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
def build_qdrant_client(
|
11 |
+
url: Optional[str] = None,
|
12 |
+
api_key: Optional[str] = None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Builds a Qdrant client object using the provided URL and API key.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
url (Optional[str]): The URL of the Qdrant server. If not provided, the function will attempt
|
19 |
+
to read it from the QDRANT_URL environment variable.
|
20 |
+
api_key (Optional[str]): The API key to use for authentication. If not provided, the function will attempt
|
21 |
+
to read it from the QDRANT_API_KEY environment variable.
|
22 |
+
|
23 |
+
Raises:
|
24 |
+
KeyError: If the URL or API key is not provided and cannot be read from the environment variables.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
qdrant_client.QdrantClient: A Qdrant client object.
|
28 |
+
"""
|
29 |
+
|
30 |
+
logger.info("Building QDrant Client")
|
31 |
+
if url is None:
|
32 |
+
try:
|
33 |
+
url = os.environ["QDRANT_URL"]
|
34 |
+
except KeyError:
|
35 |
+
raise KeyError(
|
36 |
+
"QDRANT_URL must be set as environment variable or manually passed as an argument."
|
37 |
+
)
|
38 |
+
|
39 |
+
if api_key is None:
|
40 |
+
try:
|
41 |
+
api_key = os.environ["QDRANT_API_KEY"]
|
42 |
+
except KeyError:
|
43 |
+
raise KeyError(
|
44 |
+
"QDRANT_API_KEY must be set as environment variable or manually passed as an argument."
|
45 |
+
)
|
46 |
+
|
47 |
+
client = qdrant_client.QdrantClient(url, api_key=api_key)
|
48 |
+
|
49 |
+
return client
|
financial_bot/template.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script defines a PromptTemplate class that assists in generating
|
3 |
+
conversation/prompt templates. The script facilitates formatting prompts
|
4 |
+
for inference and training by combining various context elements and user inputs.
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
import dataclasses
|
9 |
+
from typing import Dict, List, Union
|
10 |
+
|
11 |
+
|
12 |
+
@dataclasses.dataclass
|
13 |
+
class PromptTemplate:
|
14 |
+
"""A class that manages prompt templates"""
|
15 |
+
|
16 |
+
# The name of this template
|
17 |
+
name: str
|
18 |
+
# The template of the system prompt
|
19 |
+
system_template: str = "{system_message}"
|
20 |
+
# The template for the system context
|
21 |
+
context_template: str = "{user_context}\n{news_context}"
|
22 |
+
# The template for the conversation history
|
23 |
+
chat_history_template: str = "{chat_history}"
|
24 |
+
# The template of the user question
|
25 |
+
question_template: str = "{question}"
|
26 |
+
# The template of the system answer
|
27 |
+
answer_template: str = "{answer}"
|
28 |
+
# The system message
|
29 |
+
system_message: str = ""
|
30 |
+
# Separator
|
31 |
+
sep: str = "\n"
|
32 |
+
eos: str = "</s>"
|
33 |
+
|
34 |
+
@property
|
35 |
+
def input_variables(self) -> List[str]:
|
36 |
+
"""Returns a list of input variables for the prompt template"""
|
37 |
+
|
38 |
+
return ["user_context", "news_context", "chat_history", "question", "answer"]
|
39 |
+
|
40 |
+
@property
|
41 |
+
def train_raw_template(self):
|
42 |
+
"""Returns the training prompt template format"""
|
43 |
+
|
44 |
+
system = self.system_template.format(system_message=self.system_message)
|
45 |
+
context = f"{self.sep}{self.context_template}"
|
46 |
+
chat_history = f"{self.sep}{self.chat_history_template}"
|
47 |
+
question = f"{self.sep}{self.question_template}"
|
48 |
+
answer = f"{self.sep}{self.answer_template}"
|
49 |
+
|
50 |
+
return f"{system}{context}{chat_history}{question}{answer}{self.eos}"
|
51 |
+
|
52 |
+
@property
|
53 |
+
def infer_raw_template(self):
|
54 |
+
"""Returns the inference prompt template format"""
|
55 |
+
|
56 |
+
system = self.system_template.format(system_message=self.system_message)
|
57 |
+
context = f"{self.sep}{self.context_template}"
|
58 |
+
chat_history = f"{self.sep}{self.chat_history_template}"
|
59 |
+
question = f"{self.sep}{self.question_template}"
|
60 |
+
|
61 |
+
return f"{system}{context}{chat_history}{question}{self.eos}"
|
62 |
+
|
63 |
+
def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
|
64 |
+
"""Formats the data sample to a training sample"""
|
65 |
+
|
66 |
+
prompt = self.train_raw_template.format(
|
67 |
+
user_context=sample["user_context"],
|
68 |
+
news_context=sample["news_context"],
|
69 |
+
chat_history=sample.get("chat_history", ""),
|
70 |
+
question=sample["question"],
|
71 |
+
answer=sample["answer"],
|
72 |
+
)
|
73 |
+
return {"prompt": prompt, "payload": sample}
|
74 |
+
|
75 |
+
def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
|
76 |
+
"""Formats the data sample to a testing sample"""
|
77 |
+
|
78 |
+
prompt = self.infer_raw_template.format(
|
79 |
+
user_context=sample["user_context"],
|
80 |
+
news_context=sample["news_context"],
|
81 |
+
chat_history=sample.get("chat_history", ""),
|
82 |
+
question=sample["question"],
|
83 |
+
)
|
84 |
+
return {"prompt": prompt, "payload": sample}
|
85 |
+
|
86 |
+
|
87 |
+
# Global Templates registry
|
88 |
+
templates: Dict[str, PromptTemplate] = {}
|
89 |
+
|
90 |
+
|
91 |
+
def register_llm_template(template: PromptTemplate):
|
92 |
+
"""Register a new template to the global templates registry"""
|
93 |
+
|
94 |
+
templates[template.name] = template
|
95 |
+
|
96 |
+
|
97 |
+
def get_llm_template(name: str) -> PromptTemplate:
|
98 |
+
"""Returns the template assigned to the given name"""
|
99 |
+
|
100 |
+
return templates[name]
|
101 |
+
|
102 |
+
|
103 |
+
##### Register Templates #####
|
104 |
+
# - Mistral 7B Instruct v0.2 Template
|
105 |
+
register_llm_template(
|
106 |
+
PromptTemplate(
|
107 |
+
name="mistral",
|
108 |
+
system_template="<s>{system_message}",
|
109 |
+
system_message="You are a helpful assistant, with financial expertise.",
|
110 |
+
context_template="{user_context}\n{news_context}",
|
111 |
+
chat_history_template="Summary: {chat_history}",
|
112 |
+
question_template="[INST] {question} [/INST]",
|
113 |
+
answer_template="{answer}",
|
114 |
+
sep="\n",
|
115 |
+
eos=" </s>",
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
# - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json)
|
120 |
+
register_llm_template(
|
121 |
+
PromptTemplate(
|
122 |
+
name="falcon",
|
123 |
+
system_template=">>INTRODUCTION<< {system_message}",
|
124 |
+
system_message="You are a helpful assistant, with financial expertise.",
|
125 |
+
context_template=">>DOMAIN<< {user_context}\n{news_context}",
|
126 |
+
chat_history_template=">>SUMMARY<< {chat_history}",
|
127 |
+
question_template=">>QUESTION<< {question}",
|
128 |
+
answer_template=">>ANSWER<< {answer}",
|
129 |
+
sep="\n",
|
130 |
+
eos="<|endoftext|>",
|
131 |
+
)
|
132 |
+
)
|
financial_bot/utils.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
from typing import Callable, Dict, List
|
5 |
+
|
6 |
+
import psutil
|
7 |
+
import torch
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def log_available_gpu_memory():
|
13 |
+
"""
|
14 |
+
Logs the available GPU memory for each available GPU device.
|
15 |
+
|
16 |
+
If no GPUs are available, logs "No GPUs available".
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
None
|
20 |
+
"""
|
21 |
+
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
for i in range(torch.cuda.device_count()):
|
24 |
+
memory_info = subprocess.check_output(
|
25 |
+
f"nvidia-smi -i {i} --query-gpu=memory.free --format=csv,nounits,noheader",
|
26 |
+
shell=True,
|
27 |
+
)
|
28 |
+
memory_info = str(memory_info).split("\\")[0][2:]
|
29 |
+
|
30 |
+
logger.info(f"GPU {i} memory available: {memory_info} MiB")
|
31 |
+
else:
|
32 |
+
logger.info("No GPUs available")
|
33 |
+
|
34 |
+
|
35 |
+
def log_available_ram():
|
36 |
+
"""
|
37 |
+
Logs the amount of available RAM in gigabytes.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
None
|
41 |
+
"""
|
42 |
+
|
43 |
+
memory_info = psutil.virtual_memory()
|
44 |
+
|
45 |
+
# convert bytes to GB
|
46 |
+
logger.info(f"Available RAM: {memory_info.available / (1024.0 ** 3):.2f} GB")
|
47 |
+
|
48 |
+
|
49 |
+
def log_files_and_subdirs(directory_path: str):
|
50 |
+
"""
|
51 |
+
Logs all files and subdirectories in the specified directory.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
directory_path (str): The path to the directory to log.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
None
|
58 |
+
"""
|
59 |
+
|
60 |
+
# Check if the directory exists
|
61 |
+
if os.path.exists(directory_path) and os.path.isdir(directory_path):
|
62 |
+
for dirpath, dirnames, filenames in os.walk(directory_path):
|
63 |
+
logger.info(f"Directory: {dirpath}")
|
64 |
+
for filename in filenames:
|
65 |
+
logger.info(f"File: {os.path.join(dirpath, filename)}")
|
66 |
+
for dirname in dirnames:
|
67 |
+
logger.info(f"Sub-directory: {os.path.join(dirpath, dirname)}")
|
68 |
+
else:
|
69 |
+
logger.info(f"The directory '{directory_path}' does not exist")
|
70 |
+
|
71 |
+
|
72 |
+
class MockedPipeline:
|
73 |
+
"""
|
74 |
+
A mocked pipeline class that is used as a replacement to the HF pipeline class.
|
75 |
+
|
76 |
+
Attributes:
|
77 |
+
-----------
|
78 |
+
task : str
|
79 |
+
The task of the pipeline, which is text-generation.
|
80 |
+
f : Callable[[str], str]
|
81 |
+
A function that takes a prompt string as input and returns a generated text string.
|
82 |
+
"""
|
83 |
+
|
84 |
+
task: str = "text-generation"
|
85 |
+
|
86 |
+
def __init__(self, f: Callable[[str], str]):
|
87 |
+
self.f = f
|
88 |
+
|
89 |
+
def __call__(self, prompt: str) -> List[Dict[str, str]]:
|
90 |
+
"""
|
91 |
+
Calls the pipeline with a given prompt and returns a list of generated text.
|
92 |
+
|
93 |
+
Parameters:
|
94 |
+
-----------
|
95 |
+
prompt : str
|
96 |
+
The prompt string to generate text from.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
--------
|
100 |
+
List[Dict[str, str]]
|
101 |
+
A list of dictionaries, where each dictionary contains a generated_text key with the generated text string.
|
102 |
+
"""
|
103 |
+
|
104 |
+
result = self.f(prompt)
|
105 |
+
|
106 |
+
return [{"generated_text": f"{prompt}{result}"}]
|