Spaces:
Runtime error
Runtime error
File size: 3,459 Bytes
f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
Retriever to retrieve relevant examples from annotations.
"""
import copy
from typing import Dict, List, Tuple, Any
import nltk
from nltk.stem import SnowballStemmer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from utils.normalizer import normalize
from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem
class OpenAIQARetriever(object):
def __init__(self, retrieve_pool: OpenAIQARetrievePool):
self.retrieve_pool = retrieve_pool
@staticmethod
def _string_bleu(q1: str, q2: str, stop_words=None, stemmer=None):
"""
BLEU score.
"""
q1, q2 = normalize(q1), normalize(q2)
reference = [[tk for tk in nltk.word_tokenize(q1)]]
candidate = [tk for tk in nltk.word_tokenize(q2)]
if stemmer is not None:
reference = [[stemmer.stem(tk) for tk in reference[0]]]
candidate = [stemmer.stem(tk) for tk in candidate]
chencherry_smooth = SmoothingFunction() # bleu smooth to avoid hard behaviour when no ngram overlaps
bleu_score = sentence_bleu(
reference,
candidate,
weights=(0.25, 0.3, 0.3, 0.15),
smoothing_function=chencherry_smooth.method1
)
return bleu_score
def _qh2qh_similarity(
self,
item: QAItem,
num_retrieve_samples: int,
score_func: str,
qa_type: str,
weight_h: float = 0.2,
verbose: bool = False
):
"""
Retrieve top K nsqls based on query&header to query&header similarities.
"""
q = item.qa_question
header_wo_row_id = copy.copy(item.table['header'])
header_wo_row_id.remove('row_id')
h = ' '.join(header_wo_row_id)
stemmer = SnowballStemmer('english')
if score_func == 'bleu':
retrieve_q_list = [(d, self._string_bleu(q, d.qa_question.split('@')[1], stemmer=stemmer))
for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
retrieve_h_list = [(d, self._string_bleu(h, ' '.join(d.table['header']), stemmer=stemmer))
for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type]
retrieve_list = [(retrieve_q_list[idx][0], retrieve_q_list[idx][1] + weight_h * retrieve_h_list[idx][1])
for idx in range(len(retrieve_q_list))]
else:
raise ValueError
retrieve_list = sorted(retrieve_list, key=lambda x: x[1], reverse=True)
retrieve_list = list(map(lambda x: x[0], retrieve_list))[:num_retrieve_samples]
if verbose:
print(retrieve_list)
return retrieve_list
def retrieve(
self,
item: QAItem,
num_shots: int,
method: str = 'qh2qh_bleu',
qa_type: str = 'map',
verbose: bool = False
) -> List[QAItem]:
"""
Retrieve a list of relevant QA samples.
"""
if method == 'qh2qh_bleu':
retrieved_items = self._qh2qh_similarity(
item=item,
num_retrieve_samples=num_shots,
score_func='bleu',
qa_type=qa_type,
verbose=verbose
)
return retrieved_items
else:
raise ValueError(f'Retrieve method {method} is not supported.')
|