Spaces:
Running
Running
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union | |
from cassandra.cluster import ResponseFuture | |
from cassio.table.cql import DELETE_CQL_TEMPLATE, SELECT_CQL_TEMPLATE, CQLOpType | |
from cassio.table.table_types import ColumnSpecType, RowType, normalize_type_desc | |
from cassio.table.utils import ( | |
call_wrapped_async, | |
handle_multicolumn_packing, | |
handle_multicolumn_unpacking, | |
) | |
from .base_table import BaseTableMixin | |
PARTITION_ID_TYPE = Union[Any, Tuple[Any]] | |
class ClusteredMixin(BaseTableMixin): | |
def __init__( | |
self, | |
*pargs: Any, | |
partition_id_type: Union[str, List[str]] = ["TEXT"], | |
partition_id: Optional[PARTITION_ID_TYPE] = None, | |
ordering_in_partition: Union[str, List[str]] = "ASC", | |
**kwargs: Any, | |
) -> None: | |
self.partition_id_type = normalize_type_desc(partition_id_type) | |
self.partition_id = partition_id | |
if isinstance(ordering_in_partition, str): | |
self.ordering_in_partition = ordering_in_partition.upper() | |
else: | |
self.ordering_in_partition = [ | |
ordering.upper() for ordering in ordering_in_partition | |
] | |
super().__init__(*pargs, **kwargs) | |
def _schema_pk(self) -> List[ColumnSpecType]: | |
if len(self.partition_id_type) == 1: | |
return [ | |
("partition_id", self.partition_id_type[0]), | |
] | |
else: | |
return [ | |
(f"partition_id_{pk_i}", pk_typ) | |
for pk_i, pk_typ in enumerate(self.partition_id_type) | |
] | |
def _schema_cc(self) -> List[ColumnSpecType]: | |
return self._schema_row_id() | |
def _delete_partition( | |
self, is_async: bool, partition_id: Optional[PARTITION_ID_TYPE] = None | |
) -> Union[None, ResponseFuture]: | |
_partition_id = self.partition_id if partition_id is None else partition_id | |
# | |
_pid_dict = handle_multicolumn_unpacking( | |
{"partition_id": _partition_id}, | |
"partition_id", | |
[col for col, _ in self._schema_pk()], | |
) | |
( | |
rest_kwargs, | |
where_clause_blocks, | |
delete_cql_vals, | |
) = self._extract_where_clause_blocks(_pid_dict) | |
assert rest_kwargs == {} | |
where_clause = "WHERE " + " AND ".join(where_clause_blocks) | |
delete_cql = DELETE_CQL_TEMPLATE.format( | |
where_clause=where_clause, | |
) | |
if is_async: | |
return self.execute_cql_async( | |
delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE | |
) | |
else: | |
self.execute_cql(delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE) | |
return None | |
def delete_partition( | |
self, partition_id: Optional[PARTITION_ID_TYPE] = None | |
) -> None: | |
self._delete_partition(is_async=False, partition_id=partition_id) | |
return None | |
def delete_partition_async( | |
self, partition_id: Optional[PARTITION_ID_TYPE] = None | |
) -> ResponseFuture: | |
return self._delete_partition(is_async=True, partition_id=partition_id) | |
async def adelete_partition( | |
self, partition_id: Optional[PARTITION_ID_TYPE] = None | |
) -> None: | |
await call_wrapped_async(self.delete_partition_async, partition_id=partition_id) | |
def _normalize_kwargs(self, args_dict: Dict[str, Any]) -> Dict[str, Any]: | |
# if partition id provided in call, takes precedence over instance value | |
arg_pid = args_dict.get("partition_id") | |
instance_pid = self.partition_id | |
_partition_id = instance_pid if arg_pid is None else arg_pid | |
new_args_dict0 = { | |
**{"partition_id": _partition_id}, | |
**args_dict, | |
} | |
# in case of multicolumn-key schema, do the tuple unpacking: | |
new_args_dict = handle_multicolumn_unpacking( | |
new_args_dict0, | |
"partition_id", | |
[col for col, _ in self._schema_pk()], | |
) | |
return super()._normalize_kwargs(new_args_dict) | |
def _normalize_row(self, raw_row: Any) -> Dict[str, Any]: | |
pre_normalized = super()._normalize_row(raw_row) | |
repacked_row = handle_multicolumn_packing( | |
unpacked_row=pre_normalized, | |
key_name="partition_id", | |
unpacked_keys=[col for col, _ in self._schema_pk()], | |
) | |
return repacked_row | |
def _get_get_partition_cql( | |
self, | |
partition_id: Optional[PARTITION_ID_TYPE] = None, | |
n: Optional[int] = None, | |
**kwargs: Any, | |
) -> Tuple[str, Tuple[Any, ...]]: | |
_partition_id = self.partition_id if partition_id is None else partition_id | |
# | |
# TODO: work on a columns: Optional[List[str]] = None | |
# (but with nuanced handling of the column-magic we have here) | |
columns = None | |
if columns is None: | |
columns_desc = "*" | |
else: | |
# TODO: handle translations here? | |
# columns_desc = ", ".join(columns) | |
raise NotImplementedError("Column selection is not implemented.") | |
# WHERE can admit other sources (e.g. medata if the corresponding mixin) | |
# so we escalate to standard WHERE-creation route and reinject the partition | |
n_kwargs = self._normalize_kwargs( | |
{ | |
**{"partition_id": _partition_id}, | |
**kwargs, | |
} | |
) | |
( | |
rest_kwargs, | |
where_clause_blocks, | |
select_cql_vals, | |
) = self._extract_where_clause_blocks(n_kwargs) | |
# check for exhaustion: | |
assert rest_kwargs == {} | |
where_clause = "WHERE " + " AND ".join(where_clause_blocks) | |
where_cql_vals = list(select_cql_vals) | |
# | |
if n is None: | |
limit_clause = "" | |
limit_cql_vals = [] | |
else: | |
limit_clause = "LIMIT %s" | |
limit_cql_vals = [n] | |
# | |
select_cql = SELECT_CQL_TEMPLATE.format( | |
columns_desc=columns_desc, | |
where_clause=where_clause, | |
limit_clause=limit_clause, | |
) | |
get_p_cql_vals = tuple(where_cql_vals + limit_cql_vals) | |
return select_cql, get_p_cql_vals | |
def get_partition( | |
self, | |
partition_id: Optional[PARTITION_ID_TYPE] = None, | |
n: Optional[int] = None, | |
**kwargs: Any, | |
) -> Iterable[RowType]: | |
select_cql, get_p_cql_vals = self._get_get_partition_cql( | |
partition_id, n, **kwargs | |
) | |
return ( | |
self._normalize_row(raw_row) | |
for raw_row in self.execute_cql( | |
select_cql, | |
args=get_p_cql_vals, | |
op_type=CQLOpType.READ, | |
) | |
) | |
def get_partition_async( | |
self, | |
partition_id: Optional[PARTITION_ID_TYPE] = None, | |
n: Optional[int] = None, | |
**kwargs: Any, | |
) -> ResponseFuture: | |
raise NotImplementedError("Asynchronous reads are not supported.") | |
async def aget_partition( | |
self, | |
partition_id: Optional[PARTITION_ID_TYPE] = None, | |
n: Optional[int] = None, | |
**kwargs: Any, | |
) -> Iterable[RowType]: | |
select_cql, get_p_cql_vals = self._get_get_partition_cql( | |
partition_id, n, **kwargs | |
) | |
return ( | |
self._normalize_row(raw_row) | |
for raw_row in await self.aexecute_cql( | |
select_cql, | |
args=get_p_cql_vals, | |
op_type=CQLOpType.READ, | |
) | |
) | |