wisalQA_P1 / RAG.py
afouda's picture
Upload 7 files
ea1e6bd verified
raw
history blame
4.09 kB
import os
import asyncio
from dotenv import load_dotenv
# helper functions
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo"
TAVILY_API_KEY="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"
OPENAI_API_KEY="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"
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E"
QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io"
OPENAI_API_KEY="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"
WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud"
WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw"
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"
# Initialize DeepInfra-compatible OpenAI client
from openai import OpenAI
openai = OpenAI(
api_key=DEEPINFRA_API_KEY,
base_url="https://api.deepinfra.com/v1/openai",
)
# Weaviate imports
import weaviate
from weaviate.classes.init import Auth
from contextlib import contextmanager
@contextmanager
def weaviate_client():
"""
Context manager that yields a Weaviate client and
guarantees client.close() on exit.
"""
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
skip_init_checks=True, # <-- This disables gRPC check
)
try:
yield client
finally:
client.close()
def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]:
"""Embed texts in batches to avoid API limits."""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
resp = openai.embeddings.create(
model="Qwen/Qwen3-Embedding-8B",
input=batch,
encoding_format="float"
)
batch_embs = [item.embedding for item in resp.data]
all_embeddings.extend(batch_embs)
except Exception as e:
print(f"Embedding batch error (items {i}{i+len(batch)-1}): {e}")
all_embeddings.extend([[] for _ in batch])
return all_embeddings
def encode_query(query: str) -> list[float] | None:
"""Generate a single embedding vector for a query string."""
embs = embed_texts([query], batch_size=1)
if embs and embs[0]:
print("Query embedding (first 5 dims):", embs[0][:5])
return embs[0]
print("Failed to generate query embedding.")
return None
async def rag_autism(query: str, top_k: int = 3) -> dict:
"""
Run a RAG retrieval on the 'Books' collection in Weaviate.
Returns up to `top_k` matching text chunks.
"""
qe = encode_query(query)
if not qe:
return {"answer": []}
try:
with weaviate_client() as client:
coll = client.collections.get("Books")
res = coll.query.near_vector(
near_vector=qe,
limit=top_k,
return_properties=["text"]
)
if not getattr(res, "objects", None):
return {"answer": []}
return {
"answer": [
obj.properties.get("text", "[No Text]")
for obj in res.objects
]
}
except Exception as e:
print("RAG Error:", e)
return {"answer": []}
# Example test harness
# if __name__ == "__main__":
# test_queries = [
# "What are the common early signs of autism in young children?",
# "What diagnostic criteria are used for autism spectrum disorder?",
# "What support strategies help improve communication skills in autistic individuals?"
# ]
# for q in test_queries:
# print(f"\nQuery: {q}")
# out = asyncio.run(rag_autism(q, top_k=3))
# print("Retrieved contexts:")
# for idx, ctx in enumerate(out["answer"], 1):
# print(f"{idx}. {ctx}")