wakeupmh commited on
Commit
92c1c48
·
1 Parent(s): 42d1dd5

fix: normalize faiss

Browse files
Files changed (1) hide show
  1. faiss_index/index.py +7 -6
faiss_index/index.py CHANGED
@@ -61,20 +61,21 @@ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
61
  torch.cuda.empty_cache()
62
 
63
  # Convert to numpy array and build FAISS index
64
- embeddings = np.array(embeddings)
65
  dimension = embeddings.shape[1]
66
 
67
- # Use more efficient index type
68
- index = faiss.IndexFlatIP(dimension) # Simple but efficient dot-product index
 
69
 
70
- # Normalize vectors to use dot product as similarity
71
- faiss.normalize_L2(embeddings)
72
  index.add(embeddings)
73
 
74
  # Create and save the dataset
75
  dataset = Dataset.from_dict({
76
  "text": texts,
77
- "embeddings": embeddings,
78
  "title": [p["title"] for p in papers]
79
  })
80
 
 
61
  torch.cuda.empty_cache()
62
 
63
  # Convert to numpy array and build FAISS index
64
+ embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type
65
  dimension = embeddings.shape[1]
66
 
67
+ # Normalize the vectors manually
68
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
69
+ embeddings = embeddings / norms
70
 
71
+ # Create FAISS index
72
+ index = faiss.IndexFlatIP(dimension)
73
  index.add(embeddings)
74
 
75
  # Create and save the dataset
76
  dataset = Dataset.from_dict({
77
  "text": texts,
78
+ "embeddings": embeddings.tolist(), # Convert to list for storage
79
  "title": [p["title"] for p in papers]
80
  })
81