S1131 commited on
Commit
06f5b6a
·
verified ·
1 Parent(s): c71db70

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +290 -0
utils.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ """
3
+ Financial Chatbot Utilities
4
+ Core functionality for RAG-based financial chatbot
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import nltk
10
+ from nltk.corpus import stopwords
11
+ from collections import deque
12
+ from typing import Tuple
13
+ import torch
14
+ import streamlit as st
15
+
16
+ # LangChain components
17
+ from langchain_community.document_loaders import PyPDFLoader
18
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
19
+ from langchain_community.vectorstores import Chroma
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+
22
+ # Models and ML
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
24
+ from rank_bm25 import BM25Okapi
25
+ from sentence_transformers import CrossEncoder
26
+ from sklearn.metrics.pairwise import cosine_similarity
27
+
28
+ # Initialize NLTK stopwords
29
+ nltk.download('stopwords')
30
+ stop_words = set(stopwords.words('english'))
31
+ # nltk.data.path.append('./nltk_data') # Point to local NLTK data
32
+ # stop_words = set(nltk.corpus.stopwords.words('english'))
33
+
34
+ # mount
35
+ import sys
36
+ sys.path.append('/mount/src/gen_ai_dev')
37
+
38
+ # Configuration
39
+ DATA_PATH = "./Infy financial report/"
40
+ DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
41
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
+ LLM_MODEL = "gpt2" # Or "distilgpt2" # Or "HuggingFaceH4/zephyr-7b-beta" or "microsoft/phi-2"
43
+
44
+ # Environment settings
45
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
+ os.environ["CHROMA_DISABLE_TELEMETRY"] = "true"
47
+
48
+ # Suppress specific warnings
49
+ import warnings
50
+
51
+ warnings.filterwarnings("ignore", message=".*oneDNN custom operations.*")
52
+ warnings.filterwarnings("ignore", message=".*cuBLAS factory.*")
53
+
54
+
55
+ # ------------------------------
56
+ # Load and Chunk Documents
57
+ # ------------------------------
58
+ def load_and_chunk_documents():
59
+ """Load and split PDF documents into manageable chunks"""
60
+ text_splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=500,
62
+ chunk_overlap=100,
63
+ separators=["\n\n", "\n", ".", " ", ""]
64
+ )
65
+
66
+ all_chunks = []
67
+ for file in DATA_FILES:
68
+ try:
69
+ loader = PyPDFLoader(os.path.join(DATA_PATH, file))
70
+ pages = loader.load()
71
+ all_chunks.extend(text_splitter.split_documents(pages))
72
+ except Exception as e:
73
+ print(f"Error loading {file}: {e}")
74
+
75
+ return all_chunks
76
+
77
+
78
+ # ------------------------------
79
+ # Vector Store and Search Setup
80
+ # ------------------------------
81
+ text_chunks = load_and_chunk_documents()
82
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
83
+
84
+ vector_db = Chroma.from_documents(
85
+ documents=text_chunks,
86
+ embedding=embeddings,
87
+ persist_directory="./chroma_db"
88
+ )
89
+ vector_db.persist()
90
+
91
+ # BM25 setup
92
+ bm25_corpus = [chunk.page_content for chunk in text_chunks]
93
+ bm25_tokenized = [doc.split() for doc in bm25_corpus]
94
+ bm25 = BM25Okapi(bm25_tokenized)
95
+
96
+ # Cross-encoder for re-ranking
97
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
98
+
99
+
100
+ # ------------------------------
101
+ # Conversation Memory
102
+ # ------------------------------
103
+ class ConversationMemory:
104
+ """Stores recent conversation context"""
105
+
106
+ def __init__(self, max_size=5):
107
+ self.buffer = deque(maxlen=max_size)
108
+
109
+ def add_interaction(self, query: str, response: str) -> None:
110
+ self.buffer.append((query, response))
111
+
112
+ def get_context(self) -> str:
113
+ return "\n".join(
114
+ [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
115
+ )
116
+
117
+
118
+ memory = ConversationMemory(max_size=3)
119
+
120
+
121
+ # ------------------------------
122
+ # Hybrid Retrieval System
123
+ # ------------------------------
124
+ def hybrid_retrieval(query: str, top_k: int = 5) -> str:
125
+ try:
126
+ # Semantic search
127
+ semantic_results = vector_db.similarity_search(query, k=top_k * 2)
128
+ print(f"\n\n[For Debug Only] Semantic Results: {semantic_results}")
129
+
130
+ # Keyword search
131
+ keyword_results = bm25.get_top_n(query.split(), bm25_corpus, n=top_k * 2)
132
+ print(f"\n\n[For Debug Only] Keyword Results: {keyword_results}\n\n")
133
+
134
+ # Combine and deduplicate results
135
+ combined = []
136
+ seen = set()
137
+
138
+ for doc in semantic_results:
139
+ content = doc.page_content
140
+ if content not in seen:
141
+ combined.append((content, "semantic"))
142
+ seen.add(content)
143
+
144
+ for doc in keyword_results:
145
+ if doc not in seen:
146
+ combined.append((doc, "keyword"))
147
+ seen.add(doc)
148
+
149
+ # Re-rank results using cross-encoder
150
+ pairs = [(query, content) for content, _ in combined]
151
+ scores = cross_encoder.predict(pairs)
152
+
153
+ # Sort by scores
154
+ sorted_results = sorted(
155
+ zip(combined, scores),
156
+ key=lambda x: x[1],
157
+ reverse=True
158
+ )
159
+
160
+ final_results = [f"[{source}] {content}" for (content, source), _ in sorted_results[:top_k]]
161
+
162
+ memory_context = memory.get_context()
163
+ if memory_context:
164
+ final_results.append(f"[memory] {memory_context}")
165
+
166
+ return "\n\n".join(final_results)
167
+
168
+ except Exception as e:
169
+ print(f"Retrieval error: {e}")
170
+ return ""
171
+
172
+
173
+ # ------------------------------
174
+ # Safety Guardrails
175
+ # ------------------------------
176
+ class SafetyGuard:
177
+ """Validates input and filters output"""
178
+
179
+ def __init__(self):
180
+ # self.financial_terms = {
181
+ # 'revenue', 'profit', 'ebitda', 'balance', 'cash',
182
+ # 'income', 'fiscal', 'growth', 'margin', 'expense'
183
+ # }
184
+ self.blocked_topics = {
185
+ 'politics', 'sports', 'entertainment', 'religion',
186
+ 'medical', 'hypothetical', 'opinion', 'personal'
187
+ }
188
+
189
+ def validate_input(self, query: str) -> Tuple[bool, str]:
190
+ query_lower = query.lower()
191
+ # if not any(term in query_lower for term in self.financial_terms):
192
+ # return False, "Please ask financial questions."
193
+ if any(topic in query_lower for topic in self.blocked_topics):
194
+ return False, "I only discuss financial topics."
195
+ return True, ""
196
+
197
+ def filter_output(self, response: str) -> str:
198
+ phrases_to_remove = {
199
+ "I'm not sure", "I don't know", "maybe",
200
+ "possibly", "could be", "uncertain", "perhaps"
201
+ }
202
+ for phrase in phrases_to_remove:
203
+ response = response.replace(phrase, "")
204
+
205
+ sentences = re.split(r'[.!?]', response)
206
+ if len(sentences) > 2:
207
+ response = '. '.join(sentences[:2]) + '.'
208
+
209
+ return response.strip()
210
+
211
+
212
+ guard = SafetyGuard()
213
+
214
+ # ------------------------------
215
+ # LLM Initialization
216
+ # ------------------------------
217
+ try:
218
+ @st.cache_resource
219
+ def load_generator():
220
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
221
+ model = AutoModelForCausalLM.from_pretrained(
222
+ LLM_MODEL,
223
+ device_map="cpu",
224
+ torch_dtype=torch.float16,
225
+ )
226
+ return pipeline(
227
+ "text-generation",
228
+ model=model,
229
+ tokenizer=tokenizer,
230
+ max_new_tokens=100,
231
+ do_sample=False,
232
+ temperature=0.7,
233
+ top_k=0,
234
+ top_p=1
235
+ )
236
+ generator = load_generator()
237
+ except Exception as e:
238
+ print(f"Error loading model: {e}")
239
+ raise
240
+
241
+
242
+ # ------------------------------
243
+ # Response Generation
244
+ # ------------------------------
245
+ def extract_final_response(full_response: str) -> str:
246
+ parts = full_response.split("<|im_start|>assistant")
247
+ if len(parts) > 1:
248
+ response = parts[-1].split("<|im_end|>")[0]
249
+ return re.sub(r'\s+', ' ', response).strip()
250
+ return full_response
251
+
252
+
253
+ def generate_answer(query: str) -> Tuple[str, float]:
254
+ try:
255
+ is_valid, msg = guard.validate_input(query)
256
+ if not is_valid:
257
+ return msg, 0.0
258
+
259
+ context = hybrid_retrieval(query)
260
+ vector_db.persist()
261
+
262
+ prompt = f"""<|im_start|>system
263
+ You are a financial analyst. Provide a brief answer using the context.
264
+ Context: {context}<|im_end|>
265
+ <|im_start|>user
266
+ {query}<|im_end|>
267
+ <|im_start|>assistant
268
+ Answer:"""
269
+
270
+ print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
271
+ response = generator(prompt)[0]['generated_text']
272
+ print(f"\n\n[For Debug Only] response: {response}\n\n")
273
+
274
+ clean_response = extract_final_response(response)
275
+ clean_response = guard.filter_output(clean_response)
276
+ print(f"\n\n[For Debug Only] clean_response: {clean_response}\n\n")
277
+
278
+ query_embed = embeddings.embed_query(query)
279
+ response_embed = embeddings.embed_query(clean_response)
280
+
281
+ confidence = cosine_similarity([query_embed], [response_embed])[0][0]
282
+ print(f"\n\n[For Debug Only] confidence: {confidence}\n\n")
283
+
284
+ memory.add_interaction(query, clean_response)
285
+
286
+ print(f"\n\n[For Debug Only] I'm Done \n\n")
287
+ return clean_response, round(confidence, 2)
288
+
289
+ except Exception as e:
290
+ return f"Error processing request: {e}", 0.0