################################ # Assumptions: # 1. sql is correct # 2. only table name has alias # 3. only one intersect/union/except # # val: number(float)/string(str)/sql(dict) # col_unit: (agg_id, col_id, isDistinct(bool)) # val_unit: (unit_op, col_unit1, col_unit2) # table_unit: (table_type, col_unit/sql) # cond_unit: (not_op, op_id, val_unit, val1, val2) # condition: [cond_unit1, 'and'/'or', cond_unit2, ...] # sql { # 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) # 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} # 'where': condition # 'groupBy': [col_unit1, col_unit2, ...] # 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) # 'having': condition # 'limit': None/limit value # 'intersect': None/sql # 'except': None/sql # 'union': None/sql # } ################################ import json import sqlite3 from nltk import word_tokenize CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') JOIN_KEYWORDS = ('join', 'on', 'as') WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') UNIT_OPS = ('none', '-', '+', "*", '/') AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') TABLE_TYPE = { 'sql': "sql", 'table_unit': "table_unit", } COND_OPS = ('and', 'or') SQL_OPS = ('intersect', 'union', 'except') ORDER_OPS = ('desc', 'asc') class Schema: """ Simple schema which maps table&column to a unique identifier """ def __init__(self, schema): self._schema = schema self._idMap = self._map(self._schema) @property def schema(self): return self._schema @property def idMap(self): return self._idMap def _map(self, schema): idMap = {'*': "__all__"} id = 1 for key, vals in schema.items(): for val in vals: idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" id += 1 for key in schema: idMap[key.lower()] = "__" + key.lower() + "__" id += 1 return idMap def get_schema(db): """ Get database's schema, which is a dict with table name as key and list of column names as value :param db: database path :return: schema dict """ schema = {} conn = sqlite3.connect(db) cursor = conn.cursor() # fetch table names cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [str(table[0].lower()) for table in cursor.fetchall()] # fetch table info for table in tables: cursor.execute("PRAGMA table_info({})".format(table)) schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] return schema def get_schema_from_json(fpath): with open(fpath) as f: data = json.load(f) schema = {} for entry in data: table = str(entry['table'].lower()) cols = [str(col['column_name'].lower()) for col in entry['col_data']] schema[table] = cols return schema def tokenize(string): string = str(string) string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] assert len(quote_idxs) % 2 == 0, "Unexpected quote" # keep string value as token vals = {} for i in range(len(quote_idxs)-1, -1, -2): qidx1 = quote_idxs[i-1] qidx2 = quote_idxs[i] val = string[qidx1: qidx2+1] key = "__val_{}_{}__".format(qidx1, qidx2) string = string[:qidx1] + key + string[qidx2+1:] vals[key] = val # tokenize sql toks_tmp = [word.lower() for word in word_tokenize(string)] toks = [] for tok in toks_tmp: if tok.startswith('=__val_'): tok = tok[1:] toks.append('=') toks.append(tok) # replace with string value token for i in range(len(toks)): if toks[i] in vals: toks[i] = vals[toks[i]] # find if there exists !=, >=, <= eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] eq_idxs.reverse() prefix = ('!', '>', '<') for eq_idx in eq_idxs: pre_tok = toks[eq_idx-1] if pre_tok in prefix: toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] return toks def scan_alias(toks): """Scan the index of 'as' and build the map for all alias""" as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] alias = {} for idx in as_idxs: alias[toks[idx+1]] = toks[idx-1] return alias def get_tables_with_alias(schema, toks): tables = scan_alias(toks) for key in schema: assert key not in tables, "Alias {} has the same name in table".format(key) tables[key] = key return tables def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, column id """ tok = toks[start_idx] if tok == "*": return start_idx + 1, schema.idMap[tok] if '.' in tok: # if token is a composite alias, col = tok.split('.') key = tables_with_alias[alias] + "." + col return start_idx+1, schema.idMap[key] assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" for alias in default_tables: table = tables_with_alias[alias] if tok in schema.schema[table]: key = table + "." + tok return start_idx+1, schema.idMap[key] assert False, "Error col: {}".format(tok) def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): """ :returns next idx, (agg_op id, col_id) """ idx = start_idx len_ = len(toks) isBlock = False isDistinct = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 assert idx < len_ and toks[idx] == '(' idx += 1 if toks[idx] == "distinct": idx += 1 isDistinct = True idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) assert idx < len_ and toks[idx] == ')' idx += 1 return idx, (agg_id, col_id, isDistinct) if toks[idx] == "distinct": idx += 1 isDistinct = True agg_id = AGG_OPS.index("none") idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx, (agg_id, col_id, isDistinct) def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 col_unit1 = None col_unit2 = None unit_op = UNIT_OPS.index('none') idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if idx < len_ and toks[idx] in UNIT_OPS: unit_op = UNIT_OPS.index(toks[idx]) idx += 1 idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' return idx, (unit_op, col_unit1, col_unit2) def parse_table_unit(toks, start_idx, tables_with_alias, schema): """ :returns next idx, table id, table name """ idx = start_idx len_ = len(toks) key = tables_with_alias[toks[idx]] if idx + 1 < len_ and toks[idx+1] == "as": idx += 3 else: idx += 1 return idx, schema.idMap[key], key def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx, val = parse_sql(toks, idx, tables_with_alias, schema) elif "\"" in toks[idx]: # token is a string value val = toks[idx] idx += 1 else: try: val = float(toks[idx]) idx += 1 except: end_idx = idx while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: end_idx += 1 idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) idx = end_idx if isBlock: assert toks[idx] == ')' idx += 1 return idx, val def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) conds = [] while idx < len_: idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) not_op = False if toks[idx] == 'not': not_op = True idx += 1 assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) op_id = WHERE_OPS.index(toks[idx]) idx += 1 val1 = val2 = None if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) assert toks[idx] == 'and' idx += 1 idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) else: # normal case: single value idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) val2 = None conds.append((not_op, op_id, val_unit, val1, val2)) if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): break if idx < len_ and toks[idx] in COND_OPS: conds.append(toks[idx]) idx += 1 # skip and/or return idx, conds def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): idx = start_idx len_ = len(toks) assert toks[idx] == 'select', "'select' not found" idx += 1 isDistinct = False if idx < len_ and toks[idx] == 'distinct': idx += 1 isDistinct = True val_units = [] while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: agg_id = AGG_OPS.index("none") if toks[idx] in AGG_OPS: agg_id = AGG_OPS.index(toks[idx]) idx += 1 idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) val_units.append((agg_id, val_unit)) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' return idx, (isDistinct, val_units) def parse_from(toks, start_idx, tables_with_alias, schema): """ Assume in the from clause, all table units are combined with join """ assert 'from' in toks[start_idx:], "'from' not found" len_ = len(toks) idx = toks.index('from', start_idx) + 1 default_tables = [] table_units = [] conds = [] while idx < len_: isBlock = False if toks[idx] == '(': isBlock = True idx += 1 if toks[idx] == 'select': idx, sql = parse_sql(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE['sql'], sql)) else: if idx < len_ and toks[idx] == 'join': idx += 1 # skip join idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) table_units.append((TABLE_TYPE['table_unit'],table_unit)) default_tables.append(table_name) if idx < len_ and toks[idx] == "on": idx += 1 # skip on idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) if len(conds) > 0: conds.append('and') conds.extend(this_conds) if isBlock: assert toks[idx] == ')' idx += 1 if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): break return idx, table_units, conds, default_tables def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'where': return idx, [] idx += 1 idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx, conds def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) col_units = [] if idx >= len_ or toks[idx] != 'group': return idx, col_units idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) col_units.append(col_unit) if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx, col_units def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) val_units = [] order_type = 'asc' # default type is 'asc' if idx >= len_ or toks[idx] != 'order': return idx, val_units idx += 1 assert toks[idx] == 'by' idx += 1 while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) val_units.append(val_unit) if idx < len_ and toks[idx] in ORDER_OPS: order_type = toks[idx] idx += 1 if idx < len_ and toks[idx] == ',': idx += 1 # skip ',' else: break return idx, (order_type, val_units) def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): idx = start_idx len_ = len(toks) if idx >= len_ or toks[idx] != 'having': return idx, [] idx += 1 idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) return idx, conds def parse_limit(toks, start_idx): idx = start_idx len_ = len(toks) if idx < len_ and toks[idx] == 'limit': idx += 2 # make limit value can work, cannot assume put 1 as a fake limit number if type(toks[idx-1]) != int: return idx, 1 return idx, int(toks[idx-1]) return idx, None def parse_sql(toks, start_idx, tables_with_alias, schema): isBlock = False # indicate whether this is a block of sql/sub-sql len_ = len(toks) idx = start_idx sql = {} if toks[idx] == '(': isBlock = True idx += 1 # parse from clause in order to get default tables from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) sql['from'] = {'table_units': table_units, 'conds': conds} # select clause _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) idx = from_end_idx sql['select'] = select_col_units # where clause idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) sql['where'] = where_conds # group by clause idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) sql['groupBy'] = group_col_units # having clause idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) sql['having'] = having_conds # order by clause idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) sql['orderBy'] = order_col_units # limit clause idx, limit_val = parse_limit(toks, idx) sql['limit'] = limit_val idx = skip_semicolon(toks, idx) if isBlock: assert toks[idx] == ')' idx += 1 # skip ')' idx = skip_semicolon(toks, idx) # intersect/union/except clause for op in SQL_OPS: # initialize IUE sql[op] = None if idx < len_ and toks[idx] in SQL_OPS: sql_op = toks[idx] idx += 1 idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) sql[sql_op] = IUE_sql return idx, sql def load_data(fpath): with open(fpath) as f: data = json.load(f) return data def get_sql(schema, query): toks = tokenize(query) tables_with_alias = get_tables_with_alias(schema.schema, toks) _, sql = parse_sql(toks, 0, tables_with_alias, schema) return sql def skip_semicolon(toks, start_idx): idx = start_idx while idx < len(toks) and toks[idx] == ";": idx += 1 return idx def get_schemas_from_json(fpath): with open(fpath) as f: data = json.load(f) db_names = [db['db_id'] for db in data] tables = {} schemas = {} for db in data: db_id = db['db_id'] schema = {} #{'table': [col.lower, ..., ]} * -> __all__ column_names_original = db['column_names_original'] table_names_original = db['table_names_original'] tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original} for i, tabn in enumerate(table_names_original): table = str(tabn.lower()) cols = [str(col.lower()) for td, col in column_names_original if td == i] schema[table] = cols schemas[db_id] = schema return schemas, db_names, tables