File size: 7,657 Bytes
30d3882
e4f6e5a
 
 
 
 
30d3882
e4f6e5a
 
30d3882
 
e4f6e5a
 
 
 
 
 
30d3882
e4f6e5a
30d3882
e4f6e5a
 
 
 
30d3882
 
dae202a
e4f6e5a
 
09a1990
e4f6e5a
 
dae202a
30d3882
e4f6e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30d3882
e4f6e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0f09b
e4f6e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a1990
 
 
e4f6e5a
 
 
 
 
 
 
 
 
 
 
09a1990
e4f6e5a
 
 
 
e7ce30c
e4f6e5a
 
 
dae202a
e4f6e5a
dae202a
e4f6e5a
 
dae202a
 
e4f6e5a
dae202a
e4f6e5a
 
 
 
dae202a
 
e4f6e5a
dae202a
e4f6e5a
dae202a
 
e4f6e5a
 
30d3882
e4f6e5a
 
 
30d3882
e4f6e5a
 
 
dae202a
e4f6e5a
30d3882
e4f6e5a
 
30d3882
 
e4f6e5a
 
30d3882
 
dae202a
30d3882
e4f6e5a
 
30d3882
e4f6e5a
30d3882
dae202a
e4f6e5a
 
 
 
 
 
 
dae202a
e4f6e5a
30d3882
 
 
dae202a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
import json
import redis
import openai
import numpy as np
import gradio as gr
from dotenv import load_dotenv
from pinecone import Pinecone, ServerlessSpec
from langchain_openai import OpenAIEmbeddings
from langchain_voyageai import VoyageAIEmbeddings
from langchain_pinecone import PineconeVectorStore
from langchain_openai import ChatOpenAI
from langchain_core.documents import Document
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import HumanMessage
from sentence_transformers import CrossEncoder

# Load environment variables
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
pinecone_environment = os.getenv("PINECONE_ENV")
voyage_api_key = os.getenv("VOYAGE_API_KEY")

# Initialize Pinecone
pc = Pinecone(api_key=pinecone_api_key)

# Redis caching for reranking
# redis_client = redis.Redis(host='localhost', port=6379, db=0)

# Initialize embeddings
embeddings = VoyageAIEmbeddings(voyage_api_key=voyage_api_key, model="voyage-law-2")

# Load Cross-Encoder model for reranking
reranker_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")


# **1. Optimized Pinecone Index Initialization**
def initialize_pinecone_index(index_name):
    """
    Ensures the Pinecone index is optimized for fast ANN-based search.
    """
    if index_name not in pc.list_indexes():
        pc.create_index(
            name=index_name,
            dimension=1024,
            metric="cosine",
            spec=ServerlessSpec(cloud="aws", region="us-west-2"),
            hnsw_config={"ef_construction": 200, "M": 16}  # Fast ANN search
        )
    
    return PineconeVectorStore(index_name=index_name, embedding=embeddings)


# **2. Query Expansion**
QUERY_EXPANSIONS = {
    "docs": "Find all legal documents related to case law.",
    "contract": "Find contracts and legal agreements relevant to the query.",
    "policy": "Retrieve company policies and regulatory guidelines."
}

def expand_query(query):
    """
    Expands the query efficiently using predefined mappings and LLM if needed.
    """
    query = query.strip().lower()
    
    if query in QUERY_EXPANSIONS:
        return QUERY_EXPANSIONS[query]
    
    if len(query.split()) < 3:
        llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3)
        prompt = f"Rewrite the following vague search query into a more specific one:\nQuery: {query}.\nSpecific Query:"
        refined_query = llm([HumanMessage(content=prompt)]).content.strip()
        return refined_query if refined_query else query
    
    return query


# **3. Hybrid Search (Dense + Sparse Fusion)**
def search_documents(query, user_groups, index_name="briefmeta"):
    """
    Hybrid search combining semantic and sparse (keyword) retrieval.
    """
    try:
        pc = initialize_pinecone_index(index_name)
        vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings)

        # Dense search (Semantic embeddings)
        dense_results = vector_store.similarity_search_with_relevance_scores(
            query=query, k=10, sparse_weight=0.3,  
            filter={"groups": {"$in": user_groups}}
        )

        # Sparse search (BM25-style keyword search)
        sparse_results = vector_store.sparse_search(query=query, k=10)

        # Fusion of results
        hybrid_results = {}
        for doc, score in dense_results:
            hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.7}

        for doc, score in sparse_results:
            if doc.metadata["id"] in hybrid_results:
                hybrid_results[doc.metadata["id"]]["score"] += score * 0.3
            else:
                hybrid_results[doc.metadata["id"]] = {"doc": doc, "score": score * 0.3}

        # Sort by final score
        final_results = sorted(hybrid_results.values(), key=lambda x: x["score"], reverse=True)

        # Format output
        search_output = [
            {
                "doc_id": item["doc"].metadata.get("doc_id", "N/A"),
                "title": item["doc"].metadata.get("source", "N/A"),
                "text": item["doc"].page_content,
                "score": round(item["score"], 3)
            }
            for item in final_results
        ]

        return search_output

    except Exception as e:
        return [], f"Error in hybrid search: {str(e)}"


# **4. Reranking with Cross-Encoder (Cached)**
def rerank_results(query, search_results):
    """
    Uses a Cross-Encoder for reranking search results.
    """
    if not search_results:
        return search_results  

    cache_key = f"rerank:{query}"
    # cached_result = redis_client.get(cache_key)
    # if cached_result:
    #     return json.loads(cached_result)

    # Prepare input pairs for reranking
    pairs = [(query, doc["text"]) for doc in search_results]
    scores = reranker_model.predict(pairs)

    # Attach scores and sort
    for i, score in enumerate(scores):
        search_results[i]["rerank_score"] = round(float(score), 3)

    sorted_results = sorted(search_results, key=lambda x: x["rerank_score"], reverse=True)

    # redis_client.setex(cache_key, 600, json.dumps(sorted_results))  # Cache for 10 min
    return sorted_results


# **5. Intelligent Search Summary**
def generate_search_summary(search_results, query):
    """
    Generates an intelligent search summary.
    """
    if not search_results:
        return "No relevant documents were found for your search."

    top_docs = search_results[:3]
    doc_titles = [doc["title"] for doc in top_docs]

    summary_prompt = f"""
    Generate a **concise** 2-3 sentence summary of the search results.
    - User Query: "{query}"
    - Matching Documents: {len(search_results)} found
    - Titles: {", ".join(doc_titles)}

    **Summarize in user-friendly language.**
    """

    llm = ChatOpenAI(model="gpt-3.5-turbo", openai_api_key=openai.api_key, temperature=0.5)
    summary = llm([HumanMessage(content=summary_prompt)]).content.strip()

    return summary if summary else "No intelligent summary available."


# **6. Full RAG Workflow**
def complete_workflow(query, user_groups, index_name="briefmeta"):
    """
    Full RAG workflow: Hybrid Search -> Reranking -> Intelligent Summary
    """
    try:
        query = expand_query(query)
        raw_results = search_documents(query, user_groups, index_name)
        reranked_results = rerank_results(query, raw_results)

        document_titles = list({doc["title"] for doc in reranked_results})
        formatted_titles = " " + "\n".join(document_titles)

        intelligent_search_summary = generate_search_summary(reranked_results, query)

        results = {
            "results": reranked_results[:5],
            "total_results": len(reranked_results)
        }

        return results, formatted_titles, intelligent_search_summary
    except Exception as e:
        return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}"


# **7. Gradio UI**
def gradio_app():
    with gr.Blocks() as app:
        gr.Markdown("## πŸ” AI-Powered Document Search")

        user_query = gr.Textbox(label="Enter Your Search Query")
        user_groups = gr.Textbox(label="Enter User Groups", interactive=True)
        search_btn = gr.Button("Search")
        results_output = gr.JSON(label="Search Results")
        search_summary = gr.Textbox(label="Intelligent Search Summary")

        search_btn.click(complete_workflow, inputs=[user_query, user_groups], outputs=[results_output, search_summary])

    return app

gradio_app().launch()