Spaces:
Running
Running
import os | |
import json | |
import redis | |
import openai | |
import numpy as np | |
import gradio as gr | |
from dotenv import load_dotenv | |
from pinecone import Pinecone, ServerlessSpec | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_voyageai import VoyageAIEmbeddings | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_openai import ChatOpenAI | |
from langchain_core.documents import Document | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import HumanMessage | |
from sentence_transformers import CrossEncoder | |
# Load environment variables | |
load_dotenv() | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
pinecone_api_key = os.getenv("PINECONE_API_KEY") | |
pinecone_environment = os.getenv("PINECONE_ENV") | |
voyage_api_key = os.getenv("VOYAGE_API_KEY") | |
# Initialize Pinecone | |
pc = Pinecone(api_key=pinecone_api_key) | |
# Redis caching for reranking | |
# redis_client = redis.Redis(host='localhost', port=6379, db=0) | |
# Initialize embeddings | |
embeddings = VoyageAIEmbeddings(voyage_api_key=voyage_api_key, model="voyage-law-2") | |
# Load Cross-Encoder model for reranking | |
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
# **1. Optimized Pinecone Index Initialization** | |
def initialize_pinecone_index(index_name): | |
""" | |
Ensures the Pinecone index is optimized for fast ANN-based search. | |
""" | |
if index_name not in pc.list_indexes(): | |
pc.create_index( | |
name=index_name, | |
dimension=1024, | |
metric="cosine", | |
spec=ServerlessSpec(cloud="aws", region="us-west-2"), | |
hnsw_config={"ef_construction": 200, "M": 16} # Fast ANN search | |
) | |
return PineconeVectorStore(index_name=index_name, embedding=embeddings) | |
# **2. Query Expansion** | |
QUERY_EXPANSIONS = { | |
"docs": "Find all legal documents related to case law.", | |
"contract": "Find contracts and legal agreements relevant to the query.", | |
"policy": "Retrieve company policies and regulatory guidelines." | |
} | |
def expand_query(query): | |
""" | |
Expands the query efficiently using predefined mappings and LLM if needed. | |
""" | |
query = query.strip().lower() | |
if query in QUERY_EXPANSIONS: | |
return QUERY_EXPANSIONS[query] | |
if len(query.split()) < 3: | |
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3) | |
prompt = f"Rewrite the following vague search query into a more specific one:\nQuery: {query}.\nSpecific Query:" | |
refined_query = llm([HumanMessage(content=prompt)]).content.strip() | |
return refined_query if refined_query else query | |
return query | |
# **3. Hybrid Search (Dense + Sparse Fusion)** | |
def search_documents(query, user_groups, index_name="briefmeta"): | |
""" | |
Hybrid search combining semantic and sparse (keyword) retrieval. | |
""" | |
try: | |
vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings) | |
# Dense search (Semantic embeddings) | |
dense_results = vector_store.similarity_search_with_relevance_scores( | |
query=query, k=10, sparse_weight=0.3, | |
filter={"groups": {"$in": user_groups}} | |
) | |
# Sparse search (BM25-style keyword search) | |
sparse_results = vector_store.sparse_search(query=query, k=10) | |
# Fusion of results | |
hybrid_results = {} | |
for doc, score in dense_results: | |
hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.7} | |
for doc, score in sparse_results: | |
if doc.metadata["id"] in hybrid_results: | |
hybrid_results[doc.metadata["id"]]["score"] += score * 0.3 | |
else: | |
hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.3} | |
# Sort by final score | |
final_results = sorted(hybrid_results.values(), key=lambda x: x["score"], reverse=True) | |
# Format output | |
search_output = [ | |
{ | |
"doc_id": item["doc"].metadata.get("doc_id", "N/A"), | |
"title": item["doc"].metadata.get("source", "N/A"), | |
"text": item["doc"].page_content, | |
"score": round(item["score"], 3) | |
} | |
for item in final_results | |
] | |
return search_output | |
except Exception as e: | |
return [], f"Error in hybrid search: {str(e)}" | |
# **4. Reranking with Cross-Encoder (Cached)** | |
def rerank_results(query, search_results): | |
""" | |
Uses a Cross-Encoder for reranking search results. | |
""" | |
if not search_results: | |
return search_results | |
cache_key = f"rerank:{query}" | |
# cached_result = redis_client.get(cache_key) | |
# if cached_result: | |
# return json.loads(cached_result) | |
# Prepare input pairs for reranking | |
pairs = [(query, doc["text"]) for doc in search_results] | |
scores = reranker_model.predict(pairs) | |
# Attach scores and sort | |
for i, score in enumerate(scores): | |
search_results[i]["rerank_score"] = round(float(score), 3) | |
sorted_results = sorted(search_results, key=lambda x: x["rerank_score"], reverse=True) | |
# redis_client.setex(cache_key, 600, json.dumps(sorted_results)) # Cache for 10 min | |
return sorted_results | |
# **5. Intelligent Search Summary** | |
def generate_search_summary(search_results, query): | |
""" | |
Generates an intelligent search summary. | |
""" | |
if not search_results: | |
return "No relevant documents were found for your search." | |
top_docs = search_results[:3] | |
doc_titles = [doc["title"] for doc in top_docs] | |
summary_prompt = f""" | |
Generate a **concise** 2-3 sentence summary of the search results. | |
- User Query: "{query}" | |
- Matching Documents: {len(search_results)} found | |
- Titles: {", ".join(doc_titles)} | |
**Summarize in user-friendly language.** | |
""" | |
llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=openai.api_key, temperature=0.5) | |
summary = llm([HumanMessage(content=summary_prompt)]).content.strip() | |
return summary if summary else "No intelligent summary available." | |
# **6. Full RAG Workflow** | |
def complete_workflow(query, user_groups, index_name="briefmeta"): | |
""" | |
Full RAG workflow: Hybrid Search -> Reranking -> Intelligent Summary | |
""" | |
try: | |
query = expand_query(query) | |
raw_results = search_documents(query, user_groups, index_name) | |
reranked_results = rerank_results(query, raw_results) | |
document_titles = list({doc["title"] for doc in reranked_results}) | |
formatted_titles = " " + "\n".join(document_titles) | |
intelligent_search_summary = generate_search_summary(reranked_results, query) | |
results = { | |
"results": reranked_results[:5], | |
"total_results": len(reranked_results) | |
} | |
return results, formatted_titles, intelligent_search_summary | |
except Exception as e: | |
return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}" | |
# **7. Gradio UI** | |
def gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown("## π AI-Powered Document Search") | |
user_query = gr.Textbox(label="Enter Your Search Query") | |
user_groups = gr.Textbox(label="Enter User Groups", interactive=True) | |
search_btn = gr.Button("Search") | |
results_output = gr.JSON(label="Search Results") | |
search_summary = gr.Textbox(label="Intelligent Search Summary") | |
search_btn.click(complete_workflow, inputs=[user_query, user_groups], outputs=[results_output, search_summary]) | |
return app | |
gradio_app().launch() | |