Update utils.py
Browse files
    	
        utils.py
    CHANGED
    
    | @@ -12,8 +12,6 @@ from collections import deque | |
| 12 | 
             
            from typing import Tuple
         | 
| 13 | 
             
            import torch
         | 
| 14 |  | 
| 15 | 
            -
            import streamlit as st
         | 
| 16 | 
            -
             | 
| 17 | 
             
            # LangChain components
         | 
| 18 | 
             
            from langchain_community.document_loaders import PyPDFLoader
         | 
| 19 | 
             
            from langchain.text_splitter import RecursiveCharacterTextSplitter
         | 
| @@ -26,31 +24,21 @@ from rank_bm25 import BM25Okapi | |
| 26 | 
             
            from sentence_transformers import CrossEncoder
         | 
| 27 | 
             
            from sklearn.metrics.pairwise import cosine_similarity
         | 
| 28 |  | 
| 29 | 
            -
            import sys
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            sys.path.append('/mount/src/gen_ai_dev')
         | 
| 32 | 
            -
             | 
| 33 | 
            -
            # these three lines swap the stdlib sqlite3 lib with the pysqlite3 package
         | 
| 34 | 
            -
            import pysqlite3
         | 
| 35 | 
            -
            import sys
         | 
| 36 | 
            -
            sys.modules["sqlite3"] = pysqlite3
         | 
| 37 | 
            -
             | 
| 38 | 
            -
            __import__('pysqlite3')
         | 
| 39 | 
            -
            import sys
         | 
| 40 | 
            -
             | 
| 41 | 
            -
            sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
         | 
| 42 | 
            -
             | 
| 43 | 
             
            # Initialize NLTK stopwords
         | 
| 44 | 
             
            # nltk.download('stopwords')
         | 
| 45 | 
             
            # stop_words = set(stopwords.words('english'))
         | 
| 46 | 
             
            nltk.data.path.append('./nltk_data')  # Point to local NLTK data
         | 
| 47 | 
             
            stop_words = set(nltk.corpus.stopwords.words('english'))
         | 
| 48 |  | 
|  | |
|  | |
|  | |
|  | |
| 49 | 
             
            # Configuration
         | 
| 50 | 
             
            DATA_PATH = "./Infy financial report/"
         | 
| 51 | 
             
            DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
         | 
| 52 | 
             
            EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
         | 
| 53 | 
            -
            LLM_MODEL = " | 
| 54 |  | 
| 55 | 
             
            # Environment settings
         | 
| 56 | 
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| @@ -92,24 +80,12 @@ def load_and_chunk_documents(): | |
| 92 | 
             
            text_chunks = load_and_chunk_documents()
         | 
| 93 | 
             
            embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
         | 
| 94 |  | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
                 | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
                # Initialize embeddings
         | 
| 102 | 
            -
                embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                # Create and return Chroma vector store
         | 
| 105 | 
            -
                return Chroma.from_documents(
         | 
| 106 | 
            -
                    documents=text_chunks,
         | 
| 107 | 
            -
                    embedding=embeddings,
         | 
| 108 | 
            -
                    persist_directory="./chroma_db"
         | 
| 109 | 
            -
                )
         | 
| 110 | 
            -
             | 
| 111 | 
            -
            # Initialize vector_db
         | 
| 112 | 
            -
            vector_db = load_vector_db()
         | 
| 113 |  | 
| 114 | 
             
            # BM25 setup
         | 
| 115 | 
             
            bm25_corpus = [chunk.page_content for chunk in text_chunks]
         | 
| @@ -137,8 +113,10 @@ class ConversationMemory: | |
| 137 | 
             
                        [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
         | 
| 138 | 
             
                    )
         | 
| 139 |  | 
|  | |
| 140 | 
             
            memory = ConversationMemory(max_size=3)
         | 
| 141 |  | 
|  | |
| 142 | 
             
            # ------------------------------
         | 
| 143 | 
             
            # Hybrid Retrieval System
         | 
| 144 | 
             
            # ------------------------------
         | 
| @@ -211,8 +189,8 @@ class SafetyGuard: | |
| 211 | 
             
                    query_lower = query.lower()
         | 
| 212 | 
             
                    if any(topic in query_lower for topic in self.blocked_topics):
         | 
| 213 | 
             
                        return False, "I only discuss financial topics."
         | 
| 214 | 
            -
                     | 
| 215 | 
            -
             | 
| 216 | 
             
                    return True, ""
         | 
| 217 |  | 
| 218 | 
             
                def filter_output(self, response: str) -> str:
         | 
| @@ -236,37 +214,24 @@ guard = SafetyGuard() | |
| 236 | 
             
            # LLM Initialization
         | 
| 237 | 
             
            # ------------------------------
         | 
| 238 | 
             
            try:
         | 
| 239 | 
            -
                 | 
| 240 | 
            -
                 | 
| 241 | 
            -
                     | 
| 242 | 
            -
                     | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
                            device_map="auto",
         | 
| 246 | 
            -
                            torch_dtype=torch.bfloat16,
         | 
| 247 | 
            -
                            load_in_4bit=True
         | 
| 248 | 
            -
                        )
         | 
| 249 | 
            -
                    else:
         | 
| 250 | 
            -
                        model = AutoModelForCausalLM.from_pretrained(
         | 
| 251 | 
            -
                            LLM_MODEL,
         | 
| 252 | 
            -
                            device_map="cpu",
         | 
| 253 | 
            -
                            torch_dtype=torch.float32
         | 
| 254 | 
            -
                        )
         | 
| 255 | 
            -
                    return pipeline(
         | 
| 256 | 
            -
                        "text-generation",
         | 
| 257 | 
            -
                        model=model,
         | 
| 258 | 
            -
                        tokenizer=tokenizer,
         | 
| 259 | 
            -
                        max_new_tokens=400,
         | 
| 260 | 
            -
                        do_sample=True,
         | 
| 261 | 
            -
                        temperature=0.3,
         | 
| 262 | 
            -
                        top_k=30,
         | 
| 263 | 
            -
                        top_p=0.9,
         | 
| 264 | 
            -
                        repetition_penalty=1.2
         | 
| 265 | 
            -
                    )
         | 
| 266 | 
            -
             | 
| 267 |  | 
| 268 | 
            -
                 | 
| 269 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 270 | 
             
            except Exception as e:
         | 
| 271 | 
             
                print(f"Error loading model: {e}")
         | 
| 272 | 
             
                raise
         | 
| @@ -285,15 +250,13 @@ def extract_final_response(full_response: str) -> str: | |
| 285 |  | 
| 286 | 
             
            def generate_answer(query: str) -> Tuple[str, float]:
         | 
| 287 | 
             
                try:
         | 
| 288 | 
            -
                    # Input validation
         | 
| 289 | 
             
                    is_valid, msg = guard.validate_input(query)
         | 
| 290 | 
             
                    if not is_valid:
         | 
| 291 | 
             
                        return msg, 0.0
         | 
| 292 |  | 
| 293 | 
            -
                    # Retrieve context
         | 
| 294 | 
             
                    context = hybrid_retrieval(query)
         | 
|  | |
| 295 |  | 
| 296 | 
            -
                    # Generate response
         | 
| 297 | 
             
                    prompt = f"""<|im_start|>system
         | 
| 298 | 
             
            You are a financial analyst. Provide a brief answer using the context.
         | 
| 299 | 
             
            Context: {context}<|im_end|>
         | 
| @@ -302,19 +265,19 @@ Context: {context}<|im_end|> | |
| 302 | 
             
            <|im_start|>assistant
         | 
| 303 | 
             
            Answer:"""
         | 
| 304 |  | 
|  | |
|  | |
| 305 | 
             
                    response = generator(prompt)[0]['generated_text']
         | 
| 306 | 
             
                    clean_response = extract_final_response(response)
         | 
| 307 | 
             
                    clean_response = guard.filter_output(clean_response)
         | 
| 308 |  | 
| 309 | 
            -
                    # Calculate confidence
         | 
| 310 | 
             
                    query_embed = embeddings.embed_query(query)
         | 
| 311 | 
             
                    response_embed = embeddings.embed_query(clean_response)
         | 
| 312 | 
             
                    confidence = cosine_similarity([query_embed], [response_embed])[0][0]
         | 
| 313 |  | 
| 314 | 
            -
                    # Update memory
         | 
| 315 | 
             
                    memory.add_interaction(query, clean_response)
         | 
| 316 |  | 
| 317 | 
             
                    return clean_response, round(confidence, 2)
         | 
| 318 |  | 
| 319 | 
             
                except Exception as e:
         | 
| 320 | 
            -
                    return f"Error processing request: {e}", 0.0
         | 
|  | |
| 12 | 
             
            from typing import Tuple
         | 
| 13 | 
             
            import torch
         | 
| 14 |  | 
|  | |
|  | |
| 15 | 
             
            # LangChain components
         | 
| 16 | 
             
            from langchain_community.document_loaders import PyPDFLoader
         | 
| 17 | 
             
            from langchain.text_splitter import RecursiveCharacterTextSplitter
         | 
|  | |
| 24 | 
             
            from sentence_transformers import CrossEncoder
         | 
| 25 | 
             
            from sklearn.metrics.pairwise import cosine_similarity
         | 
| 26 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 27 | 
             
            # Initialize NLTK stopwords
         | 
| 28 | 
             
            # nltk.download('stopwords')
         | 
| 29 | 
             
            # stop_words = set(stopwords.words('english'))
         | 
| 30 | 
             
            nltk.data.path.append('./nltk_data')  # Point to local NLTK data
         | 
| 31 | 
             
            stop_words = set(nltk.corpus.stopwords.words('english'))
         | 
| 32 |  | 
| 33 | 
            +
            # mount
         | 
| 34 | 
            +
            import sys
         | 
| 35 | 
            +
            sys.path.append('/mount/src/gen_ai_dev')
         | 
| 36 | 
            +
             | 
| 37 | 
             
            # Configuration
         | 
| 38 | 
             
            DATA_PATH = "./Infy financial report/"
         | 
| 39 | 
             
            DATA_FILES = ["INFY_2022_2023.pdf", "INFY_2023_2024.pdf"]
         | 
| 40 | 
             
            EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
         | 
| 41 | 
            +
            LLM_MODEL = "microsoft/phi-2"
         | 
| 42 |  | 
| 43 | 
             
            # Environment settings
         | 
| 44 | 
             
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
|  | |
| 80 | 
             
            text_chunks = load_and_chunk_documents()
         | 
| 81 | 
             
            embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
         | 
| 82 |  | 
| 83 | 
            +
            vector_db = Chroma.from_documents(
         | 
| 84 | 
            +
                documents=text_chunks,
         | 
| 85 | 
            +
                embedding=embeddings,
         | 
| 86 | 
            +
                persist_directory="./chroma_db"
         | 
| 87 | 
            +
            )
         | 
| 88 | 
            +
            vector_db.persist()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 89 |  | 
| 90 | 
             
            # BM25 setup
         | 
| 91 | 
             
            bm25_corpus = [chunk.page_content for chunk in text_chunks]
         | 
|  | |
| 113 | 
             
                        [f"Previous Q: {q}\nPrevious A: {r}" for q, r in self.buffer]
         | 
| 114 | 
             
                    )
         | 
| 115 |  | 
| 116 | 
            +
             | 
| 117 | 
             
            memory = ConversationMemory(max_size=3)
         | 
| 118 |  | 
| 119 | 
            +
             | 
| 120 | 
             
            # ------------------------------
         | 
| 121 | 
             
            # Hybrid Retrieval System
         | 
| 122 | 
             
            # ------------------------------
         | 
|  | |
| 189 | 
             
                    query_lower = query.lower()
         | 
| 190 | 
             
                    if any(topic in query_lower for topic in self.blocked_topics):
         | 
| 191 | 
             
                        return False, "I only discuss financial topics."
         | 
| 192 | 
            +
                    if not any(term in query_lower for term in self.financial_terms):
         | 
| 193 | 
            +
                        return False, "Please ask financial questions."
         | 
| 194 | 
             
                    return True, ""
         | 
| 195 |  | 
| 196 | 
             
                def filter_output(self, response: str) -> str:
         | 
|  | |
| 214 | 
             
            # LLM Initialization
         | 
| 215 | 
             
            # ------------------------------
         | 
| 216 | 
             
            try:
         | 
| 217 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
         | 
| 218 | 
            +
                model = AutoModelForCausalLM.from_pretrained(
         | 
| 219 | 
            +
                    LLM_MODEL,
         | 
| 220 | 
            +
                    device_map="cpu",
         | 
| 221 | 
            +
                    torch_dtype=torch.float32
         | 
| 222 | 
            +
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 223 |  | 
| 224 | 
            +
                generator = pipeline(
         | 
| 225 | 
            +
                    "text-generation",
         | 
| 226 | 
            +
                    model=model,
         | 
| 227 | 
            +
                    tokenizer=tokenizer,
         | 
| 228 | 
            +
                    max_new_tokens=400,
         | 
| 229 | 
            +
                    do_sample=True,
         | 
| 230 | 
            +
                    temperature=0.3,
         | 
| 231 | 
            +
                    top_k=30,
         | 
| 232 | 
            +
                    top_p=0.9,
         | 
| 233 | 
            +
                    repetition_penalty=1.2
         | 
| 234 | 
            +
                )
         | 
| 235 | 
             
            except Exception as e:
         | 
| 236 | 
             
                print(f"Error loading model: {e}")
         | 
| 237 | 
             
                raise
         | 
|  | |
| 250 |  | 
| 251 | 
             
            def generate_answer(query: str) -> Tuple[str, float]:
         | 
| 252 | 
             
                try:
         | 
|  | |
| 253 | 
             
                    is_valid, msg = guard.validate_input(query)
         | 
| 254 | 
             
                    if not is_valid:
         | 
| 255 | 
             
                        return msg, 0.0
         | 
| 256 |  | 
|  | |
| 257 | 
             
                    context = hybrid_retrieval(query)
         | 
| 258 | 
            +
                    vector_db.persist()
         | 
| 259 |  | 
|  | |
| 260 | 
             
                    prompt = f"""<|im_start|>system
         | 
| 261 | 
             
            You are a financial analyst. Provide a brief answer using the context.
         | 
| 262 | 
             
            Context: {context}<|im_end|>
         | 
|  | |
| 265 | 
             
            <|im_start|>assistant
         | 
| 266 | 
             
            Answer:"""
         | 
| 267 |  | 
| 268 | 
            +
                    print(f"\n\n[For Debug Only] Prompt: {prompt}\n\n")
         | 
| 269 | 
            +
             | 
| 270 | 
             
                    response = generator(prompt)[0]['generated_text']
         | 
| 271 | 
             
                    clean_response = extract_final_response(response)
         | 
| 272 | 
             
                    clean_response = guard.filter_output(clean_response)
         | 
| 273 |  | 
|  | |
| 274 | 
             
                    query_embed = embeddings.embed_query(query)
         | 
| 275 | 
             
                    response_embed = embeddings.embed_query(clean_response)
         | 
| 276 | 
             
                    confidence = cosine_similarity([query_embed], [response_embed])[0][0]
         | 
| 277 |  | 
|  | |
| 278 | 
             
                    memory.add_interaction(query, clean_response)
         | 
| 279 |  | 
| 280 | 
             
                    return clean_response, round(confidence, 2)
         | 
| 281 |  | 
| 282 | 
             
                except Exception as e:
         | 
| 283 | 
            +
                    return f"Error processing request: {e}", 0.0
         |