Spaces:
Sleeping
Sleeping
Commit
·
079558d
1
Parent(s):
c837e28
speed
Browse files
app.py
CHANGED
@@ -42,6 +42,7 @@ qrels = {}
|
|
42 |
query2qid = {}
|
43 |
datasets = ["scifact"]
|
44 |
current_dataset = "scifact"
|
|
|
45 |
|
46 |
def log_system_info():
|
47 |
logger.info("System Information:")
|
@@ -125,7 +126,7 @@ class RepLlamaModel:
|
|
125 |
model.eval()
|
126 |
return model
|
127 |
|
128 |
-
def encode(self, texts, batch_size=
|
129 |
self.model = self.model.cuda()
|
130 |
all_embeddings = []
|
131 |
for i in range(0, len(texts), batch_size):
|
@@ -183,8 +184,7 @@ def initialize_faiss_and_corpus(dataset_name):
|
|
183 |
return index
|
184 |
|
185 |
def search_queries(dataset_name, q_reps, depth=1000):
|
186 |
-
faiss_index
|
187 |
-
|
188 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
189 |
|
190 |
# Perform the search
|
@@ -235,6 +235,9 @@ def evaluate(qrels, results, k_values):
|
|
235 |
logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}")
|
236 |
logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}")
|
237 |
|
|
|
|
|
|
|
238 |
return metrics
|
239 |
|
240 |
@spaces.GPU
|
@@ -298,6 +301,7 @@ def gradio_interface(dataset, postfix):
|
|
298 |
if model is None:
|
299 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
300 |
load_queries(current_dataset)
|
|
|
301 |
|
302 |
# Create Gradio interface
|
303 |
iface = gr.Interface(
|
|
|
42 |
query2qid = {}
|
43 |
datasets = ["scifact"]
|
44 |
current_dataset = "scifact"
|
45 |
+
faiss_index = None
|
46 |
|
47 |
def log_system_info():
|
48 |
logger.info("System Information:")
|
|
|
126 |
model.eval()
|
127 |
return model
|
128 |
|
129 |
+
def encode(self, texts, batch_size=32, **kwargs):
|
130 |
self.model = self.model.cuda()
|
131 |
all_embeddings = []
|
132 |
for i in range(0, len(texts), batch_size):
|
|
|
184 |
return index
|
185 |
|
186 |
def search_queries(dataset_name, q_reps, depth=1000):
|
187 |
+
global faiss_index
|
|
|
188 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
189 |
|
190 |
# Perform the search
|
|
|
235 |
logger.info(f"NDCG@{k}: mean={metrics[f'NDCG@{k}']}, min={min(ndcg_scores)}, max={max(ndcg_scores)}")
|
236 |
logger.info(f"Recall@{k}: mean={metrics[f'Recall@{k}']}, min={min(recall_scores)}, max={max(recall_scores)}")
|
237 |
|
238 |
+
# delete nDCG@100 and Recall@10
|
239 |
+
del metrics["NDCG@100"]
|
240 |
+
del metrics["Recall@100"]
|
241 |
return metrics
|
242 |
|
243 |
@spaces.GPU
|
|
|
301 |
if model is None:
|
302 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
303 |
load_queries(current_dataset)
|
304 |
+
faiss_index = initialize_faiss_and_corpus(current_dataset)
|
305 |
|
306 |
# Create Gradio interface
|
307 |
iface = gr.Interface(
|