VictorTomas09 commited on
Commit
ff2408d
·
verified ·
1 Parent(s): e40d8f8

Create Evaluators

Browse files
Files changed (1) hide show
  1. Evaluators +93 -0
Evaluators ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+ import faiss
5
+ import torch
6
+ from datasets import load_dataset
7
+ import evaluate
8
+
9
+ # Import RAG setup and retrieval logic from app.py
10
+ from app import setup_rag, retrieve
11
+
12
+
13
+ def retrieval_recall(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
14
+ """
15
+ Compute raw Retrieval Recall@k on the first num_samples examples.
16
+ If rerank_k is set, also apply cross-encoder reranking.
17
+ """
18
+ hits = 0
19
+ for ex in dataset.select(range(num_samples)):
20
+ question = ex["question"]
21
+ gold_answers = ex["answers"]["text"]
22
+ # get top-k retrieved contexts
23
+ if rerank_k:
24
+ ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
25
+ else:
26
+ # skip reranking: use top-k directly
27
+ q_emb = embedder.encode([question], convert_to_numpy=True)
28
+ distances, idxs = index.search(q_emb, k)
29
+ ctxs = [passages[i] for i in idxs[0]]
30
+ # check if any gold span appears
31
+ if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
32
+ hits += 1
33
+ recall = hits / num_samples
34
+ print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
35
+ return recall
36
+
37
+
38
+ def retrieval_recall_answerable(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
39
+ """
40
+ Retrieval Recall@k evaluated only on answerable questions.
41
+ """
42
+ hits, total = 0, 0
43
+ for ex in dataset.select(range(num_samples)):
44
+ if not ex["answers"]["text"]:
45
+ continue
46
+ total += 1
47
+ question = ex["question"]
48
+ if rerank_k:
49
+ ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
50
+ else:
51
+ q_emb = embedder.encode([question], convert_to_numpy=True)
52
+ distances, idxs = index.search(q_emb, k)
53
+ ctxs = [passages[i] for i in idxs[0]]
54
+ if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
55
+ hits += 1
56
+ recall = hits / total if total > 0 else 0.0
57
+ print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
58
+ return recall
59
+
60
+
61
+ def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
62
+ """
63
+ End-to-end QA EM/F1 on answerable subset using the retrieve_and_answer logic.
64
+ """
65
+ squad_metric = evaluate.load("squad")
66
+ preds, refs = [], []
67
+ for ex in dataset.select(range(num_samples)):
68
+ if not ex["answers"]["text"]:
69
+ continue
70
+ qid = ex["id"]
71
+ # retrieve and generate
72
+ answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
73
+ preds.append({"id": qid, "prediction_text": answer})
74
+ refs.append({"id": qid, "answers": ex["answers"]})
75
+ results = squad_metric.compute(predictions=preds, references=refs)
76
+ print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
77
+ return results
78
+
79
+
80
+ def main():
81
+ # Setup RAG components
82
+ passages, embedder, reranker, index, qa_pipe = setup_rag()
83
+ # Load SQuAD v2 validation set
84
+ squad = load_dataset("rajpurkar/squad_v2", split="validation")
85
+
86
+ # Run evaluations
87
+ retrieval_recall(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
88
+ retrieval_recall_answerable(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
89
+ qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()