File size: 3,459 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Retriever to retrieve relevant examples from annotations.
"""

import copy
from typing import Dict, List, Tuple, Any
import nltk
from nltk.stem import SnowballStemmer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

from utils.normalizer import normalize
from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem


class OpenAIQARetriever(object):
    def __init__(self, retrieve_pool: OpenAIQARetrievePool):
        self.retrieve_pool = retrieve_pool

    @staticmethod
    def _string_bleu(q1: str, q2: str, stop_words=None, stemmer=None):
        """
        BLEU score.
        """
        q1, q2 = normalize(q1), normalize(q2)
        reference = [[tk for tk in nltk.word_tokenize(q1)]]
        candidate = [tk for tk in nltk.word_tokenize(q2)]
        if stemmer is not None:
            reference = [[stemmer.stem(tk) for tk in reference[0]]]
            candidate = [stemmer.stem(tk) for tk in candidate]

        chencherry_smooth = SmoothingFunction()  # bleu smooth to avoid hard behaviour when no ngram overlaps
        bleu_score = sentence_bleu(
            reference,
            candidate,
            weights=(0.25, 0.3, 0.3, 0.15),
            smoothing_function=chencherry_smooth.method1
        )
        return bleu_score

    def _qh2qh_similarity(
            self,
            item: QAItem,
            num_retrieve_samples: int,
            score_func: str,
            qa_type: str,
            weight_h: float = 0.2,
            verbose: bool = False
    ):
        """
        Retrieve top K nsqls based on query&header to query&header similarities.
        """
        q = item.qa_question
        header_wo_row_id = copy.copy(item.table['header'])
        header_wo_row_id.remove('row_id')
        h = ' '.join(header_wo_row_id)
        stemmer = SnowballStemmer('english')
        if score_func == 'bleu':
            retrieve_q_list = [(d, self._string_bleu(q, d.qa_question.split('@')[1], stemmer=stemmer))
                               for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
            retrieve_h_list = [(d, self._string_bleu(h, ' '.join(d.table['header']), stemmer=stemmer))
                               for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
            retrieve_list = [(retrieve_q_list[idx][0], retrieve_q_list[idx][1] + weight_h * retrieve_h_list[idx][1])
                             for idx in range(len(retrieve_q_list))]
        else:
            raise ValueError
        retrieve_list = sorted(retrieve_list, key=lambda x: x[1], reverse=True)
        retrieve_list = list(map(lambda x: x[0], retrieve_list))[:num_retrieve_samples]

        if verbose:
            print(retrieve_list)

        return retrieve_list

    def retrieve(
            self,
            item: QAItem,
            num_shots: int,
            method: str = 'qh2qh_bleu',
            qa_type: str = 'map',
            verbose: bool = False
    ) -> List[QAItem]:
        """
        Retrieve a list of relevant QA samples.
        """
        if method == 'qh2qh_bleu':
            retrieved_items = self._qh2qh_similarity(
                item=item,
                num_retrieve_samples=num_shots,
                score_func='bleu',
                qa_type=qa_type,
                verbose=verbose
            )
            return retrieved_items
        else:
            raise ValueError(f'Retrieve method {method} is not supported.')