karthikvarunn commited on
Commit
e4f6e5a
·
verified ·
1 Parent(s): 793fcf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -146
app.py CHANGED
@@ -1,187 +1,222 @@
1
  import os
 
 
 
 
 
2
  from dotenv import load_dotenv
3
- from langchain_community.document_loaders import PyPDFLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.schema import HumanMessage
6
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
7
  from langchain_voyageai import VoyageAIEmbeddings
8
  from langchain_pinecone import PineconeVectorStore
9
- from langchain.prompts import PromptTemplate
10
- from pinecone import Pinecone
11
- from sklearn.feature_extraction.text import TfidfVectorizer
12
- from sklearn.metrics.pairwise import cosine_similarity
13
- import openai
14
- import gradio as gr
15
 
16
- # Load API keys
17
  load_dotenv()
18
- openai.api_key = os.environ.get("OPENAI_API_KEY")
19
- pinecone_api_key = os.environ.get("PINECONE_API_KEY")
20
- voyage_api_key = os.environ.get("VOYAGE_API_KEY")
 
21
 
22
  # Initialize Pinecone
23
  pc = Pinecone(api_key=pinecone_api_key)
 
 
 
 
 
24
  embeddings = VoyageAIEmbeddings(voyage_api_key=voyage_api_key, model="voyage-law-2")
25
 
26
- # 🔹 Query Expansion using GPT-4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def expand_query(query):
28
- llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3)
29
- prompt = f"Rewrite this vague query into a more specific one:\nQuery: {query}\nSpecific Query:"
30
- refined_query = llm([HumanMessage(content=prompt)]).content.strip()
31
- return refined_query if refined_query else query
32
-
33
- # 🔹 Hybrid Search (TF-IDF + Semantic Retrieval)
34
- def hybrid_search(query, user_groups, index_name="briefmeta", min_score=0, fetch_k=50):
35
- vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings)
36
- semantic_results = vector_store.max_marginal_relevance_search(query, k=10, fetch_k=fetch_k)
37
-
38
- all_texts = [doc.page_content for doc in semantic_results]
39
- vectorizer = TfidfVectorizer(stop_words="english")
40
- tfidf_matrix = vectorizer.fit_transform(all_texts)
41
- query_tfidf = vectorizer.transform([query])
42
- keyword_scores = cosine_similarity(query_tfidf, tfidf_matrix).flatten()
43
-
44
- combined_results, seen_ids = [], set()
45
- for i, doc in enumerate(semantic_results):
46
- doc_id, doc_groups = doc.metadata.get("id"), doc.metadata.get("groups", [])
47
- semantic_score = float(doc.metadata.get("score", 0))
48
- keyword_score = float(keyword_scores[i])
49
- final_score = 0.7 * semantic_score + 0.3 * keyword_score # Hybrid score
50
-
51
- if doc_id not in seen_ids and any(group in user_groups for group in doc_groups) and final_score > min_score:
52
- seen_ids.add(doc_id)
53
- doc.metadata["final_score"] = final_score
54
- combined_results.append(doc)
55
-
56
- combined_results.sort(key=lambda x: x.metadata["final_score"], reverse=True)
57
- return [
58
- {
59
- "doc_id": doc.metadata.get("doc_id", "N/A"),
60
- "chunk_id": doc.metadata.get("id", "N/A"),
61
- "title": doc.metadata.get("source", "N/A"),
62
- "text": doc.page_content,
63
- "page_number": str(doc.metadata.get("page_number", "N/A")),
64
- "score": str(doc.metadata.get("final_score", "N/A")),
65
- }
66
- for doc in combined_results
67
- ]
68
-
69
- # 🔹 Metadata-Weighted Reranking
70
- def rerank(query, context):
71
- reranker = pc.inference.rerank(
72
- model="bge-reranker-v2-m3", query=query, documents=context, top_n=10, return_documents=True
73
- )
74
-
75
- final_reranked = []
76
- for entry in reranker.data:
77
- doc, score = entry["document"], float(entry["score"])
78
- citation_boost = 1.2 if "high_citations" in doc.get("tags", []) else 1.0
79
- recency_boost = 1.1 if "recent_upload" in doc.get("tags", []) else 1.0
80
- final_score = score * citation_boost * recency_boost
81
- doc["final_score"] = final_score
82
- final_reranked.append(doc)
83
-
84
- final_reranked.sort(key=lambda x: x["final_score"], reverse=True)
85
- return final_reranked
86
-
87
- # 🔹 Intelligent Search Summary Generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def generate_search_summary(search_results, query):
 
 
 
89
  if not search_results:
90
- return "No relevant documents found. Try refining your query."
91
 
92
- num_results = len(search_results)
93
- doc_titles = [doc.get("title", "Unknown Document") for doc in search_results]
94
- doc_pages = [doc.get("page_number", "N/A") for doc in search_results]
95
- relevance_scores = [float(doc.get("score", 0)) for doc in search_results]
96
 
97
  summary_prompt = f"""
98
- Generate a concise 1-3 sentence summary:
99
  - User Query: "{query}"
100
- - Matching Documents: {num_results} found
101
- - Titles: {", ".join(set(doc_titles))}
102
- - Pages Referenced: {", ".join(set(doc_pages))}
103
- - Relevance Scores (0-1): {relevance_scores}
104
- Provide a clear, user-friendly summary with an action suggestion.
105
  """
106
 
107
- llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.5)
108
  summary = llm([HumanMessage(content=summary_prompt)]).content.strip()
 
109
  return summary if summary else "No intelligent summary available."
110
 
111
- # 🔹 LLM-based Answer Generation
112
- def generate_output(context, query):
113
- if not context.strip():
114
- return "No relevant information found. Try refining your query."
115
-
116
- llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.5)
117
- prompt_template = PromptTemplate(
118
- template="Use the following context to answer the question:\nContext: {context}\nQuestion: {question}\nAnswer:",
119
- input_variables=["context", "question"],
120
- )
121
- prompt = prompt_template.format(context=context, question=query)
122
- response = llm([HumanMessage(content=prompt)]).content.strip()
123
- return response if response else "No relevant answer found."
124
-
125
- # 🔹 Full Workflow
126
  def complete_workflow(query, user_groups, index_name="briefmeta"):
 
 
 
127
  try:
128
- refined_query = expand_query(query)
129
- context_data = hybrid_search(refined_query, user_groups)
130
- reranked_results = rerank(refined_query, context_data)
131
 
132
- context_data = [
133
- {
134
- 'chunk_id': doc["chunk_id"],
135
- 'doc_id': doc["doc_id"],
136
- 'title': doc["title"],
137
- 'text': doc["text"],
138
- 'page_number': str(doc["page_number"]),
139
- 'score': str(doc["final_score"])
140
- }
141
- for doc in reranked_results
142
- ]
143
-
144
- document_titles = list({os.path.basename(doc["title"]) for doc in context_data})
145
  formatted_titles = " " + "\n".join(document_titles)
146
- intelligent_search_summary = generate_search_summary(context_data, refined_query)
 
147
 
148
  results = {
149
- "results": [
150
- {
151
- "natural_language_output": generate_output(doc["text"], refined_query),
152
- "chunk_id": doc["chunk_id"],
153
- "document_id": doc["doc_id"],
154
- "title": doc["title"],
155
- "text": doc["text"],
156
- "page_number": doc["page_number"],
157
- "score": doc["score"],
158
- }
159
- for doc in context_data
160
- ],
161
- "total_results": len(context_data),
162
- "intelligent_search_summary": intelligent_search_summary
163
  }
164
 
165
  return results, formatted_titles, intelligent_search_summary
166
-
167
  except Exception as e:
168
- return {"results": [], "total_results": 0, "intelligent_search_summary": "Error generating summary."}, f"Error in workflow: {str(e)}"
 
169
 
170
- # 🔹 Gradio UI
171
  def gradio_app():
172
  with gr.Blocks() as app:
173
- gr.Markdown("### 📄 Intelligent Document Search Prototype-v0.2")
174
- user_query = gr.Textbox(label="🔍 Enter Search Query")
175
- user_groups = gr.Textbox(label="👥 User Groups", placeholder="e.g., ['KarthikPersonal']")
176
- index_name = gr.Textbox(label="📂 Index Name", placeholder="Default: briefmeta")
177
- search_btn = gr.Button("🔎 Search")
178
- search_summary = gr.Textbox(label="📜 Intelligent Search Summary", interactive=False)
179
- result_output = gr.JSON(label="📊 Search Results")
180
- titles_output = gr.Textbox(label="📂 Retrieved Document Titles", interactive=False)
181
 
182
- search_btn.click(complete_workflow, inputs=[user_query, user_groups, index_name], outputs=[result_output, titles_output, search_summary])
183
 
184
  return app
185
 
186
- # Launch the App
187
  gradio_app().launch()
 
1
  import os
2
+ import json
3
+ import redis
4
+ import openai
5
+ import numpy as np
6
+ import gradio as gr
7
  from dotenv import load_dotenv
8
+ from pinecone import Pinecone, ServerlessSpec
9
+ from langchain_openai import OpenAIEmbeddings
 
 
10
  from langchain_voyageai import VoyageAIEmbeddings
11
  from langchain_pinecone import PineconeVectorStore
12
+ from langchain_openai import ChatOpenAI
13
+ from langchain_core.documents import Document
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.schema import HumanMessage
17
+ from sentence_transformers import CrossEncoder
18
 
19
+ # Load environment variables
20
  load_dotenv()
21
+ openai.api_key = os.getenv("OPENAI_API_KEY")
22
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
23
+ pinecone_environment = os.getenv("PINECONE_ENV")
24
+ voyage_api_key = os.getenv("VOYAGE_API_KEY")
25
 
26
  # Initialize Pinecone
27
  pc = Pinecone(api_key=pinecone_api_key)
28
+
29
+ # Redis caching for reranking
30
+ redis_client = redis.Redis(host='localhost', port=6379, db=0)
31
+
32
+ # Initialize embeddings
33
  embeddings = VoyageAIEmbeddings(voyage_api_key=voyage_api_key, model="voyage-law-2")
34
 
35
+ # Load Cross-Encoder model for reranking
36
+ reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
37
+
38
+
39
+ # **1. Optimized Pinecone Index Initialization**
40
+ def initialize_pinecone_index(index_name):
41
+ """
42
+ Ensures the Pinecone index is optimized for fast ANN-based search.
43
+ """
44
+ if index_name not in pc.list_indexes():
45
+ pc.create_index(
46
+ name=index_name,
47
+ dimension=1024,
48
+ metric="cosine",
49
+ spec=ServerlessSpec(cloud="aws", region="us-west-2"),
50
+ hnsw_config={"ef_construction": 200, "M": 16} # Fast ANN search
51
+ )
52
+
53
+ return PineconeVectorStore(index_name=index_name, embedding=embeddings)
54
+
55
+
56
+ # **2. Query Expansion**
57
+ QUERY_EXPANSIONS = {
58
+ "docs": "Find all legal documents related to case law.",
59
+ "contract": "Find contracts and legal agreements relevant to the query.",
60
+ "policy": "Retrieve company policies and regulatory guidelines."
61
+ }
62
+
63
  def expand_query(query):
64
+ """
65
+ Expands the query efficiently using predefined mappings and LLM if needed.
66
+ """
67
+ query = query.strip().lower()
68
+
69
+ if query in QUERY_EXPANSIONS:
70
+ return QUERY_EXPANSIONS[query]
71
+
72
+ if len(query.split()) < 3:
73
+ llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3)
74
+ prompt = f"Rewrite the following vague search query into a more specific one:\nQuery: {query}.\nSpecific Query:"
75
+ refined_query = llm([HumanMessage(content=prompt)]).content.strip()
76
+ return refined_query if refined_query else query
77
+
78
+ return query
79
+
80
+
81
+ # **3. Hybrid Search (Dense + Sparse Fusion)**
82
+ def search_documents(query, user_groups, index_name="briefmeta"):
83
+ """
84
+ Hybrid search combining semantic and sparse (keyword) retrieval.
85
+ """
86
+ try:
87
+ vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings)
88
+
89
+ # Dense search (Semantic embeddings)
90
+ dense_results = vector_store.similarity_search_with_relevance_scores(
91
+ query=query, k=10, sparse_weight=0.3,
92
+ filter={"groups": {"$in": user_groups}}
93
+ )
94
+
95
+ # Sparse search (BM25-style keyword search)
96
+ sparse_results = vector_store.sparse_search(query=query, k=10)
97
+
98
+ # Fusion of results
99
+ hybrid_results = {}
100
+ for doc, score in dense_results:
101
+ hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.7}
102
+
103
+ for doc, score in sparse_results:
104
+ if doc.metadata["id"] in hybrid_results:
105
+ hybrid_results[doc.metadata["id"]]["score"] += score * 0.3
106
+ else:
107
+ hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.3}
108
+
109
+ # Sort by final score
110
+ final_results = sorted(hybrid_results.values(), key=lambda x: x["score"], reverse=True)
111
+
112
+ # Format output
113
+ search_output = [
114
+ {
115
+ "doc_id": item["doc"].metadata.get("doc_id", "N/A"),
116
+ "title": item["doc"].metadata.get("source", "N/A"),
117
+ "text": item["doc"].page_content,
118
+ "score": round(item["score"], 3)
119
+ }
120
+ for item in final_results
121
+ ]
122
+
123
+ return search_output
124
+
125
+ except Exception as e:
126
+ return [], f"Error in hybrid search: {str(e)}"
127
+
128
+
129
+ # **4. Reranking with Cross-Encoder (Cached)**
130
+ def rerank_results(query, search_results):
131
+ """
132
+ Uses a Cross-Encoder for reranking search results.
133
+ """
134
+ if not search_results:
135
+ return search_results
136
+
137
+ cache_key = f"rerank:{query}"
138
+ cached_result = redis_client.get(cache_key)
139
+ if cached_result:
140
+ return json.loads(cached_result)
141
+
142
+ # Prepare input pairs for reranking
143
+ pairs = [(query, doc["text"]) for doc in search_results]
144
+ scores = reranker_model.predict(pairs)
145
+
146
+ # Attach scores and sort
147
+ for i, score in enumerate(scores):
148
+ search_results[i]["rerank_score"] = round(float(score), 3)
149
+
150
+ sorted_results = sorted(search_results, key=lambda x: x["rerank_score"], reverse=True)
151
+
152
+ redis_client.setex(cache_key, 600, json.dumps(sorted_results)) # Cache for 10 min
153
+ return sorted_results
154
+
155
+
156
+ # **5. Intelligent Search Summary**
157
  def generate_search_summary(search_results, query):
158
+ """
159
+ Generates an intelligent search summary.
160
+ """
161
  if not search_results:
162
+ return "No relevant documents were found for your search."
163
 
164
+ top_docs = search_results[:3]
165
+ doc_titles = [doc["title"] for doc in top_docs]
 
 
166
 
167
  summary_prompt = f"""
168
+ Generate a **concise** 2-3 sentence summary of the search results.
169
  - User Query: "{query}"
170
+ - Matching Documents: {len(search_results)} found
171
+ - Titles: {", ".join(doc_titles)}
172
+
173
+ **Summarize in user-friendly language.**
 
174
  """
175
 
176
+ llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=openai.api_key, temperature=0.5)
177
  summary = llm([HumanMessage(content=summary_prompt)]).content.strip()
178
+
179
  return summary if summary else "No intelligent summary available."
180
 
181
+
182
+ # **6. Full RAG Workflow**
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def complete_workflow(query, user_groups, index_name="briefmeta"):
184
+ """
185
+ Full RAG workflow: Hybrid Search -> Reranking -> Intelligent Summary
186
+ """
187
  try:
188
+ query = expand_query(query)
189
+ raw_results = search_documents(query, user_groups, index_name)
190
+ reranked_results = rerank_results(query, raw_results)
191
 
192
+ document_titles = list({doc["title"] for doc in reranked_results})
 
 
 
 
 
 
 
 
 
 
 
 
193
  formatted_titles = " " + "\n".join(document_titles)
194
+
195
+ intelligent_search_summary = generate_search_summary(reranked_results, query)
196
 
197
  results = {
198
+ "results": reranked_results[:5],
199
+ "total_results": len(reranked_results)
 
 
 
 
 
 
 
 
 
 
 
 
200
  }
201
 
202
  return results, formatted_titles, intelligent_search_summary
 
203
  except Exception as e:
204
+ return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}"
205
+
206
 
207
+ # **7. Gradio UI**
208
  def gradio_app():
209
  with gr.Blocks() as app:
210
+ gr.Markdown("## 🔍 AI-Powered Document Search")
211
+
212
+ user_query = gr.Textbox(label="Enter Your Search Query")
213
+ user_groups = gr.Textbox(label="Enter User Groups", interactive=True)
214
+ search_btn = gr.Button("Search")
215
+ results_output = gr.JSON(label="Search Results")
216
+ search_summary = gr.Textbox(label="Intelligent Search Summary")
 
217
 
218
+ search_btn.click(complete_workflow, inputs=[user_query, user_groups], outputs=[results_output, search_summary])
219
 
220
  return app
221
 
 
222
  gradio_app().launch()