Draken007's picture
Upload 7228 files
2a0bc63 verified
from enum import Enum
from typing import Any, List, Tuple, Union
from cassandra.query import PreparedStatement, SimpleStatement
class CQLOpType(Enum):
SCHEMA = 1
WRITE = 2
READ = 3
CREATE_TABLE_CQL_TEMPLATE = """CREATE TABLE IF NOT EXISTS {{table_fqname}} ({columns_spec} PRIMARY KEY {primkey_spec}) {options_clause};""" # noqa: E501
TRUNCATE_TABLE_CQL_TEMPLATE = """TRUNCATE TABLE {{table_fqname}};"""
DELETE_CQL_TEMPLATE = """DELETE FROM {{table_fqname}} {where_clause};"""
SELECT_CQL_TEMPLATE = (
"""SELECT {columns_desc} FROM {{table_fqname}} {where_clause} {limit_clause};"""
)
INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;""" # noqa: E501
CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} " # noqa: E501
CREATE_INDEX_CQL_TEMPLATE = (
CREATE_INDEX_CQL_PREFIX
+ "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' {options_clause};" # noqa: E501
)
CREATE_KEYS_INDEX_CQL_TEMPLATE = (
CREATE_INDEX_CQL_PREFIX
+ "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501
)
CREATE_ENTRIES_INDEX_CQL_TEMPLATE = (
CREATE_INDEX_CQL_PREFIX
+ "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501
""
)
SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};""" # noqa: E501
CQLStatementType = Union[str, SimpleStatement, PreparedStatement]
StatementWithArgs = Tuple[CQLStatementType, Tuple[Any, ...]]
StatementStrWithArgs = Tuple[str, Tuple[Any, ...]]
# Mock DB session
class MockDBSession:
def __init__(self, verbose: bool = False):
self.verbose = verbose
self.statements: List[StatementWithArgs] = []
@staticmethod
def get_statement_body(statement: CQLStatementType) -> str:
if isinstance(statement, str):
_statement = statement
elif isinstance(statement, SimpleStatement):
_statement = statement.query_string
elif isinstance(statement, PreparedStatement):
_statement = statement.query_string
else:
raise ValueError()
return _statement
@staticmethod
def normalize_cql_statement(statement: CQLStatementType) -> str:
_statement = MockDBSession.get_statement_body(statement)
_s = (
_statement.replace(";", " ")
.replace("%s", " %s ")
.replace("?", " ? ")
.replace("=", " = ")
.replace(")", " ) ")
.replace("(", " ( ")
.replace("\n", " ")
)
return " ".join(
tok.lower() for tok in (_t.strip() for _t in _s.split(" ") if _t.strip())
)
@staticmethod
def prepare(statement: str) -> PreparedStatement:
# A very unusable 'prepared statement' just for tracing/debugging:
return PreparedStatement(None, 0, 0, statement, "keyspace", None, None, None)
def execute(
self, statement: CQLStatementType, arguments: Tuple[Any, ...] = tuple()
) -> List[Any]:
if self.verbose:
#
st_body = self.get_statement_body(statement)
if isinstance(statement, str):
st_type = "STR"
placeholder_count = st_body.count("%s")
assert "?" not in st_body
elif isinstance(statement, SimpleStatement):
st_type = "SIM"
placeholder_count = st_body.count("%s")
assert "?" not in st_body
elif isinstance(statement, PreparedStatement):
st_type = "PRE"
placeholder_count = st_body.count("?")
assert "%s" not in st_body
#
assert placeholder_count == len(arguments)
#
print(f"CQL_EXECUTE [{st_type}]:")
print(f" {st_body}")
if arguments:
print(f" {str(arguments)}")
self.statements.append((statement, arguments))
return []
def last_raw(self, n: int) -> List[StatementWithArgs]:
if n <= 0:
return []
else:
return self.statements[-n:]
def last(self, n: int) -> List[StatementStrWithArgs]:
return [
(
self.normalize_cql_statement(stmt),
data,
)
for stmt, data in self.last_raw(n)
]
def assert_last_equal(
self, expected_statements: List[StatementStrWithArgs]
) -> None:
# used for testing
last_executed = self.last(len(expected_statements))
assert len(last_executed) == len(expected_statements)
for s_exe, s_expe in zip(last_executed, expected_statements):
assert s_exe[1] == s_expe[1], f"EXE#{str(s_exe[1])}# != EXPE#{s_expe[1]}#"
exe_cql = self.normalize_cql_statement(s_exe[0])
expe_cql = self.normalize_cql_statement(s_expe[0])
assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#"
return None
STANDARD_ANALYZER = ("index_analyzer", "STANDARD")
LOWER_CASE_ANALYZER = ("case_sensitive", False)
NORMALIZE_ANALYZER = ("normalize", True)
ASCII_ANALYZER = ("ascii", True)