File size: 7,546 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from cassandra.cluster import ResponseFuture

from cassio.table.cql import DELETE_CQL_TEMPLATE, SELECT_CQL_TEMPLATE, CQLOpType
from cassio.table.table_types import ColumnSpecType, RowType, normalize_type_desc
from cassio.table.utils import (
    call_wrapped_async,
    handle_multicolumn_packing,
    handle_multicolumn_unpacking,
)

from .base_table import BaseTableMixin

PARTITION_ID_TYPE = Union[Any, Tuple[Any]]


class ClusteredMixin(BaseTableMixin):
    def __init__(
        self,
        *pargs: Any,
        partition_id_type: Union[str, List[str]] = ["TEXT"],
        partition_id: Optional[PARTITION_ID_TYPE] = None,
        ordering_in_partition: Union[str, List[str]] = "ASC",
        **kwargs: Any,
    ) -> None:
        self.partition_id_type = normalize_type_desc(partition_id_type)
        self.partition_id = partition_id
        if isinstance(ordering_in_partition, str):
            self.ordering_in_partition = ordering_in_partition.upper()
        else:
            self.ordering_in_partition = [
                ordering.upper() for ordering in ordering_in_partition
            ]
        super().__init__(*pargs, **kwargs)

    def _schema_pk(self) -> List[ColumnSpecType]:
        if len(self.partition_id_type) == 1:
            return [
                ("partition_id", self.partition_id_type[0]),
            ]
        else:
            return [
                (f"partition_id_{pk_i}", pk_typ)
                for pk_i, pk_typ in enumerate(self.partition_id_type)
            ]

    def _schema_cc(self) -> List[ColumnSpecType]:
        return self._schema_row_id()

    def _delete_partition(
        self, is_async: bool, partition_id: Optional[PARTITION_ID_TYPE] = None
    ) -> Union[None, ResponseFuture]:
        _partition_id = self.partition_id if partition_id is None else partition_id
        #
        _pid_dict = handle_multicolumn_unpacking(
            {"partition_id": _partition_id},
            "partition_id",
            [col for col, _ in self._schema_pk()],
        )
        (
            rest_kwargs,
            where_clause_blocks,
            delete_cql_vals,
        ) = self._extract_where_clause_blocks(_pid_dict)
        assert rest_kwargs == {}
        where_clause = "WHERE " + " AND ".join(where_clause_blocks)
        delete_cql = DELETE_CQL_TEMPLATE.format(
            where_clause=where_clause,
        )
        if is_async:
            return self.execute_cql_async(
                delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE
            )
        else:
            self.execute_cql(delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE)
            return None

    def delete_partition(
        self, partition_id: Optional[PARTITION_ID_TYPE] = None
    ) -> None:
        self._delete_partition(is_async=False, partition_id=partition_id)
        return None

    def delete_partition_async(
        self, partition_id: Optional[PARTITION_ID_TYPE] = None
    ) -> ResponseFuture:
        return self._delete_partition(is_async=True, partition_id=partition_id)

    async def adelete_partition(
        self, partition_id: Optional[PARTITION_ID_TYPE] = None
    ) -> None:
        await call_wrapped_async(self.delete_partition_async, partition_id=partition_id)

    def _normalize_kwargs(self, args_dict: Dict[str, Any]) -> Dict[str, Any]:
        # if partition id provided in call, takes precedence over instance value
        arg_pid = args_dict.get("partition_id")
        instance_pid = self.partition_id
        _partition_id = instance_pid if arg_pid is None else arg_pid
        new_args_dict0 = {
            **{"partition_id": _partition_id},
            **args_dict,
        }
        # in case of multicolumn-key schema, do the tuple unpacking:
        new_args_dict = handle_multicolumn_unpacking(
            new_args_dict0,
            "partition_id",
            [col for col, _ in self._schema_pk()],
        )

        return super()._normalize_kwargs(new_args_dict)

    def _normalize_row(self, raw_row: Any) -> Dict[str, Any]:
        pre_normalized = super()._normalize_row(raw_row)
        repacked_row = handle_multicolumn_packing(
            unpacked_row=pre_normalized,
            key_name="partition_id",
            unpacked_keys=[col for col, _ in self._schema_pk()],
        )
        return repacked_row

    def _get_get_partition_cql(
        self,
        partition_id: Optional[PARTITION_ID_TYPE] = None,
        n: Optional[int] = None,
        **kwargs: Any,
    ) -> Tuple[str, Tuple[Any, ...]]:
        _partition_id = self.partition_id if partition_id is None else partition_id
        #
        # TODO: work on a columns: Optional[List[str]] = None
        # (but with nuanced handling of the column-magic we have here)
        columns = None
        if columns is None:
            columns_desc = "*"
        else:
            # TODO: handle translations here?
            # columns_desc = ", ".join(columns)
            raise NotImplementedError("Column selection is not implemented.")
        # WHERE can admit other sources (e.g. medata if the corresponding mixin)
        # so we escalate to standard WHERE-creation route and reinject the partition
        n_kwargs = self._normalize_kwargs(
            {
                **{"partition_id": _partition_id},
                **kwargs,
            }
        )
        (
            rest_kwargs,
            where_clause_blocks,
            select_cql_vals,
        ) = self._extract_where_clause_blocks(n_kwargs)

        # check for exhaustion:
        assert rest_kwargs == {}
        where_clause = "WHERE " + " AND ".join(where_clause_blocks)
        where_cql_vals = list(select_cql_vals)
        #
        if n is None:
            limit_clause = ""
            limit_cql_vals = []
        else:
            limit_clause = "LIMIT %s"
            limit_cql_vals = [n]
        #
        select_cql = SELECT_CQL_TEMPLATE.format(
            columns_desc=columns_desc,
            where_clause=where_clause,
            limit_clause=limit_clause,
        )
        get_p_cql_vals = tuple(where_cql_vals + limit_cql_vals)
        return select_cql, get_p_cql_vals

    def get_partition(
        self,
        partition_id: Optional[PARTITION_ID_TYPE] = None,
        n: Optional[int] = None,
        **kwargs: Any,
    ) -> Iterable[RowType]:
        select_cql, get_p_cql_vals = self._get_get_partition_cql(
            partition_id, n, **kwargs
        )
        return (
            self._normalize_row(raw_row)
            for raw_row in self.execute_cql(
                select_cql,
                args=get_p_cql_vals,
                op_type=CQLOpType.READ,
            )
        )

    def get_partition_async(
        self,
        partition_id: Optional[PARTITION_ID_TYPE] = None,
        n: Optional[int] = None,
        **kwargs: Any,
    ) -> ResponseFuture:
        raise NotImplementedError("Asynchronous reads are not supported.")

    async def aget_partition(
        self,
        partition_id: Optional[PARTITION_ID_TYPE] = None,
        n: Optional[int] = None,
        **kwargs: Any,
    ) -> Iterable[RowType]:
        select_cql, get_p_cql_vals = self._get_get_partition_cql(
            partition_id, n, **kwargs
        )
        return (
            self._normalize_row(raw_row)
            for raw_row in await self.aexecute_cql(
                select_cql,
                args=get_p_cql_vals,
                op_type=CQLOpType.READ,
            )
        )