S1131 commited on
Commit
c71db70
·
verified ·
1 Parent(s): a86fd2f

utils merged into app, so deleting this now

Browse files
Files changed (1) hide show
  1. utils.py +0 -293
utils.py DELETED
@@ -1,293 +0,0 @@
1
- # utils.py
2
- """
3
- Financial Chatbot Utilities
4
- Core functionality for RAG-based financial chatbot
5
- """
6
-
7
- import os
8
- import re
9
- import nltk
10
- from nltk.corpus import stopwords
11
- from collections import deque
12
- from typing import Tuple
13
- import torch
14
- import streamlit as st
15
-
16
- # LangChain components
17
- from langchain_community.document_loaders import PyPDFLoader
18
- from langchain.text_splitter import RecursiveCharacterTextSplitter
19
- from langchain_community.vectorstores import Chroma
20
- from langchain_huggingface import HuggingFaceEmbeddings
21
-
22
- # Models and ML
23
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
24
- from rank_bm25 import BM25Okapi
25
- from sentence_transformers import CrossEncoder
26
- from sklearn.metrics.pairwise import cosine_similarity
27
-
28
- # Initialize NLTK stopwords
29
- nltk.download('stopwords')
30
- stop_words = set(stopwords.words('english'))
31
- # nltk.data.path.append('./nltk_data') # Point to local NLTK data
32
- # stop_words = set(nltk.corpus.stopwords.words('english'))
33
-
34
- # mount
35
- import sys
36
- sys.path.append('/mount/src/gen_ai_dev')
37
-
38
- # Configuration
39
- DATA_PATH = "./Infy financial report/"
40
- DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
41
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
- LLM_MODEL = "gpt2" # Or "distilgpt2" # Or "HuggingFaceH4/zephyr-7b-beta" or "microsoft/phi-2"
43
-
44
- # Environment settings
45
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
- os.environ["CHROMA_DISABLE_TELEMETRY"] = "true"
47
-
48
- # Suppress specific warnings
49
- import warnings
50
-
51
- warnings.filterwarnings("ignore", message=".*oneDNN custom operations.*")
52
- warnings.filterwarnings("ignore", message=".*cuBLAS factory.*")
53
-
54
-
55
- # ------------------------------
56
- # Load and Chunk Documents
57
- # ------------------------------
58
- def load_and_chunk_documents():
59
- """Load and split PDF documents into manageable chunks"""
60
- text_splitter = RecursiveCharacterTextSplitter(
61
- chunk_size=500,
62
- chunk_overlap=100,
63
- separators=["\n\n", "\n", ".", " ", ""]
64
- )
65
-
66
- all_chunks = []
67
- for file in DATA_FILES:
68
- try:
69
- loader = PyPDFLoader(os.path.join(DATA_PATH, file))
70
- pages = loader.load()
71
- all_chunks.extend(text_splitter.split_documents(pages))
72
- except Exception as e:
73
- print(f"Error loading {file}: {e}")
74
-
75
- return all_chunks
76
-
77
-
78
- # ------------------------------
79
- # Vector Store and Search Setup
80
- # ------------------------------
81
- text_chunks = load_and_chunk_documents()
82
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
83
-
84
- vector_db = Chroma.from_documents(
85
- documents=text_chunks,
86
- embedding=embeddings,
87
- persist_directory="./chroma_db"
88
- )
89
- vector_db.persist()
90
-
91
- # BM25 setup
92
- bm25_corpus = [chunk.page_content for chunk in text_chunks]
93
- bm25_tokenized = [doc.split() for doc in bm25_corpus]
94
- bm25 = BM25Okapi(bm25_tokenized)
95
-
96
- # Cross-encoder for re-ranking
97
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
98
-
99
-
100
- # ------------------------------
101
- # Conversation Memory
102
- # ------------------------------
103
- class ConversationMemory:
104
- """Stores recent conversation context"""
105
-
106
- def __init__(self, max_size=5):
107
- self.buffer = deque(maxlen=max_size)
108
-
109
- def add_interaction(self, query: str, response: str) -> None:
110
- self.buffer.append((query, response))
111
-
112
- def get_context(self) -> str:
113
- return "\n".join(
114
- [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
115
- )
116
-
117
-
118
- memory = ConversationMemory(max_size=3)
119
-
120
-
121
- # ------------------------------
122
- # Hybrid Retrieval System
123
- # ------------------------------
124
- def hybrid_retrieval(query: str, top_k: int = 5) -> str:
125
- try:
126
- # Semantic search
127
- semantic_results = vector_db.similarity_search(query, k=top_k * 2)
128
- print(f"\n\n[For Debug Only] Semantic Results: {semantic_results}")
129
-
130
- # Keyword search
131
- keyword_results = bm25.get_top_n(query.split(), bm25_corpus, n=top_k * 2)
132
- print(f"\n\n[For Debug Only] Keyword Results: {keyword_results}\n\n")
133
-
134
- # Combine and deduplicate results
135
- combined = []
136
- seen = set()
137
-
138
- for doc in semantic_results:
139
- content = doc.page_content
140
- if content not in seen:
141
- combined.append((content, "semantic"))
142
- seen.add(content)
143
-
144
- for doc in keyword_results:
145
- if doc not in seen:
146
- combined.append((doc, "keyword"))
147
- seen.add(doc)
148
-
149
- # Re-rank results using cross-encoder
150
- pairs = [(query, content) for content, _ in combined]
151
- scores = cross_encoder.predict(pairs)
152
-
153
- # Sort by scores
154
- sorted_results = sorted(
155
- zip(combined, scores),
156
- key=lambda x: x[1],
157
- reverse=True
158
- )
159
-
160
- final_results = [f"[{source}] {content}" for (content, source), _ in sorted_results[:top_k]]
161
-
162
- memory_context = memory.get_context()
163
- if memory_context:
164
- final_results.append(f"[memory] {memory_context}")
165
-
166
- return "\n\n".join(final_results)
167
-
168
- except Exception as e:
169
- print(f"Retrieval error: {e}")
170
- return ""
171
-
172
-
173
- # ------------------------------
174
- # Safety Guardrails
175
- # ------------------------------
176
- class SafetyGuard:
177
- """Validates input and filters output"""
178
-
179
- def __init__(self):
180
- # self.financial_terms = {
181
- # 'revenue', 'profit', 'ebitda', 'balance', 'cash',
182
- # 'income', 'fiscal', 'growth', 'margin', 'expense'
183
- # }
184
- self.blocked_topics = {
185
- 'politics', 'sports', 'entertainment', 'religion',
186
- 'medical', 'hypothetical', 'opinion', 'personal'
187
- }
188
-
189
- def validate_input(self, query: str) -> Tuple[bool, str]:
190
- query_lower = query.lower()
191
- # if not any(term in query_lower for term in self.financial_terms):
192
- # return False, "Please ask financial questions."
193
- if any(topic in query_lower for topic in self.blocked_topics):
194
- return False, "I only discuss financial topics."
195
- return True, ""
196
-
197
- def filter_output(self, response: str) -> str:
198
- phrases_to_remove = {
199
- "I'm not sure", "I don't know", "maybe",
200
- "possibly", "could be", "uncertain", "perhaps"
201
- }
202
- for phrase in phrases_to_remove:
203
- response = response.replace(phrase, "")
204
-
205
- sentences = re.split(r'[.!?]', response)
206
- if len(sentences) > 2:
207
- response = '. '.join(sentences[:2]) + '.'
208
-
209
- return response.strip()
210
-
211
-
212
- guard = SafetyGuard()
213
-
214
- # ------------------------------
215
- # LLM Initialization
216
- # ------------------------------
217
- try:
218
- @st.cache_resource
219
- def load_generator():
220
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
221
- model = AutoModelForCausalLM.from_pretrained(
222
- LLM_MODEL,
223
- device_map="cpu",
224
- torch_dtype=torch.float16,
225
- )
226
- return pipeline(
227
- "text-generation",
228
- model=model,
229
- tokenizer=tokenizer,
230
- max_new_tokens=100,
231
- do_sample=False,
232
- temperature=0.7,
233
- top_k=0,
234
- top_p=1
235
- )
236
- generator = load_generator()
237
- except Exception as e:
238
- print(f"Error loading model: {e}")
239
- raise
240
-
241
-
242
- # ------------------------------
243
- # Response Generation
244
- # ------------------------------
245
- def extract_final_response(full_response: str) -> str:
246
- parts = full_response.split("<|im_start|>assistant")
247
- if len(parts) > 1:
248
- response = parts[-1].split("<|im_end|>")[0]
249
- return re.sub(r'\s+', ' ', response).strip()
250
- return full_response
251
-
252
-
253
- def generate_answer(query: str) -> Tuple[str, float]:
254
- try:
255
- is_valid, msg = guard.validate_input(query)
256
- if not is_valid:
257
- return msg, 0.0
258
-
259
- context = hybrid_retrieval(query)
260
- vector_db.persist()
261
-
262
- prompt = f"""<|im_start|>system
263
- You are a financial analyst. Provide a brief answer using the context.
264
- Context: {context}<|im_end|>
265
- <|im_start|>user
266
- {query}<|im_end|>
267
- <|im_start|>assistant
268
- Answer:"""
269
-
270
- print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
271
- response = generator(prompt)[0]['generated_text']
272
- print(f"\n\n[For Debug Only] response: {response}\n\n")
273
-
274
- clean_response = extract_final_response(response)
275
- clean_response = guard.filter_output(clean_response)
276
- print(f"\n\n[For Debug Only] clean_response: {clean_response}\n\n")
277
-
278
- query_embed = embeddings.embed_query(query)
279
- print(f"\n\n[For Debug Only] query_embed: {query_embed}\n\n")
280
-
281
- response_embed = embeddings.embed_query(clean_response)
282
- print(f"\n\n[For Debug Only] response_embed: {response_embed}\n\n")
283
-
284
- confidence = cosine_similarity([query_embed], [response_embed])[0][0]
285
- print(f"\n\n[For Debug Only] confidence: {confidence}\n\n")
286
-
287
- memory.add_interaction(query, clean_response)
288
-
289
- print(f"\n\n[For Debug Only] I'm Done \n\n")
290
- return clean_response, round(confidence, 2)
291
-
292
- except Exception as e:
293
- return f"Error processing request: {e}", 0.0