S1131 commited on
Commit
4380a9e
·
verified ·
1 Parent(s): 06f5b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -296
app.py CHANGED
@@ -3,7 +3,7 @@
3
  import sys
4
  sys.path.append('/home/user/app')
5
  import streamlit as st
6
- # from utils import generate_answer
7
 
8
  # Streamlit configuration
9
  st.set_page_config(
@@ -14,300 +14,6 @@ st.set_page_config(
14
  'About': "Infosys Financial Analyst v1.0"
15
  }
16
  )
17
-
18
-
19
-
20
- # utils.py
21
- """
22
- Financial Chatbot Utilities
23
- Core functionality for RAG-based financial chatbot
24
- """
25
-
26
- import os
27
- import re
28
- import nltk
29
- from nltk.corpus import stopwords
30
- from collections import deque
31
- from typing import Tuple
32
- import torch
33
- import streamlit as st
34
-
35
- # LangChain components
36
- from langchain_community.document_loaders import PyPDFLoader
37
- from langchain.text_splitter import RecursiveCharacterTextSplitter
38
- from langchain_community.vectorstores import Chroma
39
- from langchain_huggingface import HuggingFaceEmbeddings
40
-
41
- # Models and ML
42
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
43
- from rank_bm25 import BM25Okapi
44
- from sentence_transformers import CrossEncoder
45
- from sklearn.metrics.pairwise import cosine_similarity
46
-
47
- # Initialize NLTK stopwords
48
- nltk.download('stopwords')
49
- stop_words = set(stopwords.words('english'))
50
- # nltk.data.path.append('./nltk_data') # Point to local NLTK data
51
- # stop_words = set(nltk.corpus.stopwords.words('english'))
52
-
53
- # mount
54
- import sys
55
- sys.path.append('/mount/src/gen_ai_dev')
56
-
57
- # Configuration
58
- DATA_PATH = "./Infy financial report/"
59
- DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
60
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
61
- LLM_MODEL = "gpt2" # Or "distilgpt2" # Or "HuggingFaceH4/zephyr-7b-beta" or "microsoft/phi-2"
62
-
63
- # Environment settings
64
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
65
- os.environ["CHROMA_DISABLE_TELEMETRY"] = "true"
66
-
67
- # Suppress specific warnings
68
- import warnings
69
-
70
- warnings.filterwarnings("ignore", message=".*oneDNN custom operations.*")
71
- warnings.filterwarnings("ignore", message=".*cuBLAS factory.*")
72
-
73
-
74
- # ------------------------------
75
- # Load and Chunk Documents
76
- # ------------------------------
77
- def load_and_chunk_documents():
78
- """Load and split PDF documents into manageable chunks"""
79
- text_splitter = RecursiveCharacterTextSplitter(
80
- chunk_size=500,
81
- chunk_overlap=100,
82
- separators=["\n\n", "\n", ".", " ", ""]
83
- )
84
-
85
- all_chunks = []
86
- for file in DATA_FILES:
87
- try:
88
- loader = PyPDFLoader(os.path.join(DATA_PATH, file))
89
- pages = loader.load()
90
- all_chunks.extend(text_splitter.split_documents(pages))
91
- except Exception as e:
92
- print(f"Error loading {file}: {e}")
93
-
94
- return all_chunks
95
-
96
-
97
- # ------------------------------
98
- # Vector Store and Search Setup
99
- # ------------------------------
100
- text_chunks = load_and_chunk_documents()
101
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
102
-
103
- vector_db = Chroma.from_documents(
104
- documents=text_chunks,
105
- embedding=embeddings,
106
- persist_directory="./chroma_db"
107
- )
108
- vector_db.persist()
109
-
110
- # BM25 setup
111
- bm25_corpus = [chunk.page_content for chunk in text_chunks]
112
- bm25_tokenized = [doc.split() for doc in bm25_corpus]
113
- bm25 = BM25Okapi(bm25_tokenized)
114
-
115
- # Cross-encoder for re-ranking
116
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
117
-
118
-
119
- # ------------------------------
120
- # Conversation Memory
121
- # ------------------------------
122
- class ConversationMemory:
123
- """Stores recent conversation context"""
124
-
125
- def __init__(self, max_size=5):
126
- self.buffer = deque(maxlen=max_size)
127
-
128
- def add_interaction(self, query: str, response: str) -> None:
129
- self.buffer.append((query, response))
130
-
131
- def get_context(self) -> str:
132
- return "\n".join(
133
- [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
134
- )
135
-
136
-
137
- memory = ConversationMemory(max_size=3)
138
-
139
-
140
- # ------------------------------
141
- # Hybrid Retrieval System
142
- # ------------------------------
143
- def hybrid_retrieval(query: str, top_k: int = 5) -> str:
144
- try:
145
- # Semantic search
146
- semantic_results = vector_db.similarity_search(query, k=top_k * 2)
147
- print(f"\n\n[For Debug Only] Semantic Results: {semantic_results}")
148
-
149
- # Keyword search
150
- keyword_results = bm25.get_top_n(query.split(), bm25_corpus, n=top_k * 2)
151
- print(f"\n\n[For Debug Only] Keyword Results: {keyword_results}\n\n")
152
-
153
- # Combine and deduplicate results
154
- combined = []
155
- seen = set()
156
-
157
- for doc in semantic_results:
158
- content = doc.page_content
159
- if content not in seen:
160
- combined.append((content, "semantic"))
161
- seen.add(content)
162
-
163
- for doc in keyword_results:
164
- if doc not in seen:
165
- combined.append((doc, "keyword"))
166
- seen.add(doc)
167
-
168
- # Re-rank results using cross-encoder
169
- pairs = [(query, content) for content, _ in combined]
170
- scores = cross_encoder.predict(pairs)
171
-
172
- # Sort by scores
173
- sorted_results = sorted(
174
- zip(combined, scores),
175
- key=lambda x: x[1],
176
- reverse=True
177
- )
178
-
179
- final_results = [f"[{source}] {content}" for (content, source), _ in sorted_results[:top_k]]
180
-
181
- memory_context = memory.get_context()
182
- if memory_context:
183
- final_results.append(f"[memory] {memory_context}")
184
-
185
- return "\n\n".join(final_results)
186
-
187
- except Exception as e:
188
- print(f"Retrieval error: {e}")
189
- return ""
190
-
191
-
192
- # ------------------------------
193
- # Safety Guardrails
194
- # ------------------------------
195
- class SafetyGuard:
196
- """Validates input and filters output"""
197
-
198
- def __init__(self):
199
- # self.financial_terms = {
200
- # 'revenue', 'profit', 'ebitda', 'balance', 'cash',
201
- # 'income', 'fiscal', 'growth', 'margin', 'expense'
202
- # }
203
- self.blocked_topics = {
204
- 'politics', 'sports', 'entertainment', 'religion',
205
- 'medical', 'hypothetical', 'opinion', 'personal'
206
- }
207
-
208
- def validate_input(self, query: str) -> Tuple[bool, str]:
209
- query_lower = query.lower()
210
- # if not any(term in query_lower for term in self.financial_terms):
211
- # return False, "Please ask financial questions."
212
- if any(topic in query_lower for topic in self.blocked_topics):
213
- return False, "I only discuss financial topics."
214
- return True, ""
215
-
216
- def filter_output(self, response: str) -> str:
217
- phrases_to_remove = {
218
- "I'm not sure", "I don't know", "maybe",
219
- "possibly", "could be", "uncertain", "perhaps"
220
- }
221
- for phrase in phrases_to_remove:
222
- response = response.replace(phrase, "")
223
-
224
- sentences = re.split(r'[.!?]', response)
225
- if len(sentences) > 2:
226
- response = '. '.join(sentences[:2]) + '.'
227
-
228
- return response.strip()
229
-
230
-
231
- guard = SafetyGuard()
232
-
233
- # ------------------------------
234
- # LLM Initialization
235
- # ------------------------------
236
- try:
237
- @st.cache_resource
238
- def load_generator():
239
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
240
- model = AutoModelForCausalLM.from_pretrained(
241
- LLM_MODEL,
242
- device_map="cpu",
243
- torch_dtype=torch.float16,
244
- )
245
- return pipeline(
246
- "text-generation",
247
- model=model,
248
- tokenizer=tokenizer,
249
- max_new_tokens=100,
250
- do_sample=False,
251
- temperature=0.7,
252
- top_k=0,
253
- top_p=1
254
- )
255
- generator = load_generator()
256
- except Exception as e:
257
- print(f"Error loading model: {e}")
258
- raise
259
-
260
-
261
- # ------------------------------
262
- # Response Generation
263
- # ------------------------------
264
- def extract_final_response(full_response: str) -> str:
265
- parts = full_response.split("<|im_start|>assistant")
266
- if len(parts) > 1:
267
- response = parts[-1].split("<|im_end|>")[0]
268
- return re.sub(r'\s+', ' ', response).strip()
269
- return full_response
270
-
271
-
272
- def generate_answer(query: str) -> Tuple[str, float]:
273
- try:
274
- is_valid, msg = guard.validate_input(query)
275
- if not is_valid:
276
- return msg, 0.0
277
-
278
- context = hybrid_retrieval(query)
279
- vector_db.persist()
280
-
281
- prompt = f"""<|im_start|>system
282
- You are a financial analyst. Provide a brief answer using the context.
283
- Context: {context}<|im_end|>
284
- <|im_start|>user
285
- {query}<|im_end|>
286
- <|im_start|>assistant
287
- Answer:"""
288
-
289
- print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
290
- response = generator(prompt)[0]['generated_text']
291
- print(f"\n\n[For Debug Only] response: {response}\n\n")
292
-
293
- clean_response = extract_final_response(response)
294
- clean_response = guard.filter_output(clean_response)
295
- print(f"\n\n[For Debug Only] clean_response: {clean_response}\n\n")
296
-
297
- query_embed = embeddings.embed_query(query)
298
- response_embed = embeddings.embed_query(clean_response)
299
-
300
- confidence = cosine_similarity([query_embed], [response_embed])[0][0]
301
- print(f"\n\n[For Debug Only] confidence: {confidence}\n\n")
302
-
303
- memory.add_interaction(query, clean_response)
304
-
305
- print(f"\n\n[For Debug Only] I'm Done \n\n")
306
- return clean_response, round(confidence, 2)
307
-
308
- except Exception as e:
309
- return f"Error processing request: {e}", 0.0
310
-
311
 
312
 
313
  def main():
@@ -329,4 +35,4 @@ def main():
329
  st.markdown(f"{confidence * 100:.1f}% relevance confidence")
330
 
331
  if __name__ == "__main__":
332
- main()
 
3
  import sys
4
  sys.path.append('/home/user/app')
5
  import streamlit as st
6
+ from utils import generate_answer
7
 
8
  # Streamlit configuration
9
  st.set_page_config(
 
14
  'About': "Infosys Financial Analyst v1.0"
15
  }
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def main():
 
35
  st.markdown(f"{confidence * 100:.1f}% relevance confidence")
36
 
37
  if __name__ == "__main__":
38
+ main()