File size: 8,108 Bytes
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a01777
f6f97d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import copy
import os
import sqlite3
import records
import sqlalchemy
import pandas as pd
from typing import Dict, List
import uuid

from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table
from utils.mmqa.image_stuff import get_caption


def check_in_and_return(key: str, source: dict):
    # `` wrapped means as a whole
    if key.startswith("`") and key.endswith("`"):
        key = key[1:-1]
    if key in source.keys():
        return source[key]
    else:
        for _k, _v in source.items():
            if _k.lower() == key.lower():
                return _v
        raise ValueError("{} not in {}".format(key, source))


class NeuralDB(object):
    def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None):
        self.raw_tables = copy.deepcopy(tables)
        self.passages = {}
        self.images = {}
        self.image_captions = {}
        self.passage_linker = {}  # The links from cell value to passage
        self.image_linker = {}  # The links from cell value to images

        # Get passages
        if passages:
            for passage in passages:
                title, passage_content = passage['title'], passage['text']
                self.passages[title] = passage_content

        # Get images
        if images:
            for image in images:
                _id, title, picture = image['id'], image['title'], image['pic']
                self.images[title] = picture
                self.image_captions[title] = get_caption(_id)

        # Link grounding resources from other modalities(passages, images).
        if self.raw_tables[0]['table'].get('rows_with_links', None):
            rows = self.raw_tables[0]['table']['rows']
            rows_with_links = self.raw_tables[0]['table']['rows_with_links']

            link_title2cell_map = {}
            for row_id in range(len(rows)):
                for col_id in range(len(rows[row_id])):
                    cell = rows_with_links[row_id][col_id]
                    for text, title, url in zip(cell[0], cell[1], cell[2]):
                        text = text.lower().strip()
                        link_title2cell_map[title] = text

            # Link Passages
            for passage in passages:
                title, passage_content = passage['title'], passage['text']
                linked_cell = link_title2cell_map.get(title, None)
                if linked_cell:
                    self.passage_linker[linked_cell] = title

            # Images
            for image in images:
                title, picture = image['title'], image['pic']
                linked_cell = link_title2cell_map.get(title, None)
                if linked_cell:
                    self.image_linker[linked_cell] = title

        for table_info in tables:
            table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table'])

        self.tables = tables

        # Connect to SQLite database
        self.tmp_path = "tmp"
        os.makedirs(self.tmp_path, exist_ok=True)
        # self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time())))
        self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4()))
        self.sqlite_conn = sqlite3.connect(self.db_path, check_same_thread=False)

        # Create DB
        assert len(tables) >= 1, "DB has no table inside"
        table_0 = tables[0]
        if len(tables) > 1:
            raise ValueError("More than one table not support yet.")
        else:
            table_0["table"].to_sql("w", self.sqlite_conn)
            self.table_name = "w"
            self.table_title = table_0.get('title', None)

        # Records conn
        self.db = records.Database('sqlite:///{}'.format(self.db_path))
        self.records_conn = self.db.get_connection()

    def __str__(self):
        return str(self.execute_query("SELECT * FROM {}".format(self.table_name)))

    def get_table(self, table_name=None):
        table_name = self.table_name if not table_name else table_name
        sql_query = "SELECT * FROM {}".format(table_name)
        _table = self.execute_query(sql_query)
        return _table

    def get_header(self, table_name=None):
        _table = self.get_table(table_name)
        return _table['header']

    def get_rows(self, table_name):
        _table = self.get_table(table_name)
        return _table['rows']

    def get_table_df(self):
        return self.tables[0]['table']

    def get_table_raw(self):
        return self.raw_tables[0]['table']

    def get_table_title(self):
        return self.tables[0]['title']

    def get_passages_titles(self):
        return list(self.passages.keys())

    def get_images_titles(self):
        return list(self.images.keys())

    def get_passage_by_title(self, title: str):
        return check_in_and_return(title, self.passages)

    def get_image_by_title(self, title):
        return check_in_and_return(title, self.images)

    def get_image_caption_by_title(self, title):
        return check_in_and_return(title, self.image_captions)

    def get_image_linker(self):
        return copy.deepcopy(self.image_linker)

    def get_passage_linker(self):
        return copy.deepcopy(self.passage_linker)

    def execute_query(self, sql_query: str):
        """
        Basic operation. Execute the sql query on the database we hold.
        @param sql_query:
        @return:
        """
        # When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded).
        if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')):
            col_name = sql_query
            new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name)
            # Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value,
            # while for variable, no ''/"" surrounded, this sql will query for the column.
            out = self.records_conn.query(new_sql_query)
        # When the sql query wants all cols or col_id, which is no need for us to add 'row_id'.
        elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"):
            out = self.records_conn.query(sql_query)
        else:
            try:
                # SELECT row_id in addition, needed for result and old table alignment.
                new_sql_query = "SELECT row_id, " + sql_query[7:]
                out = self.records_conn.query(new_sql_query)
            except sqlalchemy.exc.OperationalError as e:
                # Execute normal SQL, and in this case the row_id is actually in no need.
                out = self.records_conn.query(sql_query)

        results = out.all()
        unmerged_results = []
        merged_results = []

        headers = out.dataset.headers
        for i in range(len(results)):
            unmerged_results.append(list(results[i].values()))
            merged_results.extend(results[i].values())

        return {"header": headers, "rows": unmerged_results}

    def add_sub_table(self, sub_table, table_name=None, verbose=True):
        """
        Add sub_table into the table.
        @return:
        """
        table_name = self.table_name if not table_name else table_name
        sql_query = "SELECT * FROM {}".format(table_name)
        oring_table = self.execute_query(sql_query)
        old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"])
        # concat the new column into old table
        sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header']))
        new_table = old_table.merge(sub_table_df_normed,
                                    how='left', on='row_id')  # do left join
        new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace',
                         index=False)
        if verbose:
            print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]),
                                                                          sub_table_df_normed.dtypes))