File size: 5,392 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from enum import Enum
from typing import Any, List, Tuple, Union

from cassandra.query import PreparedStatement, SimpleStatement


class CQLOpType(Enum):
    SCHEMA = 1
    WRITE = 2
    READ = 3


CREATE_TABLE_CQL_TEMPLATE = """CREATE TABLE IF NOT EXISTS {{table_fqname}} ({columns_spec} PRIMARY KEY {primkey_spec}) {options_clause};"""  # noqa: E501

TRUNCATE_TABLE_CQL_TEMPLATE = """TRUNCATE TABLE {{table_fqname}};"""

DELETE_CQL_TEMPLATE = """DELETE FROM {{table_fqname}} {where_clause};"""

SELECT_CQL_TEMPLATE = (
    """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} {limit_clause};"""
)

INSERT_ROW_CQL_TEMPLATE = """INSERT INTO {{table_fqname}} ({columns_desc}) VALUES ({value_placeholders}) {ttl_spec} ;"""  # noqa: E501

CREATE_INDEX_CQL_PREFIX = "CREATE CUSTOM INDEX IF NOT EXISTS {index_name}_{{table_name}} ON {{table_fqname}} "  # noqa: E501

CREATE_INDEX_CQL_TEMPLATE = (
    CREATE_INDEX_CQL_PREFIX
    + "({index_column}) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex' {options_clause};"  # noqa: E501
)

CREATE_KEYS_INDEX_CQL_TEMPLATE = (
    CREATE_INDEX_CQL_PREFIX
    + "(KEYS({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';"  # noqa: E501
)

CREATE_ENTRIES_INDEX_CQL_TEMPLATE = (
    CREATE_INDEX_CQL_PREFIX
    + "(ENTRIES({index_column})) USING 'org.apache.cassandra.index.sai.StorageAttachedIndex';"  # noqa: E501
    ""
)

SELECT_ANN_CQL_TEMPLATE = """SELECT {columns_desc} FROM {{table_fqname}} {where_clause} ORDER BY {vector_column} ANN OF %s {limit_clause};"""  # noqa: E501

CQLStatementType = Union[str, SimpleStatement, PreparedStatement]
StatementWithArgs = Tuple[CQLStatementType, Tuple[Any, ...]]
StatementStrWithArgs = Tuple[str, Tuple[Any, ...]]


# Mock DB session
class MockDBSession:
    def __init__(self, verbose: bool = False):
        self.verbose = verbose
        self.statements: List[StatementWithArgs] = []

    @staticmethod
    def get_statement_body(statement: CQLStatementType) -> str:
        if isinstance(statement, str):
            _statement = statement
        elif isinstance(statement, SimpleStatement):
            _statement = statement.query_string
        elif isinstance(statement, PreparedStatement):
            _statement = statement.query_string
        else:
            raise ValueError()
        return _statement

    @staticmethod
    def normalize_cql_statement(statement: CQLStatementType) -> str:
        _statement = MockDBSession.get_statement_body(statement)
        _s = (
            _statement.replace(";", " ")
            .replace("%s", " %s ")
            .replace("?", " ? ")
            .replace("=", " = ")
            .replace(")", " ) ")
            .replace("(", " ( ")
            .replace("\n", " ")
        )
        return " ".join(
            tok.lower() for tok in (_t.strip() for _t in _s.split(" ") if _t.strip())
        )

    @staticmethod
    def prepare(statement: str) -> PreparedStatement:
        # A very unusable 'prepared statement' just for tracing/debugging:
        return PreparedStatement(None, 0, 0, statement, "keyspace", None, None, None)

    def execute(
        self, statement: CQLStatementType, arguments: Tuple[Any, ...] = tuple()
    ) -> List[Any]:
        if self.verbose:
            #
            st_body = self.get_statement_body(statement)
            if isinstance(statement, str):
                st_type = "STR"
                placeholder_count = st_body.count("%s")
                assert "?" not in st_body
            elif isinstance(statement, SimpleStatement):
                st_type = "SIM"
                placeholder_count = st_body.count("%s")
                assert "?" not in st_body
            elif isinstance(statement, PreparedStatement):
                st_type = "PRE"
                placeholder_count = st_body.count("?")
                assert "%s" not in st_body
            #
            assert placeholder_count == len(arguments)
            #
            print(f"CQL_EXECUTE [{st_type}]:")
            print(f"    {st_body}")
            if arguments:
                print(f"    {str(arguments)}")
        self.statements.append((statement, arguments))
        return []

    def last_raw(self, n: int) -> List[StatementWithArgs]:
        if n <= 0:
            return []
        else:
            return self.statements[-n:]

    def last(self, n: int) -> List[StatementStrWithArgs]:
        return [
            (
                self.normalize_cql_statement(stmt),
                data,
            )
            for stmt, data in self.last_raw(n)
        ]

    def assert_last_equal(
        self, expected_statements: List[StatementStrWithArgs]
    ) -> None:
        # used for testing
        last_executed = self.last(len(expected_statements))
        assert len(last_executed) == len(expected_statements)
        for s_exe, s_expe in zip(last_executed, expected_statements):
            assert s_exe[1] == s_expe[1], f"EXE#{str(s_exe[1])}# != EXPE#{s_expe[1]}#"
            exe_cql = self.normalize_cql_statement(s_exe[0])
            expe_cql = self.normalize_cql_statement(s_expe[0])
            assert exe_cql == expe_cql, f"EXE#{exe_cql}# != EXPE#{expe_cql}#"
        return None


STANDARD_ANALYZER = ("index_analyzer", "STANDARD")
LOWER_CASE_ANALYZER = ("case_sensitive", False)
NORMALIZE_ANALYZER = ("normalize", True)
ASCII_ANALYZER = ("ascii", True)