Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7743187
1
Parent(s):
9813925
try to fix ndcg bug
Browse files
app.py
CHANGED
@@ -94,7 +94,7 @@ class RepLlamaModel:
|
|
94 |
model.eval()
|
95 |
return model
|
96 |
|
97 |
-
def encode(self, texts, batch_size=
|
98 |
self.model = self.model.cuda()
|
99 |
all_embeddings = []
|
100 |
for i in range(0, len(texts), batch_size):
|
@@ -108,6 +108,7 @@ class RepLlamaModel:
|
|
108 |
outputs = self.model(**batch_dict)
|
109 |
embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
110 |
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
|
|
111 |
all_embeddings.append(embeddings.cpu().numpy())
|
112 |
|
113 |
self.model = self.model.cpu()
|
@@ -118,7 +119,7 @@ def load_faiss_index(dataset_name):
|
|
118 |
index_path = f"{dataset_name}/faiss_index.bin"
|
119 |
if os.path.exists(index_path):
|
120 |
logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
|
121 |
-
return faiss.read_index(index_path
|
122 |
return None
|
123 |
|
124 |
def search_queries(dataset_name, q_reps, depth=1000):
|
@@ -126,16 +127,15 @@ def search_queries(dataset_name, q_reps, depth=1000):
|
|
126 |
if faiss_index is None:
|
127 |
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
128 |
|
129 |
-
|
130 |
-
q_reps = np.ascontiguousarray(q_reps.astype('float16'))
|
131 |
|
132 |
# Perform the search
|
133 |
all_scores, all_indices = faiss_index.search(q_reps, depth)
|
134 |
|
135 |
-
|
|
|
136 |
|
137 |
-
|
138 |
-
del faiss_index
|
139 |
|
140 |
return all_scores, np.array(psg_indices)
|
141 |
|
@@ -149,6 +149,7 @@ def load_corpus_lookups(dataset_name):
|
|
149 |
with open(file, 'rb') as f:
|
150 |
_, p_lookup = pickle.load(f)
|
151 |
corpus_lookups[dataset_name] += p_lookup
|
|
|
152 |
|
153 |
def load_queries(dataset_name):
|
154 |
global queries, q_lookups, qrels
|
@@ -166,6 +167,9 @@ def load_queries(dataset_name):
|
|
166 |
qrels[dataset_name][qrel.query_id] = {}
|
167 |
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
|
168 |
|
|
|
|
|
|
|
169 |
|
170 |
def evaluate(qrels, results, k_values):
|
171 |
evaluator = pytrec_eval.RelevanceEvaluator(
|
|
|
94 |
model.eval()
|
95 |
return model
|
96 |
|
97 |
+
def encode(self, texts, batch_size=16, **kwargs):
|
98 |
self.model = self.model.cuda()
|
99 |
all_embeddings = []
|
100 |
for i in range(0, len(texts), batch_size):
|
|
|
108 |
outputs = self.model(**batch_dict)
|
109 |
embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
|
110 |
embeddings = F.normalize(embeddings, p=2, dim=-1)
|
111 |
+
logger.info(f"Encoded shape: {embeddings.shape}, Norm of first embedding: {torch.norm(embeddings[0]).item()}")
|
112 |
all_embeddings.append(embeddings.cpu().numpy())
|
113 |
|
114 |
self.model = self.model.cpu()
|
|
|
119 |
index_path = f"{dataset_name}/faiss_index.bin"
|
120 |
if os.path.exists(index_path):
|
121 |
logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
|
122 |
+
return faiss.read_index(index_path)
|
123 |
return None
|
124 |
|
125 |
def search_queries(dataset_name, q_reps, depth=1000):
|
|
|
127 |
if faiss_index is None:
|
128 |
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
129 |
|
130 |
+
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
|
|
131 |
|
132 |
# Perform the search
|
133 |
all_scores, all_indices = faiss_index.search(q_reps, depth)
|
134 |
|
135 |
+
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
136 |
+
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
137 |
|
138 |
+
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
|
|
139 |
|
140 |
return all_scores, np.array(psg_indices)
|
141 |
|
|
|
149 |
with open(file, 'rb') as f:
|
150 |
_, p_lookup = pickle.load(f)
|
151 |
corpus_lookups[dataset_name] += p_lookup
|
152 |
+
logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
|
153 |
|
154 |
def load_queries(dataset_name):
|
155 |
global queries, q_lookups, qrels
|
|
|
167 |
qrels[dataset_name][qrel.query_id] = {}
|
168 |
qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
|
169 |
|
170 |
+
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
171 |
+
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
172 |
+
|
173 |
|
174 |
def evaluate(qrels, results, k_values):
|
175 |
evaluator = pytrec_eval.RelevanceEvaluator(
|