File size: 921 Bytes
fec11df
 
 
 
 
5eeaffc
 
 
 
 
 
 
 
 
 
 
 
fec11df
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import chromadb

class RetrievalDB:
    def __init__(self, prompts, embeddings, solutions, collection_name="humaneval"):
        self.client = chromadb.Client()
        # Check if collection exists
        try:
            self.collection = self.client.get_collection(collection_name)
        except Exception:
            # If not, create it and populate
            self.collection = self.client.create_collection(name=collection_name)
            for idx, (emb, prompt, solution) in enumerate(zip(embeddings, prompts, solutions)):
                self.collection.add(
                    ids=[str(idx)],
                    embeddings=[emb.tolist()],
                    metadatas=[{"prompt": prompt, "solution": solution}]
                )

    def retrieve_similar_context(self, query_emb, k=1):
        results = self.collection.query(query_embeddings=[query_emb], n_results=k)
        return results["metadatas"]