Spaces:
Runtime error
Runtime error
################################ | |
# Assumptions: | |
# 1. sql is correct | |
# 2. only table name has alias | |
# 3. only one intersect/union/except | |
# | |
# val: number(float)/string(str)/sql(dict) | |
# col_unit: (agg_id, col_id, isDistinct(bool)) | |
# val_unit: (unit_op, col_unit1, col_unit2) | |
# table_unit: (table_type, col_unit/sql) | |
# cond_unit: (not_op, op_id, val_unit, val1, val2) | |
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] | |
# sql { | |
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) | |
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} | |
# 'where': condition | |
# 'groupBy': [col_unit1, col_unit2, ...] | |
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) | |
# 'having': condition | |
# 'limit': None/limit value | |
# 'intersect': None/sql | |
# 'except': None/sql | |
# 'union': None/sql | |
# } | |
################################ | |
import json | |
import sqlite3 | |
from nltk import word_tokenize | |
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') | |
JOIN_KEYWORDS = ('join', 'on', 'as') | |
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') | |
UNIT_OPS = ('none', '-', '+', "*", '/') | |
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') | |
TABLE_TYPE = { | |
'sql': "sql", | |
'table_unit': "table_unit", | |
} | |
COND_OPS = ('and', 'or') | |
SQL_OPS = ('intersect', 'union', 'except') | |
ORDER_OPS = ('desc', 'asc') | |
class Schema: | |
""" | |
Simple schema which maps table&column to a unique identifier | |
""" | |
def __init__(self, schema): | |
self._schema = schema | |
self._idMap = self._map(self._schema) | |
def schema(self): | |
return self._schema | |
def idMap(self): | |
return self._idMap | |
def _map(self, schema): | |
idMap = {'*': "__all__"} | |
id = 1 | |
for key, vals in schema.items(): | |
for val in vals: | |
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" | |
id += 1 | |
for key in schema: | |
idMap[key.lower()] = "__" + key.lower() + "__" | |
id += 1 | |
return idMap | |
def get_schema(db): | |
""" | |
Get database's schema, which is a dict with table name as key | |
and list of column names as value | |
:param db: database path | |
:return: schema dict | |
""" | |
schema = {} | |
conn = sqlite3.connect(db) | |
cursor = conn.cursor() | |
# fetch table names | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = [str(table[0].lower()) for table in cursor.fetchall()] | |
# fetch table info | |
for table in tables: | |
cursor.execute("PRAGMA table_info({})".format(table)) | |
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] | |
return schema | |
def get_schema_from_json(fpath): | |
with open(fpath) as f: | |
data = json.load(f) | |
schema = {} | |
for entry in data: | |
table = str(entry['table'].lower()) | |
cols = [str(col['column_name'].lower()) for col in entry['col_data']] | |
schema[table] = cols | |
return schema | |
def tokenize(string): | |
string = str(string) | |
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? | |
quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] | |
assert len(quote_idxs) % 2 == 0, "Unexpected quote" | |
# keep string value as token | |
vals = {} | |
for i in range(len(quote_idxs)-1, -1, -2): | |
qidx1 = quote_idxs[i-1] | |
qidx2 = quote_idxs[i] | |
val = string[qidx1: qidx2+1] | |
key = "__val_{}_{}__".format(qidx1, qidx2) | |
string = string[:qidx1] + key + string[qidx2+1:] | |
vals[key] = val | |
# tokenize sql | |
toks_tmp = [word.lower() for word in word_tokenize(string)] | |
toks = [] | |
for tok in toks_tmp: | |
if tok.startswith('=__val_'): | |
tok = tok[1:] | |
toks.append('=') | |
toks.append(tok) | |
# replace with string value token | |
for i in range(len(toks)): | |
if toks[i] in vals: | |
toks[i] = vals[toks[i]] | |
# find if there exists !=, >=, <= | |
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] | |
eq_idxs.reverse() | |
prefix = ('!', '>', '<') | |
for eq_idx in eq_idxs: | |
pre_tok = toks[eq_idx-1] | |
if pre_tok in prefix: | |
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] | |
return toks | |
def scan_alias(toks): | |
"""Scan the index of 'as' and build the map for all alias""" | |
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] | |
alias = {} | |
for idx in as_idxs: | |
alias[toks[idx+1]] = toks[idx-1] | |
return alias | |
def get_tables_with_alias(schema, toks): | |
tables = scan_alias(toks) | |
for key in schema: | |
assert key not in tables, "Alias {} has the same name in table".format(key) | |
tables[key] = key | |
return tables | |
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
""" | |
:returns next idx, column id | |
""" | |
tok = toks[start_idx] | |
if tok == "*": | |
return start_idx + 1, schema.idMap[tok] | |
if '.' in tok: # if token is a composite | |
alias, col = tok.split('.') | |
key = tables_with_alias[alias] + "." + col | |
return start_idx+1, schema.idMap[key] | |
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" | |
for alias in default_tables: | |
table = tables_with_alias[alias] | |
if tok in schema.schema[table]: | |
key = table + "." + tok | |
return start_idx+1, schema.idMap[key] | |
assert False, "Error col: {}".format(tok) | |
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
""" | |
:returns next idx, (agg_op id, col_id) | |
""" | |
idx = start_idx | |
len_ = len(toks) | |
isBlock = False | |
isDistinct = False | |
if toks[idx] == '(': | |
isBlock = True | |
idx += 1 | |
if toks[idx] in AGG_OPS: | |
agg_id = AGG_OPS.index(toks[idx]) | |
idx += 1 | |
assert idx < len_ and toks[idx] == '(' | |
idx += 1 | |
if toks[idx] == "distinct": | |
idx += 1 | |
isDistinct = True | |
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) | |
assert idx < len_ and toks[idx] == ')' | |
idx += 1 | |
return idx, (agg_id, col_id, isDistinct) | |
if toks[idx] == "distinct": | |
idx += 1 | |
isDistinct = True | |
agg_id = AGG_OPS.index("none") | |
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) | |
if isBlock: | |
assert toks[idx] == ')' | |
idx += 1 # skip ')' | |
return idx, (agg_id, col_id, isDistinct) | |
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
idx = start_idx | |
len_ = len(toks) | |
isBlock = False | |
if toks[idx] == '(': | |
isBlock = True | |
idx += 1 | |
col_unit1 = None | |
col_unit2 = None | |
unit_op = UNIT_OPS.index('none') | |
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
if idx < len_ and toks[idx] in UNIT_OPS: | |
unit_op = UNIT_OPS.index(toks[idx]) | |
idx += 1 | |
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
if isBlock: | |
assert toks[idx] == ')' | |
idx += 1 # skip ')' | |
return idx, (unit_op, col_unit1, col_unit2) | |
def parse_table_unit(toks, start_idx, tables_with_alias, schema): | |
""" | |
:returns next idx, table id, table name | |
""" | |
idx = start_idx | |
len_ = len(toks) | |
key = tables_with_alias[toks[idx]] | |
if idx + 1 < len_ and toks[idx+1] == "as": | |
idx += 3 | |
else: | |
idx += 1 | |
return idx, schema.idMap[key], key | |
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
idx = start_idx | |
len_ = len(toks) | |
isBlock = False | |
if toks[idx] == '(': | |
isBlock = True | |
idx += 1 | |
if toks[idx] == 'select': | |
idx, val = parse_sql(toks, idx, tables_with_alias, schema) | |
elif "\"" in toks[idx]: # token is a string value | |
val = toks[idx] | |
idx += 1 | |
else: | |
try: | |
val = float(toks[idx]) | |
idx += 1 | |
except: | |
end_idx = idx | |
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ | |
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: | |
end_idx += 1 | |
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) | |
idx = end_idx | |
if isBlock: | |
assert toks[idx] == ')' | |
idx += 1 | |
return idx, val | |
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
idx = start_idx | |
len_ = len(toks) | |
conds = [] | |
while idx < len_: | |
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
not_op = False | |
if toks[idx] == 'not': | |
not_op = True | |
idx += 1 | |
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) | |
op_id = WHERE_OPS.index(toks[idx]) | |
idx += 1 | |
val1 = val2 = None | |
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values | |
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
assert toks[idx] == 'and' | |
idx += 1 | |
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
else: # normal case: single value | |
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) | |
val2 = None | |
conds.append((not_op, op_id, val_unit, val1, val2)) | |
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): | |
break | |
if idx < len_ and toks[idx] in COND_OPS: | |
conds.append(toks[idx]) | |
idx += 1 # skip and/or | |
return idx, conds | |
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): | |
idx = start_idx | |
len_ = len(toks) | |
assert toks[idx] == 'select', "'select' not found" | |
idx += 1 | |
isDistinct = False | |
if idx < len_ and toks[idx] == 'distinct': | |
idx += 1 | |
isDistinct = True | |
val_units = [] | |
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: | |
agg_id = AGG_OPS.index("none") | |
if toks[idx] in AGG_OPS: | |
agg_id = AGG_OPS.index(toks[idx]) | |
idx += 1 | |
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
val_units.append((agg_id, val_unit)) | |
if idx < len_ and toks[idx] == ',': | |
idx += 1 # skip ',' | |
return idx, (isDistinct, val_units) | |
def parse_from(toks, start_idx, tables_with_alias, schema): | |
""" | |
Assume in the from clause, all table units are combined with join | |
""" | |
assert 'from' in toks[start_idx:], "'from' not found" | |
len_ = len(toks) | |
idx = toks.index('from', start_idx) + 1 | |
default_tables = [] | |
table_units = [] | |
conds = [] | |
while idx < len_: | |
isBlock = False | |
if toks[idx] == '(': | |
isBlock = True | |
idx += 1 | |
if toks[idx] == 'select': | |
idx, sql = parse_sql(toks, idx, tables_with_alias, schema) | |
table_units.append((TABLE_TYPE['sql'], sql)) | |
else: | |
if idx < len_ and toks[idx] == 'join': | |
idx += 1 # skip join | |
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) | |
table_units.append((TABLE_TYPE['table_unit'],table_unit)) | |
default_tables.append(table_name) | |
if idx < len_ and toks[idx] == "on": | |
idx += 1 # skip on | |
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
if len(conds) > 0: | |
conds.append('and') | |
conds.extend(this_conds) | |
if isBlock: | |
assert toks[idx] == ')' | |
idx += 1 | |
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
break | |
return idx, table_units, conds, default_tables | |
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): | |
idx = start_idx | |
len_ = len(toks) | |
if idx >= len_ or toks[idx] != 'where': | |
return idx, [] | |
idx += 1 | |
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
return idx, conds | |
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): | |
idx = start_idx | |
len_ = len(toks) | |
col_units = [] | |
if idx >= len_ or toks[idx] != 'group': | |
return idx, col_units | |
idx += 1 | |
assert toks[idx] == 'by' | |
idx += 1 | |
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) | |
col_units.append(col_unit) | |
if idx < len_ and toks[idx] == ',': | |
idx += 1 # skip ',' | |
else: | |
break | |
return idx, col_units | |
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): | |
idx = start_idx | |
len_ = len(toks) | |
val_units = [] | |
order_type = 'asc' # default type is 'asc' | |
if idx >= len_ or toks[idx] != 'order': | |
return idx, val_units | |
idx += 1 | |
assert toks[idx] == 'by' | |
idx += 1 | |
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): | |
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) | |
val_units.append(val_unit) | |
if idx < len_ and toks[idx] in ORDER_OPS: | |
order_type = toks[idx] | |
idx += 1 | |
if idx < len_ and toks[idx] == ',': | |
idx += 1 # skip ',' | |
else: | |
break | |
return idx, (order_type, val_units) | |
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): | |
idx = start_idx | |
len_ = len(toks) | |
if idx >= len_ or toks[idx] != 'having': | |
return idx, [] | |
idx += 1 | |
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) | |
return idx, conds | |
def parse_limit(toks, start_idx): | |
idx = start_idx | |
len_ = len(toks) | |
if idx < len_ and toks[idx] == 'limit': | |
idx += 2 | |
# make limit value can work, cannot assume put 1 as a fake limit number | |
if type(toks[idx-1]) != int: | |
return idx, 1 | |
return idx, int(toks[idx-1]) | |
return idx, None | |
def parse_sql(toks, start_idx, tables_with_alias, schema): | |
isBlock = False # indicate whether this is a block of sql/sub-sql | |
len_ = len(toks) | |
idx = start_idx | |
sql = {} | |
if toks[idx] == '(': | |
isBlock = True | |
idx += 1 | |
# parse from clause in order to get default tables | |
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) | |
sql['from'] = {'table_units': table_units, 'conds': conds} | |
# select clause | |
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) | |
idx = from_end_idx | |
sql['select'] = select_col_units | |
# where clause | |
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) | |
sql['where'] = where_conds | |
# group by clause | |
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) | |
sql['groupBy'] = group_col_units | |
# having clause | |
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) | |
sql['having'] = having_conds | |
# order by clause | |
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) | |
sql['orderBy'] = order_col_units | |
# limit clause | |
idx, limit_val = parse_limit(toks, idx) | |
sql['limit'] = limit_val | |
idx = skip_semicolon(toks, idx) | |
if isBlock: | |
assert toks[idx] == ')' | |
idx += 1 # skip ')' | |
idx = skip_semicolon(toks, idx) | |
# intersect/union/except clause | |
for op in SQL_OPS: # initialize IUE | |
sql[op] = None | |
if idx < len_ and toks[idx] in SQL_OPS: | |
sql_op = toks[idx] | |
idx += 1 | |
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) | |
sql[sql_op] = IUE_sql | |
return idx, sql | |
def load_data(fpath): | |
with open(fpath) as f: | |
data = json.load(f) | |
return data | |
def get_sql(schema, query): | |
toks = tokenize(query) | |
tables_with_alias = get_tables_with_alias(schema.schema, toks) | |
_, sql = parse_sql(toks, 0, tables_with_alias, schema) | |
return sql | |
def skip_semicolon(toks, start_idx): | |
idx = start_idx | |
while idx < len(toks) and toks[idx] == ";": | |
idx += 1 | |
return idx | |
def get_schemas_from_json(fpath): | |
with open(fpath) as f: | |
data = json.load(f) | |
db_names = [db['db_id'] for db in data] | |
tables = {} | |
schemas = {} | |
for db in data: | |
db_id = db['db_id'] | |
schema = {} #{'table': [col.lower, ..., ]} * -> __all__ | |
column_names_original = db['column_names_original'] | |
table_names_original = db['table_names_original'] | |
tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original} | |
for i, tabn in enumerate(table_names_original): | |
table = str(tabn.lower()) | |
cols = [str(col.lower()) for td, col in column_names_original if td == i] | |
schema[table] = cols | |
schemas[db_id] = schema | |
return schemas, db_names, tables |