karthikvarunn's picture
Update app.py
09a1990 verified
raw
history blame
7.61 kB
import os
import json
import redis
import openai
import numpy as np
import gradio as gr
from dotenv import load_dotenv
from pinecone import Pinecone, ServerlessSpec
from langchain_openai import OpenAIEmbeddings
from langchain_voyageai import VoyageAIEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import HumanMessage
from sentence_transformers import CrossEncoder
# Load environment variables
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
pinecone_environment = os.getenv("PINECONE_ENV")
voyage_api_key = os.getenv("VOYAGE_API_KEY")
# Initialize Pinecone
pc = Pinecone(api_key=pinecone_api_key)
# Redis caching for reranking
# redis_client = redis.Redis(host='localhost', port=6379, db=0)
# Initialize embeddings
embeddings = VoyageAIEmbeddings(voyage_api_key=voyage_api_key, model="voyage-law-2")
# Load Cross-Encoder model for reranking
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
# **1. Optimized Pinecone Index Initialization**
def initialize_pinecone_index(index_name):
"""
Ensures the Pinecone index is optimized for fast ANN-based search.
"""
if index_name not in pc.list_indexes():
pc.create_index(
name=index_name,
dimension=1024,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-west-2"),
hnsw_config={"ef_construction": 200, "M": 16} # Fast ANN search
)
return PineconeVectorStore(index_name=index_name, embedding=embeddings)
# **2. Query Expansion**
QUERY_EXPANSIONS = {
"docs": "Find all legal documents related to case law.",
"contract": "Find contracts and legal agreements relevant to the query.",
"policy": "Retrieve company policies and regulatory guidelines."
}
def expand_query(query):
"""
Expands the query efficiently using predefined mappings and LLM if needed.
"""
query = query.strip().lower()
if query in QUERY_EXPANSIONS:
return QUERY_EXPANSIONS[query]
if len(query.split()) < 3:
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3)
prompt = f"Rewrite the following vague search query into a more specific one:\nQuery: {query}.\nSpecific Query:"
refined_query = llm([HumanMessage(content=prompt)]).content.strip()
return refined_query if refined_query else query
return query
# **3. Hybrid Search (Dense + Sparse Fusion)**
def search_documents(query, user_groups, index_name="briefmeta"):
"""
Hybrid search combining semantic and sparse (keyword) retrieval.
"""
try:
vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings)
# Dense search (Semantic embeddings)
dense_results = vector_store.similarity_search_with_relevance_scores(
query=query, k=10, sparse_weight=0.3,
filter={"groups": {"$in": user_groups}}
)
# Sparse search (BM25-style keyword search)
sparse_results = vector_store.sparse_search(query=query, k=10)
# Fusion of results
hybrid_results = {}
for doc, score in dense_results:
hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.7}
for doc, score in sparse_results:
if doc.metadata["id"] in hybrid_results:
hybrid_results[doc.metadata["id"]]["score"] += score * 0.3
else:
hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.3}
# Sort by final score
final_results = sorted(hybrid_results.values(), key=lambda x: x["score"], reverse=True)
# Format output
search_output = [
{
"doc_id": item["doc"].metadata.get("doc_id", "N/A"),
"title": item["doc"].metadata.get("source", "N/A"),
"text": item["doc"].page_content,
"score": round(item["score"], 3)
}
for item in final_results
]
return search_output
except Exception as e:
return [], f"Error in hybrid search: {str(e)}"
# **4. Reranking with Cross-Encoder (Cached)**
def rerank_results(query, search_results):
"""
Uses a Cross-Encoder for reranking search results.
"""
if not search_results:
return search_results
cache_key = f"rerank:{query}"
# cached_result = redis_client.get(cache_key)
# if cached_result:
# return json.loads(cached_result)
# Prepare input pairs for reranking
pairs = [(query, doc["text"]) for doc in search_results]
scores = reranker_model.predict(pairs)
# Attach scores and sort
for i, score in enumerate(scores):
search_results[i]["rerank_score"] = round(float(score), 3)
sorted_results = sorted(search_results, key=lambda x: x["rerank_score"], reverse=True)
# redis_client.setex(cache_key, 600, json.dumps(sorted_results)) # Cache for 10 min
return sorted_results
# **5. Intelligent Search Summary**
def generate_search_summary(search_results, query):
"""
Generates an intelligent search summary.
"""
if not search_results:
return "No relevant documents were found for your search."
top_docs = search_results[:3]
doc_titles = [doc["title"] for doc in top_docs]
summary_prompt = f"""
Generate a **concise** 2-3 sentence summary of the search results.
- User Query: "{query}"
- Matching Documents: {len(search_results)} found
- Titles: {", ".join(doc_titles)}
**Summarize in user-friendly language.**
"""
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=openai.api_key, temperature=0.5)
summary = llm([HumanMessage(content=summary_prompt)]).content.strip()
return summary if summary else "No intelligent summary available."
# **6. Full RAG Workflow**
def complete_workflow(query, user_groups, index_name="briefmeta"):
"""
Full RAG workflow: Hybrid Search -> Reranking -> Intelligent Summary
"""
try:
query = expand_query(query)
raw_results = search_documents(query, user_groups, index_name)
reranked_results = rerank_results(query, raw_results)
document_titles = list({doc["title"] for doc in reranked_results})
formatted_titles = " " + "\n".join(document_titles)
intelligent_search_summary = generate_search_summary(reranked_results, query)
results = {
"results": reranked_results[:5],
"total_results": len(reranked_results)
}
return results, formatted_titles, intelligent_search_summary
except Exception as e:
return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}"
# **7. Gradio UI**
def gradio_app():
with gr.Blocks() as app:
gr.Markdown("## πŸ” AI-Powered Document Search")
user_query = gr.Textbox(label="Enter Your Search Query")
user_groups = gr.Textbox(label="Enter User Groups", interactive=True)
search_btn = gr.Button("Search")
results_output = gr.JSON(label="Search Results")
search_summary = gr.Textbox(label="Intelligent Search Summary")
search_btn.click(complete_workflow, inputs=[user_query, user_groups], outputs=[results_output, search_summary])
return app
gradio_app().launch()