Spaces:
Sleeping
Sleeping
import ast | |
import base64 | |
import duckdb | |
import json | |
import re | |
import textwrap | |
from ulid import ULID | |
HISTORY_FILE = "history.json" | |
MAX_ROWS = 10000 | |
class SQLError(Exception): | |
pass | |
class NotFoundError(Exception): | |
pass | |
class Q(str): | |
UNSAFE = ["CREATE", "DELETE", "DROP", "INSERT", "UPDATE"] | |
rows=None | |
def __new__(cls, template: str, **kwargs): | |
"""Create a new Q-string.""" | |
_template = textwrap.dedent(template).strip() | |
try: | |
instance = str.__new__(cls, _template.format(**kwargs)) | |
except KeyError: | |
instance = str.__new__(cls, _template) | |
instance.id = str(ULID()) | |
instance.alias = kwargs.pop("alias") if kwargs.get("alias") else None | |
instance.template = _template | |
instance.kwargs = kwargs | |
instance.definitions = "\n".join([f"{k} = {repr(v)}" for k, v in kwargs.items()]) | |
for attr in ("rows", "cols", "source_id", "start", "end"): | |
setattr(instance, attr, None) | |
return instance | |
def __repr__(self): | |
"""Neat repr for inspecting Q objects.""" | |
strings = [] | |
for k, v in self.__dict__.items(): | |
value_repr = "\n" + textwrap.indent(v, " ") if "\n" in str(v) else v | |
strings.append(f"{k}: {value_repr}") | |
return "\n".join(strings) | |
def run(self, sql_engine=None, save=False, _raise=False): | |
self.start = ULID() | |
try: | |
if sql_engine is None: | |
res = self.run_duckdb() | |
else: | |
res = self.run_sql(sql_engine) | |
self.rows, self.cols = res.shape | |
return res | |
except Exception as e: | |
if _raise: | |
raise e | |
return str(e) | |
finally: | |
self.end = ULID() | |
if save: | |
self.save() | |
def run_duckdb(self): | |
if MAX_ROWS: | |
return duckdb.sql(f"WITH x AS ({self}) SELECT * FROM x LIMIT {MAX_ROWS}") | |
else: | |
return duckdb.sql(self) | |
def df(self, sql_engine=None, save=False, _raise=False): | |
res = self.run(sql_engine=sql_engine, save=save, _raise=_raise) | |
if not getattr(self, "rows", None): | |
return | |
else: | |
result_df = res.df() | |
result_df.q = self | |
return result_df | |
def save(self, file=HISTORY_FILE): | |
with open(file, "a") as f: | |
f.write(self.json) | |
f.write("\n") | |
def json(self): | |
serialized = {"id": self.id, "q": self} | |
serialized.update(self.__dict__) | |
return json.dumps(serialized, default=lambda x: x.datetime.strftime("%F %T.%f")[:-3]) | |
def is_safe(self): | |
return not any(cmd in self.template.upper() for cmd in self.UNSAFE) | |
def from_dict(cls, query_dict: dict): | |
q = query_dict.pop("q") | |
return cls(q, **query_dict) | |
def from_template_and_definitions(cls, template: str, definitions: str, alias: str|None = None): | |
query_dict = {"q": template, "alias": alias} | |
query_dict.update(parse_definitions(definitions)) | |
instance = Q.from_dict(query_dict) | |
instance.definitions = definitions | |
return instance | |
def from_history(cls, query_id=None, alias=None): | |
search_query = Q(f""" | |
SELECT id, template, kwargs | |
FROM '{HISTORY_FILE}' | |
WHERE id='{query_id}' OR alias='{alias}' | |
LIMIT 1 | |
""") | |
query = search_query.run() | |
if search_query.rows == 1: | |
source_id, template, kwargs = query.fetchall()[0] | |
kwargs = {k: v for k, v in kwargs.items() if v is not None} | |
instance = cls(template, **kwargs) | |
instance.source_id = source_id | |
return instance | |
elif search_query.rows == 0: | |
raise NotFoundError(f"id '{query_id}' / alias '{alias}' not found") | |
else: | |
raise SQLError(query) | |
# @property | |
# def definitions(self): | |
# return "\n".join([""]+[f"{k} = {v}" for k, v in self.kwargs.items()]) | |
def base64(self): | |
return base64.b64encode(self.encode()).decode() | |
def from_base64(cls, b64): | |
"""Initializing from base64-encoded URL paths.""" | |
return cls(base64.b64decode(b64).decode()) | |
def parse_definitions(definitions) -> dict: | |
"""Parse a string literal of "key=value" pairs, one per line, into kwargs.""" | |
kwargs = {} | |
lines = definitions.split("\n") | |
for _line in lines: | |
line = re.sub("\s+", "", _line) | |
if line == "" or line.startswith("#"): | |
continue | |
if "=" in line: | |
key, value = line.split("=", maxsplit=1) | |
kwargs[key] = ast.literal_eval(value) | |
return kwargs | |
EX1 = Q.from_template_and_definitions( | |
template="SELECT {x} AS {colname}", | |
definitions="\n".join([ | |
"# Define variables: one '=' per line", | |
"x=42", | |
"colname='answer'", | |
]), | |
alias="example1", | |
) | |
EX2 = Q( | |
""" | |
SELECT | |
Symbol, | |
Number, | |
Mass, | |
Abundance | |
FROM '{url}' | |
""", | |
url="https://raw.githubusercontent.com/ekwan/cctk/master/cctk/data/isotopes.csv", | |
alias="example2", | |
) | |
EX3 = Q( | |
""" | |
SELECT * | |
FROM 'history.json' | |
ORDER BY id DESC | |
""", | |
alias="example3", | |
) | |
EX4 = Q("SELECT nothing", alias="bad_example") |