import chromadb | |
class RetrievalDB: | |
def __init__(self, prompts, embeddings, solutions, collection_name="humaneval"): | |
self.client = chromadb.Client() | |
# Check if collection exists | |
try: | |
self.collection = self.client.get_collection(collection_name) | |
except Exception: | |
# If not, create it and populate | |
self.collection = self.client.create_collection(name=collection_name) | |
for idx, (emb, prompt, solution) in enumerate(zip(embeddings, prompts, solutions)): | |
self.collection.add( | |
ids=[str(idx)], | |
embeddings=[emb.tolist()], | |
metadatas=[{"prompt": prompt, "solution": solution}] | |
) | |
def retrieve_similar_context(self, query_emb, k=1): | |
results = self.collection.query(query_embeddings=[query_emb], n_results=k) | |
return results["metadatas"] |