VeganSquirrel commited on
Commit
897a5c1
·
verified ·
1 Parent(s): 576d291
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -45,8 +45,24 @@ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="balanced",
45
 
46
  # Step 4: Define the Retrieval Function
47
  def retrieve_documents(query, top_k=3):
48
- query_embedding = np.mean([embeddings[i] for i in range(len(metadata)) if query.lower() in metadata[i].lower()], axis=0)
49
- distances, indices = index.search(np.array([query_embedding]), top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  retrieved_docs = [metadata[idx] for idx in indices[0]]
51
  return retrieved_docs
52
 
 
45
 
46
  # Step 4: Define the Retrieval Function
47
  def retrieve_documents(query, top_k=3):
48
+ # Find embeddings matching the query
49
+ matched_embeddings = [embeddings[i] for i in range(len(metadata)) if query.lower() in metadata[i].lower()]
50
+
51
+ # If no matches found, set a default query embedding
52
+ if matched_embeddings:
53
+ query_embedding = np.mean(matched_embeddings, axis=0)
54
+ else:
55
+ # Fallback: use the mean of all embeddings as a default embedding
56
+ query_embedding = np.mean(embeddings, axis=0)
57
+ print("No exact matches found for query. Using default query embedding.")
58
+
59
+ # Reshape query_embedding to match FAISS expected shape (1, d)
60
+ query_embedding = query_embedding.reshape(1, -1)
61
+
62
+ # Perform the search
63
+ distances, indices = index.search(query_embedding, top_k)
64
+
65
+ # Retrieve document metadata based on indices
66
  retrieved_docs = [metadata[idx] for idx in indices[0]]
67
  return retrieved_docs
68