orionweller commited on
Commit
7743187
·
1 Parent(s): 9813925

try to fix ndcg bug

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -94,7 +94,7 @@ class RepLlamaModel:
94
  model.eval()
95
  return model
96
 
97
- def encode(self, texts, batch_size=32, **kwargs):
98
  self.model = self.model.cuda()
99
  all_embeddings = []
100
  for i in range(0, len(texts), batch_size):
@@ -108,6 +108,7 @@ class RepLlamaModel:
108
  outputs = self.model(**batch_dict)
109
  embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
110
  embeddings = F.normalize(embeddings, p=2, dim=-1)
 
111
  all_embeddings.append(embeddings.cpu().numpy())
112
 
113
  self.model = self.model.cpu()
@@ -118,7 +119,7 @@ def load_faiss_index(dataset_name):
118
  index_path = f"{dataset_name}/faiss_index.bin"
119
  if os.path.exists(index_path):
120
  logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
121
- return faiss.read_index(index_path, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
122
  return None
123
 
124
  def search_queries(dataset_name, q_reps, depth=1000):
@@ -126,16 +127,15 @@ def search_queries(dataset_name, q_reps, depth=1000):
126
  if faiss_index is None:
127
  raise ValueError(f"No FAISS index found for dataset {dataset_name}")
128
 
129
- # Ensure q_reps is a 2D numpy array of the correct type
130
- q_reps = np.ascontiguousarray(q_reps.astype('float16'))
131
 
132
  # Perform the search
133
  all_scores, all_indices = faiss_index.search(q_reps, depth)
134
 
135
- psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
 
136
 
137
- # Clean up
138
- del faiss_index
139
 
140
  return all_scores, np.array(psg_indices)
141
 
@@ -149,6 +149,7 @@ def load_corpus_lookups(dataset_name):
149
  with open(file, 'rb') as f:
150
  _, p_lookup = pickle.load(f)
151
  corpus_lookups[dataset_name] += p_lookup
 
152
 
153
  def load_queries(dataset_name):
154
  global queries, q_lookups, qrels
@@ -166,6 +167,9 @@ def load_queries(dataset_name):
166
  qrels[dataset_name][qrel.query_id] = {}
167
  qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
168
 
 
 
 
169
 
170
  def evaluate(qrels, results, k_values):
171
  evaluator = pytrec_eval.RelevanceEvaluator(
 
94
  model.eval()
95
  return model
96
 
97
+ def encode(self, texts, batch_size=16, **kwargs):
98
  self.model = self.model.cuda()
99
  all_embeddings = []
100
  for i in range(0, len(texts), batch_size):
 
108
  outputs = self.model(**batch_dict)
109
  embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
110
  embeddings = F.normalize(embeddings, p=2, dim=-1)
111
+ logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
112
  all_embeddings.append(embeddings.cpu().numpy())
113
 
114
  self.model = self.model.cpu()
 
119
  index_path = f"{dataset_name}/faiss_index.bin"
120
  if os.path.exists(index_path):
121
  logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
122
+ return faiss.read_index(index_path)
123
  return None
124
 
125
  def search_queries(dataset_name, q_reps, depth=1000):
 
127
  if faiss_index is None:
128
  raise ValueError(f"No FAISS index found for dataset {dataset_name}")
129
 
130
+ logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
 
131
 
132
  # Perform the search
133
  all_scores, all_indices = faiss_index.search(q_reps, depth)
134
 
135
+ logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
136
+ logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
137
 
138
+ psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
 
139
 
140
  return all_scores, np.array(psg_indices)
141
 
 
149
  with open(file, 'rb') as f:
150
  _, p_lookup = pickle.load(f)
151
  corpus_lookups[dataset_name] += p_lookup
152
+ logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
153
 
154
  def load_queries(dataset_name):
155
  global queries, q_lookups, qrels
 
167
  qrels[dataset_name][qrel.query_id] = {}
168
  qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
169
 
170
+ logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
171
+ logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
172
+
173
 
174
  def evaluate(qrels, results, k_values):
175
  evaluator = pytrec_eval.RelevanceEvaluator(