Spaces:
Runtime error
Runtime error
import os | |
import asyncio | |
from dotenv import load_dotenv | |
import gradio as gr | |
from query_utils import process_query_for_rewrite, get_non_autism_response | |
# helper functions | |
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" | |
TAVILY_API_KEY="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" | |
OPENAI_API_KEY="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" | |
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" | |
QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" | |
OPENAI_API_KEY="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" | |
WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" | |
WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" | |
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
# if not (DEEPINFRA_TOKEN and WEAVIATE_URL and WEAVIATE_API_KEY): | |
# raise ValueError("Please set all required keys in .env") | |
# DeepInfra client | |
from openai import OpenAI | |
openai = OpenAI( | |
api_key=DEEPINFRA_API_KEY, | |
base_url="https://api.deepinfra.com/v1/openai", | |
) | |
# Weaviate client | |
import weaviate | |
from weaviate.classes.init import Auth | |
from contextlib import contextmanager | |
def weaviate_client(): | |
client = weaviate.connect_to_weaviate_cloud( | |
cluster_url=WEAVIATE_URL, | |
auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
skip_init_checks=True, # <-- This disables gRPC check | |
) | |
try: | |
yield client | |
finally: | |
client.close() | |
# Global path tracker | |
last_uploaded_path = None | |
# Embed function | |
def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]: | |
all_embeddings = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i : i + batch_size] | |
try: | |
resp = openai.embeddings.create( | |
model="Qwen/Qwen3-Embedding-8B", | |
input=batch, | |
encoding_format="float" | |
) | |
batch_embs = [item.embedding for item in resp.data] | |
all_embeddings.extend(batch_embs) | |
except Exception as e: | |
print(f"Embedding error: {e}") | |
all_embeddings.extend([[] for _ in batch]) | |
return all_embeddings | |
def encode_query(query: str) -> list[float] | None: | |
embs = embed_texts([query], batch_size=1) | |
if embs and embs[0]: | |
return embs[0] | |
return None | |
async def old_Document(query: str, top_k: int = 1) -> dict: | |
qe = encode_query(query) | |
if not qe: | |
return {"answer": []} | |
try: | |
with weaviate_client() as client: | |
coll = client.collections.get("user") | |
res = coll.query.near_vector( | |
near_vector=qe, | |
limit=top_k, | |
return_properties=["text"] | |
) | |
if not getattr(res, "objects", None): | |
return {"answer": []} | |
return { | |
"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects] | |
} | |
except Exception as e: | |
print("RAG Error:", e) | |
return {"answer": []} | |
# New functions to support Gradio app | |
def ingest_file(path: str) -> str: | |
global last_uploaded_path | |
last_uploaded_path = path | |
return f"Old document ingested: {os.path.basename(path)}" | |
def answer_question(query: str) -> str: | |
try: | |
# Process query for rewriting and relevance checking | |
corrected_query, is_autism_related, rewritten_query = process_query_for_rewrite(query) | |
# If not autism-related, show direct rejection message | |
if not is_autism_related: | |
return get_non_autism_response() | |
# Use the corrected query for retrieval | |
rag_resp = asyncio.run(old_Document(corrected_query)) | |
chunks = rag_resp.get("answer", []) | |
if not chunks: | |
return "Sorry, I couldn't find relevant content in the old document." | |
# Combine chunks into a single answer for relevance checking | |
combined_answer = "\n".join(f"- {c}" for c in chunks) | |
# NEW: Check if the retrieved content is sufficiently related to autism | |
from query_utils import check_answer_autism_relevance, get_non_autism_answer_response | |
answer_relevance_score = check_answer_autism_relevance(combined_answer) | |
# If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring) | |
if answer_relevance_score < 50: | |
return get_non_autism_answer_response() | |
# If sufficiently autism-related, return the answer | |
return combined_answer | |
except Exception as e: | |
return f"Error processing your request: {e}" | |
# Gradio interface for Old Documents | |
with gr.Blocks(title="Old Documents RAG") as demo: | |
gr.Markdown("## Old Documents RAG") | |
query = gr.Textbox(placeholder="Your question...", lines=2, label="Ask about Old Documents") | |
doc_file = gr.File(label="Upload Old Document (PDF, DOCX, TXT)") | |
btn = gr.Button("Submit") | |
out = gr.Textbox(label="Answer from Old Documents", lines=8, interactive=False) | |
def process_old_doc(query, doc_file): | |
if doc_file: | |
# Save and ingest the uploaded file | |
upload_dir = os.path.join(os.path.dirname(__file__), "uploaded_docs") | |
os.makedirs(upload_dir, exist_ok=True) | |
safe_filename = os.path.basename(doc_file.name) | |
save_path = os.path.join(upload_dir, safe_filename) | |
with open(save_path, "wb") as f: | |
f.write(doc_file.read()) | |
status = ingest_file(save_path) | |
answer = answer_question(query) | |
return f"{status}\n\n{answer}" | |
else: | |
# Use last uploaded file or return error if none exists | |
if last_uploaded_path: | |
answer = answer_question(query) | |
return f"[Using previously uploaded document: {os.path.basename(last_uploaded_path)}]\n\n{answer}" | |
else: | |
return "No document uploaded. Please upload an old document to proceed." | |
btn.click(fn=process_old_doc, inputs=[query, doc_file], outputs=out) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |