File size: 6,594 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import List
import re
import sqlparse


class TreeNode(object):
    def __init__(self, name=None, father=None):
        self.name: str = name
        self.rename: str = name
        self.father: TreeNode = father
        self.children: List = []
        self.produced_col_name_s = None

    def __eq__(self, other):
        return self.rename == other.rename

    def __hash__(self):
        return hash(self.rename)

    def set_name(self, name):
        self.name = name
        self.rename = name

    def add_child(self, child):
        self.children.append(child)
        child.father = self

    def rename_father_col(self, col_idx: int, col_prefix: str = "col_"):
        new_col_name = "{}{}".format(col_prefix, col_idx)
        self.father.rename = self.father.rename.replace(self.name, "{}".format(new_col_name))
        self.produced_col_name_s = [new_col_name]  # fixme when multiple outputs for a qa func

    def rename_father_val(self, val_names):
        if len(val_names) == 1:
            val_name = val_names[0]
            new_val_equals_str = "'{}'".format(val_name) if isinstance(convert_type(val_name), str) else "{}".format(
                val_name)
        else:
            new_val_equals_str = '({})'.format(', '.join(["'{}'".format(val_name) for val_name in val_names]))
        self.father.rename = self.father.rename.replace(self.name, new_val_equals_str)


def get_cfg_tree(nsql: str):
    """
    Parse QA() into a tree for execution guiding.
    @param nsql:
    @return:
    """

    stack: List = []  # Saving the state of the char.
    expression_stack: List = []  # Saving the state of the expression.
    current_tree_node = TreeNode(name=nsql)

    for idx in range(len(nsql)):
        if nsql[idx] == "(":
            stack.append(idx)
            if idx > 1 and nsql[idx - 2:idx + 1] == "QA(" and idx - 2 != 0:
                tree_node = TreeNode()
                current_tree_node.add_child(tree_node)
                expression_stack.append(current_tree_node)
                current_tree_node = tree_node
        elif nsql[idx] == ")":
            left_clause_idx = stack.pop()
            if idx > 1 and nsql[left_clause_idx - 2:left_clause_idx + 1] == "QA(" and left_clause_idx - 2 != 0:
                # the QA clause
                nsql_span = nsql[left_clause_idx - 2:idx + 1]
                current_tree_node.set_name(nsql_span)
                current_tree_node = expression_stack.pop()

    return current_tree_node


def get_steps(tree_node: TreeNode, steps: List):
    """Pred-Order Traversal"""
    for child in tree_node.children:
        get_steps(child, steps)
    steps.append(tree_node)


def parse_question_paras(nsql: str, qa_model):
    # We assume there's no nested qa inside when running this func
    nsql = nsql.strip(" ;")
    assert nsql[:3] == "QA(" and nsql[-1] == ")", "must start with QA( symbol and end with )"
    assert not "QA" in nsql[2:-1],  "must have no nested qa inside"

    # Get question and the left part(paras_raw_str)
    all_quote_idx = [i.start() for i in re.finditer('\"', nsql)]
    question = nsql[all_quote_idx[0] + 1: all_quote_idx[1]]
    paras_raw_str = nsql[all_quote_idx[1] + 1:-1].strip(" ;")

    # Split Parameters(SQL/column/value) from all parameters.
    paras = [_para.strip(' ;') for _para in sqlparse.split(paras_raw_str)]
    return question, paras


def convert_type(value):
    try:
        return eval(value)
    except Exception as e:
        return value


def nsql_role_recognize(nsql_like_str, all_headers, all_passage_titles, all_image_titles):
    """Recognize role. (SQL/column/value) """
    orig_nsql_like_str = nsql_like_str

    # strip the first and the last '`'
    if nsql_like_str.startswith('`') and nsql_like_str.endswith('`'):
        nsql_like_str = nsql_like_str[1:-1]

    # Case 1: if col in header, it is column type.
    if nsql_like_str in all_headers or nsql_like_str in list(map(lambda x: x.lower(), all_headers)):
        return 'col', orig_nsql_like_str

    # fixme: add case when the this nsql_like_str both in table headers, images title and in passages title.
    # Case 2.1: if it is title of certain passage.
    if (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
            and (nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles))):
        return "passage_title_and_image_title", orig_nsql_like_str
    else:
        try:
            nsql_like_str_evaled = str(eval(nsql_like_str))
            if (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles))) \
                    and (nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles))):
                return "passage_title_and_image_title", nsql_like_str_evaled
        except:
            pass

    # Case 2.2: if it is title of certain passage.
    if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
        return "passage_title", orig_nsql_like_str
    else:
        try:
            nsql_like_str_evaled = str(eval(nsql_like_str))
            if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_passage_titles)):
                return "passage_title", nsql_like_str_evaled
        except:
            pass

    # Case 2.3: if it is title of certain picture.
    if nsql_like_str.lower() in list(map(lambda x: x.lower(), all_image_titles)):
        return "image_title", orig_nsql_like_str
    else:
        try:
            nsql_like_str_evaled = str(eval(nsql_like_str))
            if nsql_like_str_evaled.lower() in list(map(lambda x: x.lower(), all_image_titles)):
                return "image_title", nsql_like_str_evaled
        except:
            pass

    # Case 4: if it can be parsed by eval(), it is value type.
    try:
        eval(nsql_like_str)
        return 'val', orig_nsql_like_str
    except Exception as e:
        pass

    # Case 5: else it should be the sql, if it isn't, exception will be raised.
    return 'complete_sql', orig_nsql_like_str


def remove_duplicate(original_list):
    no_duplicate_list = []
    [no_duplicate_list.append(i) for i in original_list if i not in no_duplicate_list]
    return no_duplicate_list


def extract_answers(sub_table):
    if not sub_table or sub_table['header'] is None:
        return []
    answer = []
    if 'row_id' in sub_table['header']:
        for _row in sub_table['rows']:
            answer.extend(_row[1:])
        return answer
    else:
        for _row in sub_table['rows']:
            answer.extend(_row)
        return answer