|
import ast |
|
import base64 |
|
import duckdb |
|
import json |
|
import re |
|
import textwrap |
|
from ulid import ULID |
|
|
|
HISTORY_FILE = "history.json" |
|
MAX_ROWS = 10000 |
|
|
|
class SQLError(Exception): |
|
pass |
|
|
|
class NotFoundError(Exception): |
|
pass |
|
|
|
class Q(str): |
|
UNSAFE = ["CREATE", "DELETE", "DROP", "INSERT", "UPDATE"] |
|
rows=None |
|
def __new__(cls, template: str, **kwargs): |
|
"""Create a new Q-string.""" |
|
_template = textwrap.dedent(template).strip() |
|
try: |
|
instance = str.__new__(cls, _template.format(**kwargs)) |
|
except KeyError: |
|
instance = str.__new__(cls, _template) |
|
instance.id = str(ULID()) |
|
instance.alias = kwargs.pop("alias") if kwargs.get("alias") else None |
|
instance.template = _template |
|
instance.kwargs = kwargs |
|
instance.definitions = "\n".join([f"{k} = {repr(v)}" for k, v in kwargs.items()]) |
|
|
|
for attr in ("rows", "cols", "source_id", "start", "end"): |
|
setattr(instance, attr, None) |
|
return instance |
|
|
|
def __repr__(self): |
|
"""Neat repr for inspecting Q objects.""" |
|
strings = [] |
|
for k, v in self.__dict__.items(): |
|
value_repr = "\n" + textwrap.indent(v, " ") if "\n" in str(v) else v |
|
strings.append(f"{k}: {value_repr}") |
|
return "\n".join(strings) |
|
|
|
def run(self, sql_engine=None, save=False, _raise=False): |
|
self.start = ULID() |
|
try: |
|
if sql_engine is None: |
|
res = self.run_duckdb() |
|
else: |
|
res = self.run_sql(sql_engine) |
|
self.rows, self.cols = res.shape |
|
return res |
|
except Exception as e: |
|
if _raise: |
|
raise e |
|
return str(e) |
|
finally: |
|
self.end = ULID() |
|
if save: |
|
self.save() |
|
|
|
def run_duckdb(self): |
|
if MAX_ROWS: |
|
return duckdb.sql(f"WITH x AS ({self}) SELECT * FROM x LIMIT {MAX_ROWS}") |
|
else: |
|
return duckdb.sql(self) |
|
|
|
def df(self, sql_engine=None, save=False, _raise=False): |
|
res = self.run(sql_engine=sql_engine, save=save, _raise=_raise) |
|
if not getattr(self, "rows", None): |
|
return |
|
else: |
|
result_df = res.df() |
|
result_df.q = self |
|
return result_df |
|
|
|
def save(self, file=HISTORY_FILE): |
|
with open(file, "a") as f: |
|
f.write(self.json) |
|
f.write("\n") |
|
|
|
@property |
|
def json(self): |
|
serialized = {"id": self.id, "q": self} |
|
serialized.update(self.__dict__) |
|
return json.dumps(serialized, default=lambda x: x.datetime.strftime("%F %T.%f")[:-3]) |
|
|
|
@property |
|
def is_safe(self): |
|
return not any(cmd in self.template.upper() for cmd in self.UNSAFE) |
|
|
|
|
|
@classmethod |
|
def from_dict(cls, query_dict: dict): |
|
q = query_dict.pop("q") |
|
return cls(q, **query_dict) |
|
|
|
@classmethod |
|
def from_template_and_definitions(cls, template: str, definitions: str, alias: str|None = None): |
|
query_dict = {"q": template, "alias": alias} |
|
query_dict.update(parse_definitions(definitions)) |
|
instance = Q.from_dict(query_dict) |
|
instance.definitions = definitions |
|
return instance |
|
|
|
@classmethod |
|
def from_history(cls, query_id=None, alias=None): |
|
search_query = Q(f""" |
|
SELECT id, template, kwargs |
|
FROM '{HISTORY_FILE}' |
|
WHERE id='{query_id}' OR alias='{alias}' |
|
LIMIT 1 |
|
""") |
|
query = search_query.run() |
|
if search_query.rows == 1: |
|
source_id, template, kwargs = query.fetchall()[0] |
|
kwargs = {k: v for k, v in kwargs.items() if v is not None} |
|
instance = cls(template, **kwargs) |
|
instance.source_id = source_id |
|
return instance |
|
elif search_query.rows == 0: |
|
raise NotFoundError(f"id '{query_id}' / alias '{alias}' not found") |
|
else: |
|
raise SQLError(query) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def base64(self): |
|
return base64.b64encode(self.encode()).decode() |
|
|
|
@classmethod |
|
def from_base64(cls, b64): |
|
"""Initializing from base64-encoded URL paths.""" |
|
return cls(base64.b64decode(b64).decode()) |
|
|
|
|
|
def parse_definitions(definitions) -> dict: |
|
"""Parse a string literal of "key=value" pairs, one per line, into kwargs.""" |
|
kwargs = {} |
|
lines = definitions.split("\n") |
|
for _line in lines: |
|
line = re.sub("\s+", "", _line) |
|
if line == "" or line.startswith("#"): |
|
continue |
|
if "=" in line: |
|
key, value = line.split("=", maxsplit=1) |
|
kwargs[key] = ast.literal_eval(value) |
|
return kwargs |
|
|
|
|
|
EX1 = Q.from_template_and_definitions( |
|
template="SELECT {x} AS {colname}", |
|
definitions="\n".join([ |
|
"# Define variables: one '=' per line", |
|
"x=42", |
|
"colname='answer'", |
|
]), |
|
alias="example1", |
|
) |
|
|
|
EX2 = Q( |
|
""" |
|
SELECT |
|
Symbol, |
|
Number, |
|
Mass, |
|
Abundance |
|
FROM '{url}' |
|
""", |
|
url="https://raw.githubusercontent.com/ekwan/cctk/master/cctk/data/isotopes.csv", |
|
alias="example2", |
|
) |
|
|
|
EX3 = Q( |
|
""" |
|
SELECT * |
|
FROM 'history.json' |
|
ORDER BY id DESC |
|
""", |
|
alias="example3", |
|
) |
|
|
|
EX4 = Q("SELECT nothing", alias="bad_example") |