Spaces:
Runtime error
Runtime error
import os | |
import asyncio | |
from dotenv import load_dotenv | |
# helper functions | |
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" | |
TAVILY_API_KEY="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" | |
OPENAI_API_KEY="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" | |
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" | |
QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" | |
OPENAI_API_KEY="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" | |
WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" | |
WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" | |
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
# Initialize DeepInfra-compatible OpenAI client | |
from openai import OpenAI | |
openai = OpenAI( | |
api_key=DEEPINFRA_API_KEY, | |
base_url="https://api.deepinfra.com/v1/openai", | |
) | |
# Weaviate imports | |
import weaviate | |
from weaviate.classes.init import Auth | |
from contextlib import contextmanager | |
def weaviate_client(): | |
""" | |
Context manager that yields a Weaviate client and | |
guarantees client.close() on exit. | |
""" | |
client = weaviate.connect_to_weaviate_cloud( | |
cluster_url=WEAVIATE_URL, | |
auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
skip_init_checks=True, # <-- This disables gRPC check | |
) | |
try: | |
yield client | |
finally: | |
client.close() | |
def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]: | |
"""Embed texts in batches to avoid API limits.""" | |
all_embeddings: list[list[float]] = [] | |
for i in range(0, len(texts), batch_size): | |
batch = texts[i : i + batch_size] | |
try: | |
resp = openai.embeddings.create( | |
model="Qwen/Qwen3-Embedding-8B", | |
input=batch, | |
encoding_format="float" | |
) | |
batch_embs = [item.embedding for item in resp.data] | |
all_embeddings.extend(batch_embs) | |
except Exception as e: | |
print(f"Embedding batch error (items {i}–{i+len(batch)-1}): {e}") | |
all_embeddings.extend([[] for _ in batch]) | |
return all_embeddings | |
def encode_query(query: str) -> list[float] | None: | |
"""Generate a single embedding vector for a query string.""" | |
embs = embed_texts([query], batch_size=1) | |
if embs and embs[0]: | |
print("Query embedding (first 5 dims):", embs[0][:5]) | |
return embs[0] | |
print("Failed to generate query embedding.") | |
return None | |
async def rag_autism(query: str, top_k: int = 3) -> dict: | |
""" | |
Run a RAG retrieval on the 'Books' collection in Weaviate. | |
Returns up to `top_k` matching text chunks. | |
""" | |
qe = encode_query(query) | |
if not qe: | |
return {"answer": []} | |
try: | |
with weaviate_client() as client: | |
coll = client.collections.get("Books") | |
res = coll.query.near_vector( | |
near_vector=qe, | |
limit=top_k, | |
return_properties=["text"] | |
) | |
if not getattr(res, "objects", None): | |
return {"answer": []} | |
return { | |
"answer": [ | |
obj.properties.get("text", "[No Text]") | |
for obj in res.objects | |
] | |
} | |
except Exception as e: | |
print("RAG Error:", e) | |
return {"answer": []} | |
# Example test harness | |
# if __name__ == "__main__": | |
# test_queries = [ | |
# "What are the common early signs of autism in young children?", | |
# "What diagnostic criteria are used for autism spectrum disorder?", | |
# "What support strategies help improve communication skills in autistic individuals?" | |
# ] | |
# for q in test_queries: | |
# print(f"\nQuery: {q}") | |
# out = asyncio.run(rag_autism(q, top_k=3)) | |
# print("Retrieved contexts:") | |
# for idx, ctx in enumerate(out["answer"], 1): | |
# print(f"{idx}. {ctx}") | |