File size: 4,094 Bytes
ea1e6bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import asyncio
from dotenv import load_dotenv

# helper functions
GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo"

TAVILY_API_KEY="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"

OPENAI_API_KEY="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"

QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E"

QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io"

OPENAI_API_KEY="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"

WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud"

WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw"

DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"

DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"

# Initialize DeepInfra-compatible OpenAI client
from openai import OpenAI
openai = OpenAI(
    api_key=DEEPINFRA_API_KEY,
    base_url="https://api.deepinfra.com/v1/openai",
)

# Weaviate imports
import weaviate
from weaviate.classes.init import Auth
from contextlib import contextmanager

@contextmanager
def weaviate_client():
    """
    Context manager that yields a Weaviate client and
    guarantees client.close() on exit.
    """
    client = weaviate.connect_to_weaviate_cloud(
        cluster_url=WEAVIATE_URL,
        auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
        skip_init_checks=True,   # <-- This disables gRPC check

    )
    try:
        yield client
    finally:
        client.close()

def embed_texts(texts: list[str], batch_size: int = 50) -> list[list[float]]:
    """Embed texts in batches to avoid API limits."""
    all_embeddings: list[list[float]] = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        try:
            resp = openai.embeddings.create(
                model="Qwen/Qwen3-Embedding-8B",
                input=batch,
                encoding_format="float"
            )
            batch_embs = [item.embedding for item in resp.data]
            all_embeddings.extend(batch_embs)
        except Exception as e:
            print(f"Embedding batch error (items {i}{i+len(batch)-1}): {e}")
            all_embeddings.extend([[] for _ in batch])
    return all_embeddings

def encode_query(query: str) -> list[float] | None:
    """Generate a single embedding vector for a query string."""
    embs = embed_texts([query], batch_size=1)
    if embs and embs[0]:
        print("Query embedding (first 5 dims):", embs[0][:5])
        return embs[0]
    print("Failed to generate query embedding.")
    return None

async def rag_autism(query: str, top_k: int = 3) -> dict:
    """
    Run a RAG retrieval on the 'Books' collection in Weaviate.
    Returns up to `top_k` matching text chunks.
    """
    qe = encode_query(query)
    if not qe:
        return {"answer": []}

    try:
        with weaviate_client() as client:
            coll = client.collections.get("Books")
            res  = coll.query.near_vector(
                near_vector=qe,
                limit=top_k,
                return_properties=["text"]
            )
        if not getattr(res, "objects", None):
            return {"answer": []}
        return {
            "answer": [
                obj.properties.get("text", "[No Text]")
                for obj in res.objects
            ]
        }
    except Exception as e:
        print("RAG Error:", e)
        return {"answer": []}

# Example test harness
# if __name__ == "__main__":
#     test_queries = [
#         "What are the common early signs of autism in young children?",
#         "What diagnostic criteria are used for autism spectrum disorder?",
#         "What support strategies help improve communication skills in autistic individuals?"
#     ]
#     for q in test_queries:
#         print(f"\nQuery: {q}")
#         out = asyncio.run(rag_autism(q, top_k=3))
#         print("Retrieved contexts:")
#         for idx, ctx in enumerate(out["answer"], 1):
#             print(f"{idx}. {ctx}")