Spaces:
Running
Running
import os | |
from dotenv import load_dotenv | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import HumanMessage | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_voyageai import VoyageAIEmbeddings | |
from langchain_pinecone import PineconeVectorStore | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from typing import List, Tuple | |
from langchain.schema import BaseRetriever | |
from langchain_core.documents import Document | |
from langchain_core.runnables import chain | |
from pinecone import Pinecone, ServerlessSpec | |
import openai | |
import numpy as np | |
import gradio as gr | |
load_dotenv() | |
# Initialize OpenAI and Pinecone credentials | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
pinecone_api_key = os.environ.get("PINECONE_API_KEY") | |
pinecone_environment = os.environ.get("PINECONE_ENV") | |
voyage_api_key = os.environ.get("VOYAGE_API_KEY") | |
# Initialize Pinecone | |
try: | |
pc = Pinecone(api_key=pinecone_api_key) | |
except Exception as e: | |
print(f"Error connecting to Pinecone: {str(e)}") | |
embeddings = VoyageAIEmbeddings( | |
voyage_api_key=voyage_api_key, model="voyage-law-2" | |
) | |
def expand_query(query): | |
""" | |
Expands the query to make it more precise using an LLM. | |
Example: "docs" -> "Find all legal documents related to case law." | |
""" | |
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.3) | |
prompt = f"Rewrite the following vague search query into a more specific one:\nQuery: {query}\nSpecific Query:" | |
refined_query = llm([HumanMessage(content=prompt)]).content.strip() | |
return refined_query if refined_query else query | |
def search_documents(query, user_groups, index_name="briefmeta", min_score=0.01): | |
try: | |
vector_store = PineconeVectorStore(index_name=index_name, embedding=embeddings) | |
results = vector_store.max_marginal_relevance_search(query, k=10, fetch_k=30) | |
seen_ids = set() | |
unique_results = [] | |
for result in results: | |
unique_id = result.metadata.get("id") | |
doc_groups = result.metadata.get("groups", []) | |
score = result.metadata.get("score", 0) | |
# Apply user group filtering & score threshold | |
if unique_id not in seen_ids and any(group in user_groups for group in doc_groups) and score > min_score: | |
seen_ids.add(unique_id) | |
unique_results.append(result) | |
context = [ | |
{ | |
"doc_id": result.metadata.get("doc_id", "N/A"), | |
"chunk_id": result.metadata.get("id", "N/A"), | |
"title": result.metadata.get("source", "N/A"), | |
"text": result.page_content, | |
"page_number": str(result.metadata.get("page_number", "N/A")), | |
"score": str(result.metadata.get("score", "N/A")), | |
} | |
for result in unique_results | |
] | |
return context | |
except Exception as e: | |
return [], f"Error searching documents: {str(e)}" | |
def rerank(query, context): | |
result = pc.inference.rerank( | |
model="bge-reranker-v2-m3", | |
query=query, | |
documents=context, | |
top_n=5, | |
return_documents=True, | |
) | |
return result | |
def generate_output(context, query): | |
try: | |
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.5) | |
if not context.strip(): | |
return "I couldn't find relevant information for your query. Could you refine your question?" | |
prompt_template = PromptTemplate( | |
template="""Use the following document context to answer accurately: | |
Context: {context} | |
Question: {question} | |
If the answer is unclear, ask for clarification. | |
Answer:""", | |
input_variables=["context", "question"] | |
) | |
prompt = prompt_template.format(context=context, question=query) | |
response = llm([HumanMessage(content=prompt)]).content.strip() | |
return response if response else "No relevant answer found." | |
except Exception as e: | |
return f"Error generating output: {str(e)}" | |
def generate_search_summary(search_results, document_titles, query): | |
""" | |
Generates an intelligent search summary based on retrieved documents. | |
""" | |
try: | |
if not search_results: | |
return "No relevant documents were found for your search. Try refining your query." | |
# Extract metadata | |
num_results = len(document_titles) | |
doc_titles = [doc.get("title", "Unknown Document") for doc in search_results] | |
doc_pages = [doc.get("page_number", "N/A") for doc in search_results] | |
relevance_scores = [float(doc.get("score", 0)) for doc in search_results] | |
# Identify recency (to be implemented) | |
recency_info = "" | |
if "date_uploaded" in search_results[0]: # Assuming date is available | |
dates = [doc.get("date_uploaded", "Unknown") for doc in search_results] | |
recency_info = f"Most recent document uploaded on {max(dates)}." | |
# Identify common keywords | |
common_terms = set() | |
for doc in search_results: | |
text_snippet = doc.get("text", "").split()[:50] # Take first 50 words | |
common_terms.update(text_snippet) | |
summary_prompt = f""" | |
Generate a concise 1-3 sentence summary of the search results. | |
- User Query: "{query}" | |
- Matching Documents: {num_results} found | |
- Titles: {", ".join(set(doc_titles))} | |
- Pages Referenced: {", ".join(set(doc_pages))} | |
- Common Terms: {", ".join(list(common_terms)[:10])} (top terms) | |
- Recency: {recency_info} | |
- Relevance Scores (0-1): {relevance_scores} | |
Provide a clear, user-friendly summary with an action suggestion. | |
""" | |
llm = ChatOpenAI(model="gpt-4", openai_api_key=openai.api_key, temperature=0.5) | |
summary = llm([HumanMessage(content=summary_prompt)]).content.strip() | |
return summary if summary else "No intelligent summary available." | |
except Exception as e: | |
return f"Error generating search summary: {str(e)}" | |
def complete_workflow(query, user_groups, index_name="briefmeta"): | |
try: | |
# Expand the query | |
refined_query = expand_query(query) | |
# Proceed with refined query instead of the original | |
context_data = search_documents(refined_query, user_groups) | |
reranked = rerank(refined_query, context_data) | |
context_data = [] | |
for i, entry in enumerate(reranked.data): | |
context_data.append({ | |
'chunk_id': entry['document']['chunk_id'], | |
'doc_id': entry['document']['doc_id'], | |
'title': entry['document']['title'], | |
'text': entry['document']['text'], | |
'page_number': str(entry['document']['page_number']), | |
'score': str(entry['score']) | |
}) | |
document_titles = list({os.path.basename(doc["title"]) for doc in context_data}) | |
formatted_titles = " " + "\n".join(document_titles) | |
total_results = len(context_data) | |
results = { | |
"results": [ | |
{ | |
"natural_language_output": generate_output(doc["text"], refined_query), # Use refined query | |
"chunk_id": doc["chunk_id"], | |
"document_id": doc["doc_id"], | |
"title": doc["title"], | |
"text": doc["text"], | |
"page_number": doc["page_number"], | |
"score": doc["score"], | |
} | |
for doc in context_data | |
], | |
"total_results": total_results | |
} | |
return results, formatted_titles | |
except Exception as e: | |
return {"results": [], "total_results": 0}, f"Error in workflow: {str(e)}" | |
def gradio_app(): | |
with gr.Blocks(css=".result-output {width: 150%; font-size: 16px; padding: 10px;}") as app: | |
gr.Markdown("### Intelligent Document Search Prototype-v0.2") | |
with gr.Row(): | |
user_query = gr.Textbox(label=" Enter Search Query") | |
user_groups = gr.Textbox(label=" User Groups", placeholder="e.g., ['KarthikPersonal']", interactive=True) | |
index_name = gr.Textbox(label=" Index Name", placeholder="Default: briefmeta", interactive=True) | |
search_btn = gr.Button(" Search") | |
with gr.Row(): | |
result_output = gr.JSON(label=" Search Results", elem_id="result-output") | |
with gr.Row(): | |
titles_output = gr.Textbox(label=" Retrieved Document Titles", interactive=False) | |
search_btn.click( | |
complete_workflow, | |
inputs=[user_query, user_groups, index_name], | |
outputs=[result_output, titles_output] | |
) | |
return app | |
# Launch the app | |
gradio_app().launch() |