NightPrince commited on
Commit
5eeaffc
·
verified ·
1 Parent(s): f70c622

Update pipeline/retrieval.py

Browse files
Files changed (1) hide show
  1. pipeline/retrieval.py +12 -7
pipeline/retrieval.py CHANGED
@@ -3,13 +3,18 @@ import chromadb
3
  class RetrievalDB:
4
  def __init__(self, prompts, embeddings, solutions, collection_name="humaneval"):
5
  self.client = chromadb.Client()
6
- self.collection = self.client.create_collection(name=collection_name)
7
- for idx, (emb, prompt, solution) in enumerate(zip(embeddings, prompts, solutions)):
8
- self.collection.add(
9
- ids=[str(idx)],
10
- embeddings=[emb.tolist()],
11
- metadatas=[{"prompt": prompt, "solution": solution}]
12
- )
 
 
 
 
 
13
 
14
  def retrieve_similar_context(self, query_emb, k=1):
15
  results = self.collection.query(query_embeddings=[query_emb], n_results=k)
 
3
  class RetrievalDB:
4
  def __init__(self, prompts, embeddings, solutions, collection_name="humaneval"):
5
  self.client = chromadb.Client()
6
+ # Check if collection exists
7
+ try:
8
+ self.collection = self.client.get_collection(collection_name)
9
+ except Exception:
10
+ # If not, create it and populate
11
+ self.collection = self.client.create_collection(name=collection_name)
12
+ for idx, (emb, prompt, solution) in enumerate(zip(embeddings, prompts, solutions)):
13
+ self.collection.add(
14
+ ids=[str(idx)],
15
+ embeddings=[emb.tolist()],
16
+ metadatas=[{"prompt": prompt, "solution": solution}]
17
+ )
18
 
19
  def retrieve_similar_context(self, query_emb, k=1):
20
  results = self.collection.query(query_embeddings=[query_emb], n_results=k)