S1131 commited on
Commit
69a10b1
·
verified ·
1 Parent(s): 1a4d9d3

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +36 -73
utils.py CHANGED
@@ -12,8 +12,6 @@ from collections import deque
12
  from typing import Tuple
13
  import torch
14
 
15
- import streamlit as st
16
-
17
  # LangChain components
18
  from langchain_community.document_loaders import PyPDFLoader
19
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -26,31 +24,21 @@ from rank_bm25 import BM25Okapi
26
  from sentence_transformers import CrossEncoder
27
  from sklearn.metrics.pairwise import cosine_similarity
28
 
29
- import sys
30
-
31
- sys.path.append('/mount/src/gen_ai_dev')
32
-
33
- # these three lines swap the stdlib sqlite3 lib with the pysqlite3 package
34
- import pysqlite3
35
- import sys
36
- sys.modules["sqlite3"] = pysqlite3
37
-
38
- __import__('pysqlite3')
39
- import sys
40
-
41
- sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
42
-
43
  # Initialize NLTK stopwords
44
  # nltk.download('stopwords')
45
  # stop_words = set(stopwords.words('english'))
46
  nltk.data.path.append('./nltk_data') # Point to local NLTK data
47
  stop_words = set(nltk.corpus.stopwords.words('english'))
48
 
 
 
 
 
49
  # Configuration
50
  DATA_PATH = "./Infy financial report/"
51
  DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
52
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
53
- LLM_MODEL = "HuggingFaceH4/zephyr-7b-beta" #"microsoft/phi-2"
54
 
55
  # Environment settings
56
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -92,24 +80,12 @@ def load_and_chunk_documents():
92
  text_chunks = load_and_chunk_documents()
93
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
94
 
95
-
96
- @st.cache_resource(show_spinner=False)
97
- def load_vector_db():
98
- # Load and chunk documents
99
- text_chunks = load_and_chunk_documents()
100
-
101
- # Initialize embeddings
102
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
103
-
104
- # Create and return Chroma vector store
105
- return Chroma.from_documents(
106
- documents=text_chunks,
107
- embedding=embeddings,
108
- persist_directory="./chroma_db"
109
- )
110
-
111
- # Initialize vector_db
112
- vector_db = load_vector_db()
113
 
114
  # BM25 setup
115
  bm25_corpus = [chunk.page_content for chunk in text_chunks]
@@ -137,8 +113,10 @@ class ConversationMemory:
137
  [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
138
  )
139
 
 
140
  memory = ConversationMemory(max_size=3)
141
 
 
142
  # ------------------------------
143
  # Hybrid Retrieval System
144
  # ------------------------------
@@ -211,8 +189,8 @@ class SafetyGuard:
211
  query_lower = query.lower()
212
  if any(topic in query_lower for topic in self.blocked_topics):
213
  return False, "I only discuss financial topics."
214
- # if not any(term in query_lower for term in self.financial_terms):
215
- # return False, "Please ask financial questions."
216
  return True, ""
217
 
218
  def filter_output(self, response: str) -> str:
@@ -236,37 +214,24 @@ guard = SafetyGuard()
236
  # LLM Initialization
237
  # ------------------------------
238
  try:
239
- @st.cache_resource(show_spinner=False)
240
- def load_generator():
241
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
242
- if torch.cuda.is_available():
243
- model = AutoModelForCausalLM.from_pretrained(
244
- LLM_MODEL,
245
- device_map="auto",
246
- torch_dtype=torch.bfloat16,
247
- load_in_4bit=True
248
- )
249
- else:
250
- model = AutoModelForCausalLM.from_pretrained(
251
- LLM_MODEL,
252
- device_map="cpu",
253
- torch_dtype=torch.float32
254
- )
255
- return pipeline(
256
- "text-generation",
257
- model=model,
258
- tokenizer=tokenizer,
259
- max_new_tokens=400,
260
- do_sample=True,
261
- temperature=0.3,
262
- top_k=30,
263
- top_p=0.9,
264
- repetition_penalty=1.2
265
- )
266
-
267
 
268
- # Later in your generate_answer function:
269
- generator = load_generator()
 
 
 
 
 
 
 
 
 
270
  except Exception as e:
271
  print(f"Error loading model: {e}")
272
  raise
@@ -285,15 +250,13 @@ def extract_final_response(full_response: str) -> str:
285
 
286
  def generate_answer(query: str) -> Tuple[str, float]:
287
  try:
288
- # Input validation
289
  is_valid, msg = guard.validate_input(query)
290
  if not is_valid:
291
  return msg, 0.0
292
 
293
- # Retrieve context
294
  context = hybrid_retrieval(query)
 
295
 
296
- # Generate response
297
  prompt = f"""<|im_start|>system
298
  You are a financial analyst. Provide a brief answer using the context.
299
  Context: {context}<|im_end|>
@@ -302,19 +265,19 @@ Context: {context}<|im_end|>
302
  <|im_start|>assistant
303
  Answer:"""
304
 
 
 
305
  response = generator(prompt)[0]['generated_text']
306
  clean_response = extract_final_response(response)
307
  clean_response = guard.filter_output(clean_response)
308
 
309
- # Calculate confidence
310
  query_embed = embeddings.embed_query(query)
311
  response_embed = embeddings.embed_query(clean_response)
312
  confidence = cosine_similarity([query_embed], [response_embed])[0][0]
313
 
314
- # Update memory
315
  memory.add_interaction(query, clean_response)
316
 
317
  return clean_response, round(confidence, 2)
318
 
319
  except Exception as e:
320
- return f"Error processing request: {e}", 0.0
 
12
  from typing import Tuple
13
  import torch
14
 
 
 
15
  # LangChain components
16
  from langchain_community.document_loaders import PyPDFLoader
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
24
  from sentence_transformers import CrossEncoder
25
  from sklearn.metrics.pairwise import cosine_similarity
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  # Initialize NLTK stopwords
28
  # nltk.download('stopwords')
29
  # stop_words = set(stopwords.words('english'))
30
  nltk.data.path.append('./nltk_data') # Point to local NLTK data
31
  stop_words = set(nltk.corpus.stopwords.words('english'))
32
 
33
+ # mount
34
+ import sys
35
+ sys.path.append('/mount/src/gen_ai_dev')
36
+
37
  # Configuration
38
  DATA_PATH = "./Infy financial report/"
39
  DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
40
  EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
41
+ LLM_MODEL = "microsoft/phi-2"
42
 
43
  # Environment settings
44
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
80
  text_chunks = load_and_chunk_documents()
81
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
82
 
83
+ vector_db = Chroma.from_documents(
84
+ documents=text_chunks,
85
+ embedding=embeddings,
86
+ persist_directory="./chroma_db"
87
+ )
88
+ vector_db.persist()
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  # BM25 setup
91
  bm25_corpus = [chunk.page_content for chunk in text_chunks]
 
113
  [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
114
  )
115
 
116
+
117
  memory = ConversationMemory(max_size=3)
118
 
119
+
120
  # ------------------------------
121
  # Hybrid Retrieval System
122
  # ------------------------------
 
189
  query_lower = query.lower()
190
  if any(topic in query_lower for topic in self.blocked_topics):
191
  return False, "I only discuss financial topics."
192
+ if not any(term in query_lower for term in self.financial_terms):
193
+ return False, "Please ask financial questions."
194
  return True, ""
195
 
196
  def filter_output(self, response: str) -> str:
 
214
  # LLM Initialization
215
  # ------------------------------
216
  try:
217
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
218
+ model = AutoModelForCausalLM.from_pretrained(
219
+ LLM_MODEL,
220
+ device_map="cpu",
221
+ torch_dtype=torch.float32
222
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ generator = pipeline(
225
+ "text-generation",
226
+ model=model,
227
+ tokenizer=tokenizer,
228
+ max_new_tokens=400,
229
+ do_sample=True,
230
+ temperature=0.3,
231
+ top_k=30,
232
+ top_p=0.9,
233
+ repetition_penalty=1.2
234
+ )
235
  except Exception as e:
236
  print(f"Error loading model: {e}")
237
  raise
 
250
 
251
  def generate_answer(query: str) -> Tuple[str, float]:
252
  try:
 
253
  is_valid, msg = guard.validate_input(query)
254
  if not is_valid:
255
  return msg, 0.0
256
 
 
257
  context = hybrid_retrieval(query)
258
+ vector_db.persist()
259
 
 
260
  prompt = f"""<|im_start|>system
261
  You are a financial analyst. Provide a brief answer using the context.
262
  Context: {context}<|im_end|>
 
265
  <|im_start|>assistant
266
  Answer:"""
267
 
268
+ print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
269
+
270
  response = generator(prompt)[0]['generated_text']
271
  clean_response = extract_final_response(response)
272
  clean_response = guard.filter_output(clean_response)
273
 
 
274
  query_embed = embeddings.embed_query(query)
275
  response_embed = embeddings.embed_query(clean_response)
276
  confidence = cosine_similarity([query_embed], [response_embed])[0][0]
277
 
 
278
  memory.add_interaction(query, clean_response)
279
 
280
  return clean_response, round(confidence, 2)
281
 
282
  except Exception as e:
283
+ return f"Error processing request: {e}", 0.0