Spaces:
Runtime error
Runtime error
| from enum import Enum | |
| from typing import Any, List, Tuple, Union | |
| from cassandra.query import PreparedStatement, SimpleStatement | |
| class CQLOpType(Enum): | |
| SCHEMA = 1 | |
| WRITE = 2 | |
| READ = 3 | |
| CREATE_TABLE_CQL_TEMPLATE = """CREATE TABLE IF NOT EXISTS {{table_fqname}} ({columns_spec} PRIMARY KEY {primkey_spec}) {options_clause};""" # noqa: E501 | |
| TRUNCATE_TABLE_CQL_TEMPLATE = """TRUNCATE TABLE {{table_fqname}};""" | |
| DELETE_CQL_TEMPLATE = """DELETE FROM {{table_fqname}} {where_clause};""" | |
| SELECT_CQL_TEMPLATE = ( | |
| """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} {limit_clause};""" | |
| ) | |
| INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;""" # noqa: E501 | |
| CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} " # noqa: E501 | |
| CREATE_INDEX_CQL_TEMPLATE = ( | |
| CREATE_INDEX_CQL_PREFIX | |
| + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' {options_clause};" # noqa: E501 | |
| ) | |
| CREATE_KEYS_INDEX_CQL_TEMPLATE = ( | |
| CREATE_INDEX_CQL_PREFIX | |
| + "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 | |
| ) | |
| CREATE_ENTRIES_INDEX_CQL_TEMPLATE = ( | |
| CREATE_INDEX_CQL_PREFIX | |
| + "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';" # noqa: E501 | |
| "" | |
| ) | |
| SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};""" # noqa: E501 | |
| CQLStatementType = Union[str, SimpleStatement, PreparedStatement] | |
| StatementWithArgs = Tuple[CQLStatementType, Tuple[Any, ...]] | |
| StatementStrWithArgs = Tuple[str, Tuple[Any, ...]] | |
| # Mock DB session | |
| class MockDBSession: | |
| def __init__(self, verbose: bool = False): | |
| self.verbose = verbose | |
| self.statements: List[StatementWithArgs] = [] | |
| def get_statement_body(statement: CQLStatementType) -> str: | |
| if isinstance(statement, str): | |
| _statement = statement | |
| elif isinstance(statement, SimpleStatement): | |
| _statement = statement.query_string | |
| elif isinstance(statement, PreparedStatement): | |
| _statement = statement.query_string | |
| else: | |
| raise ValueError() | |
| return _statement | |
| def normalize_cql_statement(statement: CQLStatementType) -> str: | |
| _statement = MockDBSession.get_statement_body(statement) | |
| _s = ( | |
| _statement.replace(";", " ") | |
| .replace("%s", " %s ") | |
| .replace("?", " ? ") | |
| .replace("=", " = ") | |
| .replace(")", " ) ") | |
| .replace("(", " ( ") | |
| .replace("\n", " ") | |
| ) | |
| return " ".join( | |
| tok.lower() for tok in (_t.strip() for _t in _s.split(" ") if _t.strip()) | |
| ) | |
| def prepare(statement: str) -> PreparedStatement: | |
| # A very unusable 'prepared statement' just for tracing/debugging: | |
| return PreparedStatement(None, 0, 0, statement, "keyspace", None, None, None) | |
| def execute( | |
| self, statement: CQLStatementType, arguments: Tuple[Any, ...] = tuple() | |
| ) -> List[Any]: | |
| if self.verbose: | |
| # | |
| st_body = self.get_statement_body(statement) | |
| if isinstance(statement, str): | |
| st_type = "STR" | |
| placeholder_count = st_body.count("%s") | |
| assert "?" not in st_body | |
| elif isinstance(statement, SimpleStatement): | |
| st_type = "SIM" | |
| placeholder_count = st_body.count("%s") | |
| assert "?" not in st_body | |
| elif isinstance(statement, PreparedStatement): | |
| st_type = "PRE" | |
| placeholder_count = st_body.count("?") | |
| assert "%s" not in st_body | |
| # | |
| assert placeholder_count == len(arguments) | |
| # | |
| print(f"CQL_EXECUTE [{st_type}]:") | |
| print(f" {st_body}") | |
| if arguments: | |
| print(f" {str(arguments)}") | |
| self.statements.append((statement, arguments)) | |
| return [] | |
| def last_raw(self, n: int) -> List[StatementWithArgs]: | |
| if n <= 0: | |
| return [] | |
| else: | |
| return self.statements[-n:] | |
| def last(self, n: int) -> List[StatementStrWithArgs]: | |
| return [ | |
| ( | |
| self.normalize_cql_statement(stmt), | |
| data, | |
| ) | |
| for stmt, data in self.last_raw(n) | |
| ] | |
| def assert_last_equal( | |
| self, expected_statements: List[StatementStrWithArgs] | |
| ) -> None: | |
| # used for testing | |
| last_executed = self.last(len(expected_statements)) | |
| assert len(last_executed) == len(expected_statements) | |
| for s_exe, s_expe in zip(last_executed, expected_statements): | |
| assert s_exe[1] == s_expe[1], f"EXE#{str(s_exe[1])}# != EXPE#{s_expe[1]}#" | |
| exe_cql = self.normalize_cql_statement(s_exe[0]) | |
| expe_cql = self.normalize_cql_statement(s_expe[0]) | |
| assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#" | |
| return None | |
| STANDARD_ANALYZER = ("index_analyzer", "STANDARD") | |
| LOWER_CASE_ANALYZER = ("case_sensitive", False) | |
| NORMALIZE_ANALYZER = ("normalize", True) | |
| ASCII_ANALYZER = ("ascii", True) | |