from enum import Enum from typing import Any, List, Tuple, Union from cassandra.query import PreparedStatement, SimpleStatement class CQLOpType(Enum): SCHEMA = 1 WRITE = 2 READ = 3 CREATE_TABLE_CQL_TEMPLATE = """CREATE TABLE IF NOT EXISTS {{table_fqname}} ({columns_spec} PRIMARY KEY {primkey_spec}) {options_clause};""" # noqa: E501 TRUNCATE_TABLE_CQL_TEMPLATE = """TRUNCATE TABLE {{table_fqname}};""" DELETE_CQL_TEMPLATE = """DELETE FROM {{table_fqname}} {where_clause};""" SELECT_CQL_TEMPLATE = ( """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} {limit_clause};""" ) INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;""" # noqa: E501 CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} " # noqa: E501 CREATE_INDEX_CQL_TEMPLATE = ( CREATE_INDEX_CQL_PREFIX + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' {options_clause};" # noqa: E501 ) CREATE_KEYS_INDEX_CQL_TEMPLATE = ( CREATE_INDEX_CQL_PREFIX + "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 ) CREATE_ENTRIES_INDEX_CQL_TEMPLATE = ( CREATE_INDEX_CQL_PREFIX + "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 "" ) SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};""" # noqa: E501 CQLStatementType = Union[str, SimpleStatement, PreparedStatement] StatementWithArgs = Tuple[CQLStatementType, Tuple[Any, ...]] StatementStrWithArgs = Tuple[str, Tuple[Any, ...]] # Mock DB session class MockDBSession: def __init__(self, verbose: bool = False): self.verbose = verbose self.statements: List[StatementWithArgs] = [] @staticmethod def get_statement_body(statement: CQLStatementType) -> str: if isinstance(statement, str): _statement = statement elif isinstance(statement, SimpleStatement): _statement = statement.query_string elif isinstance(statement, PreparedStatement): _statement = statement.query_string else: raise ValueError() return _statement @staticmethod def normalize_cql_statement(statement: CQLStatementType) -> str: _statement = MockDBSession.get_statement_body(statement) _s = ( _statement.replace(";", " ") .replace("%s", " %s ") .replace("?", " ? ") .replace("=", " = ") .replace(")", " ) ") .replace("(", " ( ") .replace("\n", " ") ) return " ".join( tok.lower() for tok in (_t.strip() for _t in _s.split(" ") if _t.strip()) ) @staticmethod def prepare(statement: str) -> PreparedStatement: # A very unusable 'prepared statement' just for tracing/debugging: return PreparedStatement(None, 0, 0, statement, "keyspace", None, None, None) def execute( self, statement: CQLStatementType, arguments: Tuple[Any, ...] = tuple() ) -> List[Any]: if self.verbose: # st_body = self.get_statement_body(statement) if isinstance(statement, str): st_type = "STR" placeholder_count = st_body.count("%s") assert "?" not in st_body elif isinstance(statement, SimpleStatement): st_type = "SIM" placeholder_count = st_body.count("%s") assert "?" not in st_body elif isinstance(statement, PreparedStatement): st_type = "PRE" placeholder_count = st_body.count("?") assert "%s" not in st_body # assert placeholder_count == len(arguments) # print(f"CQL_EXECUTE [{st_type}]:") print(f" {st_body}") if arguments: print(f" {str(arguments)}") self.statements.append((statement, arguments)) return [] def last_raw(self, n: int) -> List[StatementWithArgs]: if n <= 0: return [] else: return self.statements[-n:] def last(self, n: int) -> List[StatementStrWithArgs]: return [ ( self.normalize_cql_statement(stmt), data, ) for stmt, data in self.last_raw(n) ] def assert_last_equal( self, expected_statements: List[StatementStrWithArgs] ) -> None: # used for testing last_executed = self.last(len(expected_statements)) assert len(last_executed) == len(expected_statements) for s_exe, s_expe in zip(last_executed, expected_statements): assert s_exe[1] == s_expe[1], f"EXE#{str(s_exe[1])}# != EXPE#{s_expe[1]}#" exe_cql = self.normalize_cql_statement(s_exe[0]) expe_cql = self.normalize_cql_statement(s_expe[0]) assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#" return None STANDARD_ANALYZER = ("index_analyzer", "STANDARD") LOWER_CASE_ANALYZER = ("case_sensitive", False) NORMALIZE_ANALYZER = ("normalize", True) ASCII_ANALYZER = ("ascii", True)