Jatin Mehra commited on
Commit
eb07e3c
·
1 Parent(s): 75d04ae

Refactor preprocessing.py to enhance PDF processing and integrate FAISS for similarity search

Browse files
Files changed (1) hide show
  1. preprocessing.py +126 -117
preprocessing.py CHANGED
@@ -1,128 +1,137 @@
1
  import os
2
- import PyPDF2
3
- from groq import Groq
4
- import streamlit as st
5
- from collections import defaultdict
 
 
 
 
 
 
 
6
 
7
- class Model:
8
- """
9
- A class that represents a model for generating responses based on a given context and query.
10
- """
 
11
 
12
- def __init__(self):
13
- """
14
- Initializes the Model object and sets up the Groq client.
15
- """
16
- # api_key = os.getenv("GROQ_API_KEY")
17
- api_key = st.secrets["GROQ_API_KEY"]
18
- if not api_key:
19
- raise ValueError("GROQ_API_KEY environment variable is not set.")
20
- self.client = Groq(api_key=api_key)
21
- self.contexts = []
22
- self.cache = defaultdict(dict) # Caching for repeated queries
23
 
24
- def extract_text_from_pdf(self, pdf_file):
25
- """
26
- Extracts text from a PDF file.
27
- Args:
28
- - pdf_file: The file-like object of the PDF.
29
- Returns:
30
- - text: The extracted text from the PDF file.
31
- """
32
- try:
33
- pdf_reader = PyPDF2.PdfReader(pdf_file)
34
- text = ""
35
- for page in pdf_reader.pages:
36
- text += page.extract_text()
37
- return text
38
- except Exception as e:
39
- raise ValueError(f"Error extracting text: {str(e)}")
40
 
41
- def generate_response(self, context, query, temperature, max_tokens, model):
42
- """
43
- Generates a response based on the given context and query.
44
- Args:
45
- - context: The context for generating the response.
46
- - query: The query or question.
47
- - temperature: The sampling temperature for response generation.
48
- - max_tokens: The maximum number of tokens for the response.
49
- - model: The model ID to be used for generating the response.
50
- Returns:
51
- - response: The generated response.
52
- """
53
- # Caching check
54
- if query in self.cache and self.cache[query]["context"] == context:
55
- return self.cache[query]["response"]
56
 
57
- messages = [
58
- {"role": "system", "content": f"Context: {context}"},
59
- {"role": "user", "content": query},
60
- ]
61
- try:
62
- completion = self.client.chat.completions.create(
63
- model=model, # Model ID
64
- messages=messages,
65
- temperature=temperature,
66
- max_tokens=max_tokens,
67
- )
68
- response = completion.choices[0].message.content
69
- self.cache[query]["context"] = context
70
- self.cache[query]["response"] = response # Cache the response
71
- return response
72
- except Exception as e:
73
- return f"API request failed: {str(e)}"
74
 
75
- def add_to_context(self, file_path: str):
76
- """
77
- Reads a PDF file and appends its content to the context for generating responses.
78
- Args:
79
- - file_path: The path to the PDF file.
80
- """
81
- try:
82
- with open(file_path, "rb") as pdf_file:
83
- context = self.extract_text_from_pdf(pdf_file)
84
- self.contexts.append(context)
85
- except Exception as e:
86
- raise ValueError(f"Error processing PDF: {str(e)}")
87
 
88
- def remove_from_context(self, index: int):
89
- """
90
- Removes a document from the context based on its index.
91
- Args:
92
- - index: The index of the document to remove.
93
- """
94
- if 0 <= index < len(self.contexts):
95
- self.contexts.pop(index)
96
- else:
97
- raise ValueError("Invalid index for removing context.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- def get_combined_context(self):
100
- """
101
- Combines all contexts into a single context string.
102
- Returns:
103
- - combined_context: The combined context from all documents.
104
- """
105
- return "\n".join(self.contexts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def get_response(self, question: str, temperature: float, max_tokens: int, model: str):
108
- """
109
- Generates a response based on the given question and the current combined context.
110
- Args:
111
- - question: The user's question.
112
- - temperature: The sampling temperature for response generation.
113
- - max_tokens: The maximum number of tokens for the response.
114
- - model: The model ID to be used for generating the response.
115
- Returns:
116
- - response: The generated response or a prompt to upload a document.
117
- """
118
- if not self.contexts:
119
- return "Please upload a document."
120
- combined_context = self.get_combined_context()
121
- return self.generate_response(combined_context, question, temperature, max_tokens, model)
122
 
123
- def clear(self):
124
- """
125
- Clears the current context.
126
- """
127
- self.contexts = []
128
- self.cache.clear()
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from langchain_community.document_loaders import PyMuPDFLoader
3
+ import faiss
4
+ from langchain_groq import ChatGroq
5
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
8
+ from langchain.memory import ConversationBufferMemory
9
+ from sentence_transformers import SentenceTransformer
10
+ import dotenv
11
+ dotenv.load_dotenv()
12
+ # Initialize LLM and tools globally
13
 
14
+ def model_selection(model_name):
15
+ llm = ChatGroq(model=model_name, api_key=os.getenv("GROQ_API_KEY"))
16
+ return llm
17
+
18
+ tools = [TavilySearchResults(max_results=5)]
19
 
20
+ # Initialize memory for conversation history
21
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
 
 
 
 
 
 
 
 
 
22
 
23
+ def estimate_tokens(text):
24
+ """Estimate the number of tokens in a text (rough approximation)."""
25
+ return len(text) // 4
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def process_pdf_file(file_path):
28
+ """Load a PDF file and extract its text."""
29
+ if not os.path.exists(file_path):
30
+ raise FileNotFoundError(f"The file {file_path} does not exist.")
31
+ loader = PyMuPDFLoader(file_path)
32
+ documents = loader.load()
33
+ text = "".join(doc.page_content for doc in documents)
34
+ return text
 
 
 
 
 
 
 
35
 
36
+ def chunk_text(text, max_length=1500):
37
+ """Split text into chunks based on paragraphs, respecting max_length."""
38
+ paragraphs = text.split("\n\n")
39
+ chunks = []
40
+ current_chunk = ""
41
+ for paragraph in paragraphs:
42
+ if len(current_chunk) + len(paragraph) <= max_length:
43
+ current_chunk += paragraph + "\n\n"
44
+ else:
45
+ chunks.append(current_chunk.strip())
46
+ current_chunk = paragraph + "\n\n"
47
+ if current_chunk:
48
+ chunks.append(current_chunk.strip())
49
+ return chunks
 
 
 
50
 
51
+ def create_embeddings(texts, model):
52
+ """Create embeddings for a list of texts using the provided model."""
53
+ embeddings = model.encode(texts, show_progress_bar=True, convert_to_tensor=True)
54
+ return embeddings.cpu().numpy()
 
 
 
 
 
 
 
 
55
 
56
+ def build_faiss_index(embeddings):
57
+ """Build a FAISS index from embeddings for similarity search."""
58
+ dim = embeddings.shape[1]
59
+ index = faiss.IndexFlatL2(dim)
60
+ index.add(embeddings)
61
+ return index
62
+
63
+ def retrieve_similar_chunks(query, index, texts, model, k=3, max_chunk_length=3500):
64
+ """Retrieve top k similar chunks to the query from the FAISS index."""
65
+ query_embedding = model.encode([query], convert_to_tensor=True).cpu().numpy()
66
+ distances, indices = index.search(query_embedding, k)
67
+ return [(texts[i][:max_chunk_length], distances[0][j]) for j, i in enumerate(indices[0])]
68
+
69
+ def agentic_rag(llm, tools, query, context, Use_Tavily=False):
70
+ # Define the prompt template for the agent
71
+ search_instructions = (
72
+ "Use the search tool if the context is insufficient to answer the question or you are unsure. Give source links if you use the search tool."
73
+ if Use_Tavily
74
+ else "Use the context provided to answer the question."
75
+ )
76
+
77
+ prompt = ChatPromptTemplate.from_messages([
78
+ ("system", """
79
+ You are a helpful assistant. {search_instructions}
80
+ Instructions:
81
+ 1. Use the provided context to answer the user's question.
82
+ 2. Provide a clear answer, if you don't know the answer, say 'I don't know'.
83
+ """),
84
+ ("human", "Context: {context}\n\nQuestion: {input}"),
85
+ MessagesPlaceholder(variable_name="chat_history"),
86
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
87
+ ])
88
 
89
+ # Only use tools when Tavily is enabled
90
+ agent_tools = tools if Use_Tavily else []
91
+
92
+ try:
93
+ # Create the agent and executor with appropriate tools
94
+ agent = create_tool_calling_agent(llm, agent_tools, prompt)
95
+ agent_executor = AgentExecutor(agent=agent, tools=agent_tools, memory=memory, verbose=True)
96
+
97
+ # Execute the agent
98
+ return agent_executor.invoke({
99
+ "input": query,
100
+ "context": context,
101
+ "search_instructions": search_instructions
102
+ })
103
+ except Exception as e:
104
+ print(f"Error during agent execution: {str(e)}")
105
+ # Fallback to direct LLM call without agent framework
106
+ fallback_prompt = ChatPromptTemplate.from_messages([
107
+ ("system", "You are a helpful assistant. Use the provided context to answer the user's question."),
108
+ ("human", "Context: {context}\n\nQuestion: {input}")
109
+ ])
110
+ response = llm.invoke(fallback_prompt.format(context=context, input=query))
111
+ return {"output": response.content}
112
 
113
+ if __name__ == "__main__":
114
+ # Process PDF and prepare index
115
+ dotenv.load_dotenv()
116
+ pdf_file = "JatinCV.pdf"
117
+ llm = model_selection("meta-llama/llama-4-scout-17b-16e-instruct")
118
+ texts = process_pdf_file(pdf_file)
119
+ chunks = chunk_text(texts, max_length=1500)
120
+ model = SentenceTransformer('all-MiniLM-L6-v2')
121
+ embeddings = create_embeddings(chunks, model)
122
+ index = build_faiss_index(embeddings)
 
 
 
 
 
123
 
124
+ # Chat loop
125
+ print("Chat with the assistant (type 'exit' or 'quit' to stop):")
126
+ while True:
127
+ query = input("User: ")
128
+ if query.lower() in ["exit", "quit"]:
129
+ break
130
+
131
+ # Retrieve similar chunks
132
+ similar_chunks = retrieve_similar_chunks(query, index, chunks, model, k=3)
133
+ context = "\n".join([chunk for chunk, _ in similar_chunks])
134
+
135
+ # Generate response
136
+ response = agentic_rag(llm, tools, query=query, context=context, Use_Tavily=True)
137
+ print("Assistant:", response["output"])