OpenFactCheck-Prerelease
/
src
/openfactcheck
/solvers
/tutorial
/search_engine_evidence_retriever.py
import time | |
import json | |
import openai | |
from typing import List, Dict, Any | |
from .utils.prompt_base import QGEN_PROMPT | |
from .utils.api import chatgpt, search_google, search_bing | |
from .utils.web_util import scrape_url, select_doc_by_keyword_coverage, select_passages_by_semantic_similarity | |
from openfactcheck import FactCheckerState, StandardTaskSolver, Solver | |
class SearchEngineEvidenceRetriever(StandardTaskSolver): | |
def __init__(self, args): | |
super().__init__(args) | |
self.search_engine = args.get("search_engine", "google") | |
self.search_engine_func = { | |
"google": search_google, | |
"bing": search_bing | |
}.get(self.search_engine, "google") | |
self.url_merge_method = args.get("url_merge_method", "union") | |
def __call__(self, state: FactCheckerState, *args, **kwargs): | |
claims = state.get(self.input_name) | |
queries = self.generate_questions_as_query(claims) | |
evidences = self.search_evidence(claims, queries) | |
state.set(self.output_name, evidences) | |
return True, state | |
# generate questions and queries based on a claim | |
def generate_questions_as_query(self, claims, | |
num_retries: int = 3) -> List[list]: | |
""" | |
num_retries: the number of retries when error occurs during openai api calling | |
""" | |
query_list = [] | |
for i, claim in enumerate(claims): | |
for _ in range(num_retries): | |
try: | |
response = chatgpt(QGEN_PROMPT + claim) | |
break | |
except openai.OpenAIError as exception: | |
print(f"{exception}. Retrying...") | |
time.sleep(1) | |
query_list.append(response) | |
# print(response) | |
# print("\n") | |
# convert openai output: a string into a list of questions/queries | |
# not check-worthy claims: query response is set as "", accordingly return a [] | |
# other responses are split into a list of questions/queries | |
automatic_query_list = [] | |
for query in query_list: | |
if query == "": | |
automatic_query_list.append([]) | |
else: | |
new_tmp = [] | |
tmp = query.split("\n") | |
for q in tmp: | |
q = q.strip() | |
if q == "" or q == "Output:": | |
continue | |
elif q[:6] == "Output": | |
q = q[7:].strip() | |
new_tmp.append(q) | |
automatic_query_list.append(new_tmp) | |
return automatic_query_list | |
# ---------------------------------------------------------- | |
# Evidence Retrieval | |
# ---------------------------------------------------------- | |
def collect_claim_url_list(self, queries: List[str]) -> List[str]: | |
""" | |
collect urls for a claim given the query list: | |
queries: a list of queries or questions for a claim | |
search_engine: use which search engine to retrieve evidence, google or bing | |
url_union_or_intersection: url operation, to merge all -> 'union' or obtain intersection | |
intersection urls tend to be what is not expected, less relevant | |
""" | |
if len(queries) == 0: | |
print("Invalid queries: []") | |
return None | |
urls_list: List[list] = [] # initial list of urls for all queries | |
url_query_dict: Dict[str, list] = {} # url as key, and list of queries corresponding to this url as value. | |
url_union, url_intersection = [], [] | |
for query in queries: | |
urls = self.search_engine_func(query) | |
urls_list.append(urls) | |
for i, urls in enumerate(urls_list): | |
for url in urls: | |
if url_query_dict.get(url) is None: | |
url_query_dict[url] = [queries[i]] | |
else: | |
url_query_dict[url] = url_query_dict[url] + [queries[i]] | |
if self.url_merge_method == "union": | |
for urls in urls_list: | |
url_union += urls | |
url_union = list(set(url_union)) | |
assert (len(url_union) == len(url_query_dict.keys())) | |
return list(url_query_dict.keys()), url_query_dict | |
elif self.url_merge_method == "intersection": | |
url_intersection = urls_list[0] | |
for urls in urls_list[1:]: | |
url_intersection = list(set(url_intersection).intersection(set(urls))) | |
return url_intersection, url_query_dict | |
else: | |
print("Invalid url operation, please choose from 'union' and 'intersection'.") | |
return None, url_query_dict | |
def search_evidence(self, | |
decontextualised_claims: List[str], | |
automatic_query_list: List[list], | |
path_save_evidence: str = "evidence.json", | |
save_web_text: bool = False) -> Dict[str, Dict[str, Any]]: | |
assert (len(decontextualised_claims) == len(automatic_query_list)) | |
claim_info: Dict[str, Dict[str, Any]] = {} | |
for i, claim in enumerate(decontextualised_claims): | |
queries = automatic_query_list[i] | |
if len(queries) == 0: | |
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []} | |
print("Claim: {} This is an opinion, not check-worthy.".format(claim)) | |
continue | |
# for each checkworthy claim, first gather urls of related web pages | |
urls, url_query_dict = self.collect_claim_url_list(queries) | |
docs: List[dict] = [] | |
for j, url in enumerate(urls): | |
web_text, _ = scrape_url(url) | |
if not web_text is None: | |
docs.append({"query": url_query_dict[url], "url": url, "web_text": web_text}) | |
else: | |
continue | |
print("Claim: {}\nWe retrieved {} urls, {} web pages are accessible.".format(claim, len(urls), len(docs))) | |
# we can directly use the first k of url_query_dict, as it is the list of google returned. | |
# Here, we select the most relevent top-k docs against the claim by keyword coverage | |
# return index of selected documents as the order in docs | |
if len(docs) != 0: | |
docs_text = [d['web_text'] for d in docs] | |
selected_docs_index = select_doc_by_keyword_coverage(claim, docs_text) | |
print(selected_docs_index) | |
else: | |
# no related web articles collected for this claim, continue to next claim | |
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []} | |
continue | |
selected_docs = [docs_text[i] for i in selected_docs_index] | |
# score corresponding passages and select the top-5 passages | |
# return the text of passages; and a list of doc ids for each passage. | |
# ids here is as the total number and order in selected_docs_index such as in [4, 25, 28, 32, 33] | |
topk_passages, passage_doc_id = select_passages_by_semantic_similarity(claim, selected_docs) | |
# recover doc_id to original index in docs which records detailed information of a doc | |
passage_doc_index = [] | |
for ids in passage_doc_id: | |
passage_doc_index.append([selected_docs_index[id] for id in ids]) | |
# evidence list | |
evidence_list: List[dict] = [] | |
for pid, p in enumerate(topk_passages): | |
doc_ids = passage_doc_index[pid] | |
if save_web_text: | |
evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p, | |
"query": [docs[doc_id]["query"] for doc_id in doc_ids], | |
"url": [docs[doc_id]["url"] for doc_id in doc_ids], | |
"web_text": [docs[doc_id]["web_text"] for doc_id in doc_ids], }) | |
else: | |
evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p, | |
"query": [docs[doc_id]["query"] for doc_id in doc_ids], | |
"url": [docs[doc_id]["url"] for doc_id in doc_ids], | |
"web_text": [], }) | |
claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": evidence_list} | |
# write to json file | |
# Serializing json | |
json_object = json.dumps(claim_info, indent=4) | |
# Writing to sample.json | |
with open(path_save_evidence, "w") as outfile: | |
outfile.write(json_object) | |
return claim_info | |