S1131 commited on
Commit
a86fd2f
·
verified ·
1 Parent(s): af84697

Included utils into app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -1
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import sys
4
  sys.path.append('/home/user/app')
5
  import streamlit as st
6
- from utils import generate_answer
7
 
8
  # Streamlit configuration
9
  st.set_page_config(
@@ -15,6 +15,301 @@ st.set_page_config(
15
  }
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def main():
19
  st.title("💰 INFY Financial Analyst (2022-2024)")
20
  st.markdown("Ask questions about Infosys financial statements from the last 2 years.")
 
3
  import sys
4
  sys.path.append('/home/user/app')
5
  import streamlit as st
6
+ # from utils import generate_answer
7
 
8
  # Streamlit configuration
9
  st.set_page_config(
 
15
  }
16
  )
17
 
18
+
19
+
20
+ # utils.py
21
+ """
22
+ Financial Chatbot Utilities
23
+ Core functionality for RAG-based financial chatbot
24
+ """
25
+
26
+ import os
27
+ import re
28
+ import nltk
29
+ from nltk.corpus import stopwords
30
+ from collections import deque
31
+ from typing import Tuple
32
+ import torch
33
+ import streamlit as st
34
+
35
+ # LangChain components
36
+ from langchain_community.document_loaders import PyPDFLoader
37
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
38
+ from langchain_community.vectorstores import Chroma
39
+ from langchain_huggingface import HuggingFaceEmbeddings
40
+
41
+ # Models and ML
42
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
43
+ from rank_bm25 import BM25Okapi
44
+ from sentence_transformers import CrossEncoder
45
+ from sklearn.metrics.pairwise import cosine_similarity
46
+
47
+ # Initialize NLTK stopwords
48
+ nltk.download('stopwords')
49
+ stop_words = set(stopwords.words('english'))
50
+ # nltk.data.path.append('./nltk_data') # Point to local NLTK data
51
+ # stop_words = set(nltk.corpus.stopwords.words('english'))
52
+
53
+ # mount
54
+ import sys
55
+ sys.path.append('/mount/src/gen_ai_dev')
56
+
57
+ # Configuration
58
+ DATA_PATH = "./Infy financial report/"
59
+ DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
60
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
61
+ LLM_MODEL = "gpt2" # Or "distilgpt2" # Or "HuggingFaceH4/zephyr-7b-beta" or "microsoft/phi-2"
62
+
63
+ # Environment settings
64
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
+ os.environ["CHROMA_DISABLE_TELEMETRY"] = "true"
66
+
67
+ # Suppress specific warnings
68
+ import warnings
69
+
70
+ warnings.filterwarnings("ignore", message=".*oneDNN custom operations.*")
71
+ warnings.filterwarnings("ignore", message=".*cuBLAS factory.*")
72
+
73
+
74
+ # ------------------------------
75
+ # Load and Chunk Documents
76
+ # ------------------------------
77
+ def load_and_chunk_documents():
78
+ """Load and split PDF documents into manageable chunks"""
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=500,
81
+ chunk_overlap=100,
82
+ separators=["\n\n", "\n", ".", " ", ""]
83
+ )
84
+
85
+ all_chunks = []
86
+ for file in DATA_FILES:
87
+ try:
88
+ loader = PyPDFLoader(os.path.join(DATA_PATH, file))
89
+ pages = loader.load()
90
+ all_chunks.extend(text_splitter.split_documents(pages))
91
+ except Exception as e:
92
+ print(f"Error loading {file}: {e}")
93
+
94
+ return all_chunks
95
+
96
+
97
+ # ------------------------------
98
+ # Vector Store and Search Setup
99
+ # ------------------------------
100
+ text_chunks = load_and_chunk_documents()
101
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
102
+
103
+ vector_db = Chroma.from_documents(
104
+ documents=text_chunks,
105
+ embedding=embeddings,
106
+ persist_directory="./chroma_db"
107
+ )
108
+ vector_db.persist()
109
+
110
+ # BM25 setup
111
+ bm25_corpus = [chunk.page_content for chunk in text_chunks]
112
+ bm25_tokenized = [doc.split() for doc in bm25_corpus]
113
+ bm25 = BM25Okapi(bm25_tokenized)
114
+
115
+ # Cross-encoder for re-ranking
116
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
117
+
118
+
119
+ # ------------------------------
120
+ # Conversation Memory
121
+ # ------------------------------
122
+ class ConversationMemory:
123
+ """Stores recent conversation context"""
124
+
125
+ def __init__(self, max_size=5):
126
+ self.buffer = deque(maxlen=max_size)
127
+
128
+ def add_interaction(self, query: str, response: str) -> None:
129
+ self.buffer.append((query, response))
130
+
131
+ def get_context(self) -> str:
132
+ return "\n".join(
133
+ [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
134
+ )
135
+
136
+
137
+ memory = ConversationMemory(max_size=3)
138
+
139
+
140
+ # ------------------------------
141
+ # Hybrid Retrieval System
142
+ # ------------------------------
143
+ def hybrid_retrieval(query: str, top_k: int = 5) -> str:
144
+ try:
145
+ # Semantic search
146
+ semantic_results = vector_db.similarity_search(query, k=top_k * 2)
147
+ print(f"\n\n[For Debug Only] Semantic Results: {semantic_results}")
148
+
149
+ # Keyword search
150
+ keyword_results = bm25.get_top_n(query.split(), bm25_corpus, n=top_k * 2)
151
+ print(f"\n\n[For Debug Only] Keyword Results: {keyword_results}\n\n")
152
+
153
+ # Combine and deduplicate results
154
+ combined = []
155
+ seen = set()
156
+
157
+ for doc in semantic_results:
158
+ content = doc.page_content
159
+ if content not in seen:
160
+ combined.append((content, "semantic"))
161
+ seen.add(content)
162
+
163
+ for doc in keyword_results:
164
+ if doc not in seen:
165
+ combined.append((doc, "keyword"))
166
+ seen.add(doc)
167
+
168
+ # Re-rank results using cross-encoder
169
+ pairs = [(query, content) for content, _ in combined]
170
+ scores = cross_encoder.predict(pairs)
171
+
172
+ # Sort by scores
173
+ sorted_results = sorted(
174
+ zip(combined, scores),
175
+ key=lambda x: x[1],
176
+ reverse=True
177
+ )
178
+
179
+ final_results = [f"[{source}] {content}" for (content, source), _ in sorted_results[:top_k]]
180
+
181
+ memory_context = memory.get_context()
182
+ if memory_context:
183
+ final_results.append(f"[memory] {memory_context}")
184
+
185
+ return "\n\n".join(final_results)
186
+
187
+ except Exception as e:
188
+ print(f"Retrieval error: {e}")
189
+ return ""
190
+
191
+
192
+ # ------------------------------
193
+ # Safety Guardrails
194
+ # ------------------------------
195
+ class SafetyGuard:
196
+ """Validates input and filters output"""
197
+
198
+ def __init__(self):
199
+ # self.financial_terms = {
200
+ # 'revenue', 'profit', 'ebitda', 'balance', 'cash',
201
+ # 'income', 'fiscal', 'growth', 'margin', 'expense'
202
+ # }
203
+ self.blocked_topics = {
204
+ 'politics', 'sports', 'entertainment', 'religion',
205
+ 'medical', 'hypothetical', 'opinion', 'personal'
206
+ }
207
+
208
+ def validate_input(self, query: str) -> Tuple[bool, str]:
209
+ query_lower = query.lower()
210
+ # if not any(term in query_lower for term in self.financial_terms):
211
+ # return False, "Please ask financial questions."
212
+ if any(topic in query_lower for topic in self.blocked_topics):
213
+ return False, "I only discuss financial topics."
214
+ return True, ""
215
+
216
+ def filter_output(self, response: str) -> str:
217
+ phrases_to_remove = {
218
+ "I'm not sure", "I don't know", "maybe",
219
+ "possibly", "could be", "uncertain", "perhaps"
220
+ }
221
+ for phrase in phrases_to_remove:
222
+ response = response.replace(phrase, "")
223
+
224
+ sentences = re.split(r'[.!?]', response)
225
+ if len(sentences) > 2:
226
+ response = '. '.join(sentences[:2]) + '.'
227
+
228
+ return response.strip()
229
+
230
+
231
+ guard = SafetyGuard()
232
+
233
+ # ------------------------------
234
+ # LLM Initialization
235
+ # ------------------------------
236
+ try:
237
+ @st.cache_resource
238
+ def load_generator():
239
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
240
+ model = AutoModelForCausalLM.from_pretrained(
241
+ LLM_MODEL,
242
+ device_map="cpu",
243
+ torch_dtype=torch.float16,
244
+ )
245
+ return pipeline(
246
+ "text-generation",
247
+ model=model,
248
+ tokenizer=tokenizer,
249
+ max_new_tokens=100,
250
+ do_sample=False,
251
+ temperature=0.7,
252
+ top_k=0,
253
+ top_p=1
254
+ )
255
+ generator = load_generator()
256
+ except Exception as e:
257
+ print(f"Error loading model: {e}")
258
+ raise
259
+
260
+
261
+ # ------------------------------
262
+ # Response Generation
263
+ # ------------------------------
264
+ def extract_final_response(full_response: str) -> str:
265
+ parts = full_response.split("<|im_start|>assistant")
266
+ if len(parts) > 1:
267
+ response = parts[-1].split("<|im_end|>")[0]
268
+ return re.sub(r'\s+', ' ', response).strip()
269
+ return full_response
270
+
271
+
272
+ def generate_answer(query: str) -> Tuple[str, float]:
273
+ try:
274
+ is_valid, msg = guard.validate_input(query)
275
+ if not is_valid:
276
+ return msg, 0.0
277
+
278
+ context = hybrid_retrieval(query)
279
+ vector_db.persist()
280
+
281
+ prompt = f"""<|im_start|>system
282
+ You are a financial analyst. Provide a brief answer using the context.
283
+ Context: {context}<|im_end|>
284
+ <|im_start|>user
285
+ {query}<|im_end|>
286
+ <|im_start|>assistant
287
+ Answer:"""
288
+
289
+ print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
290
+ response = generator(prompt)[0]['generated_text']
291
+ print(f"\n\n[For Debug Only] response: {response}\n\n")
292
+
293
+ clean_response = extract_final_response(response)
294
+ clean_response = guard.filter_output(clean_response)
295
+ print(f"\n\n[For Debug Only] clean_response: {clean_response}\n\n")
296
+
297
+ query_embed = embeddings.embed_query(query)
298
+ response_embed = embeddings.embed_query(clean_response)
299
+
300
+ confidence = cosine_similarity([query_embed], [response_embed])[0][0]
301
+ print(f"\n\n[For Debug Only] confidence: {confidence}\n\n")
302
+
303
+ memory.add_interaction(query, clean_response)
304
+
305
+ print(f"\n\n[For Debug Only] I'm Done \n\n")
306
+ return clean_response, round(confidence, 2)
307
+
308
+ except Exception as e:
309
+ return f"Error processing request: {e}", 0.0
310
+
311
+
312
+
313
  def main():
314
  st.title("💰 INFY Financial Analyst (2022-2024)")
315
  st.markdown("Ask questions about Infosys financial statements from the last 2 years.")