Update utils.py
Browse files
utils.py
CHANGED
@@ -12,8 +12,6 @@ from collections import deque
|
|
12 |
from typing import Tuple
|
13 |
import torch
|
14 |
|
15 |
-
import streamlit as st
|
16 |
-
|
17 |
# LangChain components
|
18 |
from langchain_community.document_loaders import PyPDFLoader
|
19 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -26,31 +24,21 @@ from rank_bm25 import BM25Okapi
|
|
26 |
from sentence_transformers import CrossEncoder
|
27 |
from sklearn.metrics.pairwise import cosine_similarity
|
28 |
|
29 |
-
import sys
|
30 |
-
|
31 |
-
sys.path.append('/mount/src/gen_ai_dev')
|
32 |
-
|
33 |
-
# these three lines swap the stdlib sqlite3 lib with the pysqlite3 package
|
34 |
-
import pysqlite3
|
35 |
-
import sys
|
36 |
-
sys.modules["sqlite3"] = pysqlite3
|
37 |
-
|
38 |
-
__import__('pysqlite3')
|
39 |
-
import sys
|
40 |
-
|
41 |
-
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
|
42 |
-
|
43 |
# Initialize NLTK stopwords
|
44 |
# nltk.download('stopwords')
|
45 |
# stop_words = set(stopwords.words('english'))
|
46 |
nltk.data.path.append('./nltk_data') # Point to local NLTK data
|
47 |
stop_words = set(nltk.corpus.stopwords.words('english'))
|
48 |
|
|
|
|
|
|
|
|
|
49 |
# Configuration
|
50 |
DATA_PATH = "./Infy financial report/"
|
51 |
DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
|
52 |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
53 |
-
LLM_MODEL = "
|
54 |
|
55 |
# Environment settings
|
56 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -92,24 +80,12 @@ def load_and_chunk_documents():
|
|
92 |
text_chunks = load_and_chunk_documents()
|
93 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
# Initialize embeddings
|
102 |
-
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
103 |
-
|
104 |
-
# Create and return Chroma vector store
|
105 |
-
return Chroma.from_documents(
|
106 |
-
documents=text_chunks,
|
107 |
-
embedding=embeddings,
|
108 |
-
persist_directory="./chroma_db"
|
109 |
-
)
|
110 |
-
|
111 |
-
# Initialize vector_db
|
112 |
-
vector_db = load_vector_db()
|
113 |
|
114 |
# BM25 setup
|
115 |
bm25_corpus = [chunk.page_content for chunk in text_chunks]
|
@@ -137,8 +113,10 @@ class ConversationMemory:
|
|
137 |
[f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
|
138 |
)
|
139 |
|
|
|
140 |
memory = ConversationMemory(max_size=3)
|
141 |
|
|
|
142 |
# ------------------------------
|
143 |
# Hybrid Retrieval System
|
144 |
# ------------------------------
|
@@ -211,8 +189,8 @@ class SafetyGuard:
|
|
211 |
query_lower = query.lower()
|
212 |
if any(topic in query_lower for topic in self.blocked_topics):
|
213 |
return False, "I only discuss financial topics."
|
214 |
-
|
215 |
-
|
216 |
return True, ""
|
217 |
|
218 |
def filter_output(self, response: str) -> str:
|
@@ -236,37 +214,24 @@ guard = SafetyGuard()
|
|
236 |
# LLM Initialization
|
237 |
# ------------------------------
|
238 |
try:
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
device_map="auto",
|
246 |
-
torch_dtype=torch.bfloat16,
|
247 |
-
load_in_4bit=True
|
248 |
-
)
|
249 |
-
else:
|
250 |
-
model = AutoModelForCausalLM.from_pretrained(
|
251 |
-
LLM_MODEL,
|
252 |
-
device_map="cpu",
|
253 |
-
torch_dtype=torch.float32
|
254 |
-
)
|
255 |
-
return pipeline(
|
256 |
-
"text-generation",
|
257 |
-
model=model,
|
258 |
-
tokenizer=tokenizer,
|
259 |
-
max_new_tokens=400,
|
260 |
-
do_sample=True,
|
261 |
-
temperature=0.3,
|
262 |
-
top_k=30,
|
263 |
-
top_p=0.9,
|
264 |
-
repetition_penalty=1.2
|
265 |
-
)
|
266 |
-
|
267 |
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
except Exception as e:
|
271 |
print(f"Error loading model: {e}")
|
272 |
raise
|
@@ -285,15 +250,13 @@ def extract_final_response(full_response: str) -> str:
|
|
285 |
|
286 |
def generate_answer(query: str) -> Tuple[str, float]:
|
287 |
try:
|
288 |
-
# Input validation
|
289 |
is_valid, msg = guard.validate_input(query)
|
290 |
if not is_valid:
|
291 |
return msg, 0.0
|
292 |
|
293 |
-
# Retrieve context
|
294 |
context = hybrid_retrieval(query)
|
|
|
295 |
|
296 |
-
# Generate response
|
297 |
prompt = f"""<|im_start|>system
|
298 |
You are a financial analyst. Provide a brief answer using the context.
|
299 |
Context: {context}<|im_end|>
|
@@ -302,19 +265,19 @@ Context: {context}<|im_end|>
|
|
302 |
<|im_start|>assistant
|
303 |
Answer:"""
|
304 |
|
|
|
|
|
305 |
response = generator(prompt)[0]['generated_text']
|
306 |
clean_response = extract_final_response(response)
|
307 |
clean_response = guard.filter_output(clean_response)
|
308 |
|
309 |
-
# Calculate confidence
|
310 |
query_embed = embeddings.embed_query(query)
|
311 |
response_embed = embeddings.embed_query(clean_response)
|
312 |
confidence = cosine_similarity([query_embed], [response_embed])[0][0]
|
313 |
|
314 |
-
# Update memory
|
315 |
memory.add_interaction(query, clean_response)
|
316 |
|
317 |
return clean_response, round(confidence, 2)
|
318 |
|
319 |
except Exception as e:
|
320 |
-
return f"Error processing request: {e}", 0.0
|
|
|
12 |
from typing import Tuple
|
13 |
import torch
|
14 |
|
|
|
|
|
15 |
# LangChain components
|
16 |
from langchain_community.document_loaders import PyPDFLoader
|
17 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
24 |
from sentence_transformers import CrossEncoder
|
25 |
from sklearn.metrics.pairwise import cosine_similarity
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
# Initialize NLTK stopwords
|
28 |
# nltk.download('stopwords')
|
29 |
# stop_words = set(stopwords.words('english'))
|
30 |
nltk.data.path.append('./nltk_data') # Point to local NLTK data
|
31 |
stop_words = set(nltk.corpus.stopwords.words('english'))
|
32 |
|
33 |
+
# mount
|
34 |
+
import sys
|
35 |
+
sys.path.append('/mount/src/gen_ai_dev')
|
36 |
+
|
37 |
# Configuration
|
38 |
DATA_PATH = "./Infy financial report/"
|
39 |
DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
|
40 |
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
41 |
+
LLM_MODEL = "microsoft/phi-2"
|
42 |
|
43 |
# Environment settings
|
44 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
80 |
text_chunks = load_and_chunk_documents()
|
81 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
82 |
|
83 |
+
vector_db = Chroma.from_documents(
|
84 |
+
documents=text_chunks,
|
85 |
+
embedding=embeddings,
|
86 |
+
persist_directory="./chroma_db"
|
87 |
+
)
|
88 |
+
vector_db.persist()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# BM25 setup
|
91 |
bm25_corpus = [chunk.page_content for chunk in text_chunks]
|
|
|
113 |
[f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
|
114 |
)
|
115 |
|
116 |
+
|
117 |
memory = ConversationMemory(max_size=3)
|
118 |
|
119 |
+
|
120 |
# ------------------------------
|
121 |
# Hybrid Retrieval System
|
122 |
# ------------------------------
|
|
|
189 |
query_lower = query.lower()
|
190 |
if any(topic in query_lower for topic in self.blocked_topics):
|
191 |
return False, "I only discuss financial topics."
|
192 |
+
if not any(term in query_lower for term in self.financial_terms):
|
193 |
+
return False, "Please ask financial questions."
|
194 |
return True, ""
|
195 |
|
196 |
def filter_output(self, response: str) -> str:
|
|
|
214 |
# LLM Initialization
|
215 |
# ------------------------------
|
216 |
try:
|
217 |
+
tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
|
218 |
+
model = AutoModelForCausalLM.from_pretrained(
|
219 |
+
LLM_MODEL,
|
220 |
+
device_map="cpu",
|
221 |
+
torch_dtype=torch.float32
|
222 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
generator = pipeline(
|
225 |
+
"text-generation",
|
226 |
+
model=model,
|
227 |
+
tokenizer=tokenizer,
|
228 |
+
max_new_tokens=400,
|
229 |
+
do_sample=True,
|
230 |
+
temperature=0.3,
|
231 |
+
top_k=30,
|
232 |
+
top_p=0.9,
|
233 |
+
repetition_penalty=1.2
|
234 |
+
)
|
235 |
except Exception as e:
|
236 |
print(f"Error loading model: {e}")
|
237 |
raise
|
|
|
250 |
|
251 |
def generate_answer(query: str) -> Tuple[str, float]:
|
252 |
try:
|
|
|
253 |
is_valid, msg = guard.validate_input(query)
|
254 |
if not is_valid:
|
255 |
return msg, 0.0
|
256 |
|
|
|
257 |
context = hybrid_retrieval(query)
|
258 |
+
vector_db.persist()
|
259 |
|
|
|
260 |
prompt = f"""<|im_start|>system
|
261 |
You are a financial analyst. Provide a brief answer using the context.
|
262 |
Context: {context}<|im_end|>
|
|
|
265 |
<|im_start|>assistant
|
266 |
Answer:"""
|
267 |
|
268 |
+
print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
|
269 |
+
|
270 |
response = generator(prompt)[0]['generated_text']
|
271 |
clean_response = extract_final_response(response)
|
272 |
clean_response = guard.filter_output(clean_response)
|
273 |
|
|
|
274 |
query_embed = embeddings.embed_query(query)
|
275 |
response_embed = embeddings.embed_query(clean_response)
|
276 |
confidence = cosine_similarity([query_embed], [response_embed])[0][0]
|
277 |
|
|
|
278 |
memory.add_interaction(query, clean_response)
|
279 |
|
280 |
return clean_response, round(confidence, 2)
|
281 |
|
282 |
except Exception as e:
|
283 |
+
return f"Error processing request: {e}", 0.0
|