Spaces:
Runtime error
Runtime error
File size: 8,108 Bytes
f6f97d8 9a01777 f6f97d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import copy
import os
import sqlite3
import records
import sqlalchemy
import pandas as pd
from typing import Dict, List
import uuid
from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table
from utils.mmqa.image_stuff import get_caption
def check_in_and_return(key: str, source: dict):
# `` wrapped means as a whole
if key.startswith("`") and key.endswith("`"):
key = key[1:-1]
if key in source.keys():
return source[key]
else:
for _k, _v in source.items():
if _k.lower() == key.lower():
return _v
raise ValueError("{} not in {}".format(key, source))
class NeuralDB(object):
def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None):
self.raw_tables = copy.deepcopy(tables)
self.passages = {}
self.images = {}
self.image_captions = {}
self.passage_linker = {} # The links from cell value to passage
self.image_linker = {} # The links from cell value to images
# Get passages
if passages:
for passage in passages:
title, passage_content = passage['title'], passage['text']
self.passages[title] = passage_content
# Get images
if images:
for image in images:
_id, title, picture = image['id'], image['title'], image['pic']
self.images[title] = picture
self.image_captions[title] = get_caption(_id)
# Link grounding resources from other modalities(passages, images).
if self.raw_tables[0]['table'].get('rows_with_links', None):
rows = self.raw_tables[0]['table']['rows']
rows_with_links = self.raw_tables[0]['table']['rows_with_links']
link_title2cell_map = {}
for row_id in range(len(rows)):
for col_id in range(len(rows[row_id])):
cell = rows_with_links[row_id][col_id]
for text, title, url in zip(cell[0], cell[1], cell[2]):
text = text.lower().strip()
link_title2cell_map[title] = text
# Link Passages
for passage in passages:
title, passage_content = passage['title'], passage['text']
linked_cell = link_title2cell_map.get(title, None)
if linked_cell:
self.passage_linker[linked_cell] = title
# Images
for image in images:
title, picture = image['title'], image['pic']
linked_cell = link_title2cell_map.get(title, None)
if linked_cell:
self.image_linker[linked_cell] = title
for table_info in tables:
table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table'])
self.tables = tables
# Connect to SQLite database
self.tmp_path = "tmp"
os.makedirs(self.tmp_path, exist_ok=True)
# self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time())))
self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4()))
self.sqlite_conn = sqlite3.connect(self.db_path, check_same_thread=False)
# Create DB
assert len(tables) >= 1, "DB has no table inside"
table_0 = tables[0]
if len(tables) > 1:
raise ValueError("More than one table not support yet.")
else:
table_0["table"].to_sql("w", self.sqlite_conn)
self.table_name = "w"
self.table_title = table_0.get('title', None)
# Records conn
self.db = records.Database('sqlite:///{}'.format(self.db_path))
self.records_conn = self.db.get_connection()
def __str__(self):
return str(self.execute_query("SELECT * FROM {}".format(self.table_name)))
def get_table(self, table_name=None):
table_name = self.table_name if not table_name else table_name
sql_query = "SELECT * FROM {}".format(table_name)
_table = self.execute_query(sql_query)
return _table
def get_header(self, table_name=None):
_table = self.get_table(table_name)
return _table['header']
def get_rows(self, table_name):
_table = self.get_table(table_name)
return _table['rows']
def get_table_df(self):
return self.tables[0]['table']
def get_table_raw(self):
return self.raw_tables[0]['table']
def get_table_title(self):
return self.tables[0]['title']
def get_passages_titles(self):
return list(self.passages.keys())
def get_images_titles(self):
return list(self.images.keys())
def get_passage_by_title(self, title: str):
return check_in_and_return(title, self.passages)
def get_image_by_title(self, title):
return check_in_and_return(title, self.images)
def get_image_caption_by_title(self, title):
return check_in_and_return(title, self.image_captions)
def get_image_linker(self):
return copy.deepcopy(self.image_linker)
def get_passage_linker(self):
return copy.deepcopy(self.passage_linker)
def execute_query(self, sql_query: str):
"""
Basic operation. Execute the sql query on the database we hold.
@param sql_query:
@return:
"""
# When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded).
if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')):
col_name = sql_query
new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name)
# Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value,
# while for variable, no ''/"" surrounded, this sql will query for the column.
out = self.records_conn.query(new_sql_query)
# When the sql query wants all cols or col_id, which is no need for us to add 'row_id'.
elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"):
out = self.records_conn.query(sql_query)
else:
try:
# SELECT row_id in addition, needed for result and old table alignment.
new_sql_query = "SELECT row_id, " + sql_query[7:]
out = self.records_conn.query(new_sql_query)
except sqlalchemy.exc.OperationalError as e:
# Execute normal SQL, and in this case the row_id is actually in no need.
out = self.records_conn.query(sql_query)
results = out.all()
unmerged_results = []
merged_results = []
headers = out.dataset.headers
for i in range(len(results)):
unmerged_results.append(list(results[i].values()))
merged_results.extend(results[i].values())
return {"header": headers, "rows": unmerged_results}
def add_sub_table(self, sub_table, table_name=None, verbose=True):
"""
Add sub_table into the table.
@return:
"""
table_name = self.table_name if not table_name else table_name
sql_query = "SELECT * FROM {}".format(table_name)
oring_table = self.execute_query(sql_query)
old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"])
# concat the new column into old table
sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header']))
new_table = old_table.merge(sub_table_df_normed,
how='left', on='row_id') # do left join
new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace',
index=False)
if verbose:
print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]),
sub_table_df_normed.dtypes))
|