OpenFactCheck-Prerelease
/
src
/openfactcheck
/solvers
/rarr_solvers
/rarr_utils
/evidence_selection.py
| import itertools | |
| from typing import Any, Dict, List | |
| import torch | |
| from sentence_transformers import CrossEncoder | |
| PASSAGE_RANKER = CrossEncoder( | |
| "cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| max_length=512, | |
| device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ) | |
| def compute_score_matrix( | |
| questions: List[str], evidences: List[str] | |
| ) -> List[List[float]]: | |
| """Scores the relevance of all evidence against all questions using a CrossEncoder. | |
| Args: | |
| questions: A list of unique questions. | |
| evidences: A list of unique evidences. | |
| Returns: | |
| score_matrix: A 2D list list of question X evidence relevance scores. | |
| """ | |
| score_matrix = [] | |
| for q in questions: | |
| evidence_scores = PASSAGE_RANKER.predict([(q, e) for e in evidences]).tolist() | |
| score_matrix.append(evidence_scores) | |
| return score_matrix | |
| def question_coverage_objective_fn( | |
| score_matrix: List[List[float]], evidence_indices: List[int] | |
| ) -> float: | |
| """Given (query, evidence) scores and a subset of evidence, return the coverage. | |
| Given all pairwise query and evidence scores, and a subset of the evidence | |
| specified by indices, return a value indicating how well this subset of evidence | |
| covers (i.e., helps answer) all questions. | |
| Args: | |
| score_matrix: A 2D list list of question X evidence relevance scores. | |
| evidence_indicies: A subset of the evidence to to get the coverage score of. | |
| Returns: | |
| total: The coverage we would get by using the subset of evidence in | |
| `evidence_indices` over all questions. | |
| """ | |
| # Compute sum_{question q} max_{selected evidence e} score(q, e). | |
| # This encourages all questions to be explained by at least one evidence. | |
| total = 0.0 | |
| for scores_for_question in score_matrix: | |
| total += max(scores_for_question[j] for j in evidence_indices) | |
| return total | |
| def select_evidences( | |
| example: Dict[str, Any], max_selected: int = 5, prefer_fewer: bool = False | |
| ) -> List[Dict[str, Any]]: | |
| """Selects the set of evidence that maximizes information converage over the claim. | |
| Args: | |
| example: The result of running the editing pipeline on one claim. | |
| max_selected: Maximum number of evidences to select. | |
| prefer_fewer: If True and the maximum objective value can be achieved by | |
| fewer evidences than `max_selected`, prefer selecting fewer evidences. | |
| Returns: | |
| selected_evidences: Selected evidences that serve as the attribution report. | |
| """ | |
| questions = sorted(set(example["questions"])) | |
| evidences = sorted(set(e["text"] for e in example["revisions"][0]["evidences"])) | |
| num_evidences = len(evidences) | |
| if not num_evidences: | |
| return [] | |
| score_matrix = compute_score_matrix(questions, evidences) | |
| best_combo = tuple() | |
| best_objective_value = float("-inf") | |
| max_selected = min(max_selected, num_evidences) | |
| min_selected = 1 if prefer_fewer else max_selected | |
| for num_selected in range(min_selected, max_selected + 1): | |
| for combo in itertools.combinations(range(num_evidences), num_selected): | |
| objective_value = question_coverage_objective_fn(score_matrix, combo) | |
| if objective_value > best_objective_value: | |
| best_combo = combo | |
| best_objective_value = objective_value | |
| selected_evidences = [{"text": evidences[idx]} for idx in best_combo] | |
| return selected_evidences | |