orionweller commited on
Commit
079558d
·
1 Parent(s): c837e28
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -42,6 +42,7 @@ qrels = {}
42
  query2qid = {}
43
  datasets = ["scifact"]
44
  current_dataset = "scifact"
 
45
 
46
  def log_system_info():
47
  logger.info("System Information:")
@@ -125,7 +126,7 @@ class RepLlamaModel:
125
  model.eval()
126
  return model
127
 
128
- def encode(self, texts, batch_size=16, **kwargs):
129
  self.model = self.model.cuda()
130
  all_embeddings = []
131
  for i in range(0, len(texts), batch_size):
@@ -183,8 +184,7 @@ def initialize_faiss_and_corpus(dataset_name):
183
  return index
184
 
185
  def search_queries(dataset_name, q_reps, depth=1000):
186
- faiss_index = initialize_faiss_and_corpus(dataset_name)
187
-
188
  logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
189
 
190
  # Perform the search
@@ -235,6 +235,9 @@ def evaluate(qrels, results, k_values):
235
  logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}")
236
  logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}")
237
 
 
 
 
238
  return metrics
239
 
240
  @spaces.GPU
@@ -298,6 +301,7 @@ def gradio_interface(dataset, postfix):
298
  if model is None:
299
  model = RepLlamaModel(model_name_or_path=CUR_MODEL)
300
  load_queries(current_dataset)
 
301
 
302
  # Create Gradio interface
303
  iface = gr.Interface(
 
42
  query2qid = {}
43
  datasets = ["scifact"]
44
  current_dataset = "scifact"
45
+ faiss_index = None
46
 
47
  def log_system_info():
48
  logger.info("System Information:")
 
126
  model.eval()
127
  return model
128
 
129
+ def encode(self, texts, batch_size=32, **kwargs):
130
  self.model = self.model.cuda()
131
  all_embeddings = []
132
  for i in range(0, len(texts), batch_size):
 
184
  return index
185
 
186
  def search_queries(dataset_name, q_reps, depth=1000):
187
+ global faiss_index
 
188
  logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
189
 
190
  # Perform the search
 
235
  logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}")
236
  logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}")
237
 
238
+ # delete nDCG@100 and Recall@10
239
+ del metrics["NDCG@100"]
240
+ del metrics["Recall@100"]
241
  return metrics
242
 
243
  @spaces.GPU
 
301
  if model is None:
302
  model = RepLlamaModel(model_name_or_path=CUR_MODEL)
303
  load_queries(current_dataset)
304
+ faiss_index = initialize_faiss_and_corpus(current_dataset)
305
 
306
  # Create Gradio interface
307
  iface = gr.Interface(