File size: 3,608 Bytes
ff2408d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import pickle
import numpy as np
import faiss
import torch
from datasets import load_dataset
import evaluate

# Import RAG setup and retrieval logic from app.py
from app import setup_rag, retrieve


def retrieval_recall(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
    """
    Compute raw Retrieval Recall@k on the first num_samples examples.
    If rerank_k is set, also apply cross-encoder reranking.
    """
    hits = 0
    for ex in dataset.select(range(num_samples)):
        question = ex["question"]
        gold_answers = ex["answers"]["text"]
        # get top-k retrieved contexts
        if rerank_k:
            ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
        else:
            # skip reranking: use top-k directly
            q_emb = embedder.encode([question], convert_to_numpy=True)
            distances, idxs = index.search(q_emb, k)
            ctxs = [passages[i] for i in idxs[0]]
        # check if any gold span appears
        if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers):
            hits += 1
    recall = hits / num_samples
    print(f"Retrieval Recall@{k} (rerank_k={rerank_k}): {recall:.3f} ({hits}/{num_samples})")
    return recall


def retrieval_recall_answerable(dataset, passages, embedder, index, k=20, rerank_k=None, num_samples=100):
    """
    Retrieval Recall@k evaluated only on answerable questions.
    """
    hits, total = 0, 0
    for ex in dataset.select(range(num_samples)):
        if not ex["answers"]["text"]:
            continue
        total += 1
        question = ex["question"]
        if rerank_k:
            ctxs, _ = retrieve(question, passages, embedder, index, k=k, rerank_k=rerank_k)
        else:
            q_emb = embedder.encode([question], convert_to_numpy=True)
            distances, idxs = index.search(q_emb, k)
            ctxs = [passages[i] for i in idxs[0]]
        if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]):
            hits += 1
    recall = hits / total if total > 0 else 0.0
    print(f"Retrieval Recall@{k} on answerable only (rerank_k={rerank_k}): {recall:.3f} ({hits}/{total})")
    return recall


def qa_eval_answerable(dataset, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100):
    """
    End-to-end QA EM/F1 on answerable subset using the retrieve_and_answer logic.
    """
    squad_metric = evaluate.load("squad")
    preds, refs = [], []
    for ex in dataset.select(range(num_samples)):
        if not ex["answers"]["text"]:
            continue
        qid = ex["id"]
        # retrieve and generate
        answer, _ = retrieve_and_answer(ex["question"], passages, embedder, reranker, index, qa_pipe)
        preds.append({"id": qid, "prediction_text": answer})
        refs.append({"id": qid, "answers": ex["answers"]})
    results = squad_metric.compute(predictions=preds, references=refs)
    print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}")
    return results


def main():
    # Setup RAG components
    passages, embedder, reranker, index, qa_pipe = setup_rag()
    # Load SQuAD v2 validation set
    squad = load_dataset("rajpurkar/squad_v2", split="validation")

    # Run evaluations
    retrieval_recall(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
    retrieval_recall_answerable(squad, passages, embedder, index, k=20, rerank_k=5, num_samples=100)
    qa_eval_answerable(squad, passages, embedder, reranker, index, qa_pipe, k=20, num_samples=100)


if __name__ == "__main__":
    main()