File size: 8,909 Bytes
8360ec7
ec53a03
 
8360ec7
ec53a03
 
 
8360ec7
 
ec53a03
8360ec7
ec53a03
8360ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import time
import json
import openai
from typing import List, Dict, Any

from .utils.prompt_base import QGEN_PROMPT
from .utils.api import chatgpt, search_google, search_bing
from .utils.web_util import scrape_url, select_doc_by_keyword_coverage, select_passages_by_semantic_similarity

from openfactcheck import FactCheckerState, StandardTaskSolver, Solver

@Solver.register("search_engine_evidence_retriever", "claims", "evidences")
class SearchEngineEvidenceRetriever(StandardTaskSolver):
    def __init__(self, args):
        super().__init__(args)
        self.search_engine = args.get("search_engine", "google")
        self.search_engine_func = {
            "google": search_google,
            "bing": search_bing
        }.get(self.search_engine, "google")

        self.url_merge_method = args.get("url_merge_method", "union")

    def __call__(self, state: FactCheckerState, *args, **kwargs):
        claims = state.get(self.input_name)
        queries = self.generate_questions_as_query(claims)
        evidences = self.search_evidence(claims, queries)
        state.set(self.output_name, evidences)
        return True, state

    # generate questions and queries based on a claim
    def generate_questions_as_query(self, claims,
                                    num_retries: int = 3) -> List[list]:
        """
        num_retries: the number of retries when error occurs during openai api calling
        """
        query_list = []
        for i, claim in enumerate(claims):
            for _ in range(num_retries):
                try:
                    response = chatgpt(QGEN_PROMPT + claim)
                    break
                except openai.OpenAIError as exception:
                    print(f"{exception}. Retrying...")
                    time.sleep(1)
            query_list.append(response)
            # print(response)
            # print("\n")

        # convert openai output: a string into a list of questions/queries
        # not check-worthy claims: query response is set as "", accordingly return a []
        # other responses are split into a list of questions/queries
        automatic_query_list = []
        for query in query_list:
            if query == "":
                automatic_query_list.append([])
            else:
                new_tmp = []
                tmp = query.split("\n")
                for q in tmp:
                    q = q.strip()
                    if q == "" or q == "Output:":
                        continue
                    elif q[:6] == "Output":
                        q = q[7:].strip()
                    new_tmp.append(q)
                automatic_query_list.append(new_tmp)

        return automatic_query_list

    # ----------------------------------------------------------
    # Evidence Retrieval
    # ----------------------------------------------------------
    def collect_claim_url_list(self, queries: List[str]) -> List[str]:
        """
        collect urls for a claim given the query list:
        queries: a list of queries or questions for a claim
        search_engine: use which search engine to retrieve evidence, google or bing
        url_union_or_intersection: url operation, to merge all -> 'union' or obtain intersection
        intersection urls tend to be what is not expected, less relevant
        """
        if len(queries) == 0:
            print("Invalid queries: []")
            return None

        urls_list: List[list] = []  # initial list of urls for all queries
        url_query_dict: Dict[str, list] = {}  # url as key, and list of queries corresponding to this url as value.
        url_union, url_intersection = [], []

        for query in queries:
            urls = self.search_engine_func(query)
            urls_list.append(urls)

        for i, urls in enumerate(urls_list):
            for url in urls:
                if url_query_dict.get(url) is None:
                    url_query_dict[url] = [queries[i]]
                else:
                    url_query_dict[url] = url_query_dict[url] + [queries[i]]

        if self.url_merge_method == "union":
            for urls in urls_list:
                url_union += urls
            url_union = list(set(url_union))
            assert (len(url_union) == len(url_query_dict.keys()))
            return list(url_query_dict.keys()), url_query_dict
        elif self.url_merge_method == "intersection":
            url_intersection = urls_list[0]
            for urls in urls_list[1:]:
                url_intersection = list(set(url_intersection).intersection(set(urls)))
            return url_intersection, url_query_dict
        else:
            print("Invalid url operation, please choose from 'union' and 'intersection'.")
            return None, url_query_dict

    def search_evidence(self,
                        decontextualised_claims: List[str],
                        automatic_query_list: List[list],
                        path_save_evidence: str = "evidence.json",
                        save_web_text: bool = False) -> Dict[str, Dict[str, Any]]:

        assert (len(decontextualised_claims) == len(automatic_query_list))

        claim_info: Dict[str, Dict[str, Any]] = {}
        for i, claim in enumerate(decontextualised_claims):
            queries = automatic_query_list[i]
            if len(queries) == 0:
                claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []}
                print("Claim: {} This is an opinion, not check-worthy.".format(claim))
                continue

            # for each checkworthy claim, first gather urls of related web pages
            urls, url_query_dict = self.collect_claim_url_list(queries)

            docs: List[dict] = []
            for j, url in enumerate(urls):
                web_text, _ = scrape_url(url)
                if not web_text is None:
                    docs.append({"query": url_query_dict[url], "url": url, "web_text": web_text})
                else:
                    continue
            print("Claim: {}\nWe retrieved {} urls, {} web pages are accessible.".format(claim, len(urls), len(docs)))

            # we can directly use the first k of url_query_dict, as it is the list of google returned.
            # Here, we select the most relevent top-k docs against the claim by keyword coverage
            # return index of selected documents as the order in docs
            if len(docs) != 0:
                docs_text = [d['web_text'] for d in docs]
                selected_docs_index = select_doc_by_keyword_coverage(claim, docs_text)
                print(selected_docs_index)
            else:
                # no related web articles collected for this claim, continue to next claim
                claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": []}
                continue

            selected_docs = [docs_text[i] for i in selected_docs_index]
            # score corresponding passages and select the top-5 passages
            # return the text of passages; and a list of doc ids for each passage.
            # ids here is as the total number and order in selected_docs_index such as in [4, 25, 28, 32, 33]
            topk_passages, passage_doc_id = select_passages_by_semantic_similarity(claim, selected_docs)

            # recover doc_id to original index in docs which records detailed information of a doc
            passage_doc_index = []
            for ids in passage_doc_id:
                passage_doc_index.append([selected_docs_index[id] for id in ids])

            # evidence list
            evidence_list: List[dict] = []
            for pid, p in enumerate(topk_passages):
                doc_ids = passage_doc_index[pid]
                if save_web_text:
                    evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p,
                                          "query": [docs[doc_id]["query"] for doc_id in doc_ids],
                                          "url": [docs[doc_id]["url"] for doc_id in doc_ids],
                                          "web_text": [docs[doc_id]["web_text"] for doc_id in doc_ids], })
                else:
                    evidence_list.append({"evidence_id": pid, "web_page_snippet_manual": p,
                                          "query": [docs[doc_id]["query"] for doc_id in doc_ids],
                                          "url": [docs[doc_id]["url"] for doc_id in doc_ids],
                                          "web_text": [], })
            claim_info[claim] = {"claim": claim, "automatic_queries": queries, "evidence_list": evidence_list}

        # write to json file
        # Serializing json
        json_object = json.dumps(claim_info, indent=4)

        # Writing to sample.json
        with open(path_save_evidence, "w") as outfile:
            outfile.write(json_object)
        return claim_info