import asyncio import json import logging from asyncio import InvalidStateError, Task from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, cast from cassandra.cluster import ResponseFuture, ResultSet from cassandra.query import PreparedStatement, SimpleStatement from cassio.config import check_resolve_keyspace, check_resolve_session from cassio.table.cql import ( CREATE_INDEX_CQL_TEMPLATE, CREATE_TABLE_CQL_TEMPLATE, DELETE_CQL_TEMPLATE, INSERT_ROW_CQL_TEMPLATE, SELECT_CQL_TEMPLATE, TRUNCATE_TABLE_CQL_TEMPLATE, CQLOpType, ) from cassio.table.query import Predicate from cassio.table.table_types import ( ColumnSpecType, RowType, SessionType, normalize_type_desc, ) from cassio.table.utils import ( call_wrapped_async, handle_multicolumn_packing, handle_multicolumn_unpacking, ) class CustomLogger(logging.Logger): def trace(self, msg: str, *args: Any, **kwargs: Any) -> None: if self.isEnabledFor(5): self._log(5, msg, args, **kwargs) logging.addLevelName(5, "TRACE") logging.setLoggerClass(CustomLogger) logger = logging.getLogger(__name__) class BaseTable: ordering_in_partition: Optional[Union[str, List[str]]] = None def __init__( self, table: str, session: Optional[SessionType] = None, keyspace: Optional[str] = None, ttl_seconds: Optional[int] = None, row_id_type: Union[str, List[str]] = ["TEXT"], skip_provisioning: bool = False, async_setup: bool = False, body_index_options: Optional[List[Tuple[str, Any]]] = None, ) -> None: self.session = check_resolve_session(session) self.keyspace = check_resolve_keyspace(keyspace) self.table = table self.ttl_seconds = ttl_seconds self.row_id_type = normalize_type_desc(row_id_type) self.skip_provisioning = skip_provisioning self._prepared_statements: Dict[str, PreparedStatement] = {} self._body_index_options = body_index_options self.db_setup_task: Optional[Task[None]] = None if async_setup: self.db_setup_task = asyncio.create_task(self.adb_setup()) else: self.db_setup() def _schema_row_id(self) -> List[ColumnSpecType]: if len(self.row_id_type) == 1: return [ ("row_id", self.row_id_type[0]), ] else: return [ (f"row_id_{row_i}", row_typ) for row_i, row_typ in enumerate(self.row_id_type) ] def _schema_pk(self) -> List[ColumnSpecType]: return self._schema_row_id() def _schema_cc(self) -> List[ColumnSpecType]: return [] def _schema_da(self) -> List[ColumnSpecType]: return [ ("body_blob", "TEXT"), ] async def _aschema_da(self) -> List[ColumnSpecType]: return self._schema_da() def _schema(self) -> Dict[str, List[ColumnSpecType]]: return { "pk": self._schema_pk(), "cc": self._schema_cc(), "da": self._schema_da(), } async def _aschema(self) -> Dict[str, List[ColumnSpecType]]: return { "pk": self._schema_pk(), "cc": self._schema_cc(), "da": await self._aschema_da(), } def _schema_primary_key(self) -> List[ColumnSpecType]: return self._schema_pk() + self._schema_cc() def _schema_collist(self) -> List[ColumnSpecType]: full_list = self._schema_da() + self._schema_cc() + self._schema_pk() return full_list def _schema_colnameset(self) -> Set[str]: full_list = self._schema_collist() full_set = set(col for col, _ in full_list) assert len(full_list) == len(full_set) return full_set def _desc_table(self) -> str: columns = self._schema() col_str = ( "[(" + ", ".join("%s(%s)" % colspec for colspec in columns["pk"]) + ") " + ", ".join("%s(%s)" % colspec for colspec in columns["cc"]) + "] " + ", ".join("%s(%s)" % colspec for colspec in columns["da"]) ) return col_str def _extract_where_clause_blocks( self, args_dict: Any ) -> Tuple[Any, List[str], Tuple[Any, ...]]: # Removes some of the passed kwargs and returns the remaining, # plus the pieces for a WHERE _allowed_colspecs = self._schema_collist() passed_columns = sorted( [col for col, _ in _allowed_colspecs if col in args_dict] ) residual_args = {k: v for k, v in args_dict.items() if k not in passed_columns} where_clause_blocks = [] where_clause_vals = [] for col in passed_columns: value = args_dict[col] if isinstance(value, Predicate): pred_op_name, pred_value = value.render() where_clause_blocks.append(f"{col} {pred_op_name} %s") where_clause_vals.append(pred_value) else: where_clause_blocks.append(f"{col} = %s") where_clause_vals.append(value) return ( residual_args, where_clause_blocks, tuple(where_clause_vals), ) def _normalize_kwargs(self, args_dict: Dict[str, Any]) -> Dict[str, Any]: new_args_dict = handle_multicolumn_unpacking( args_dict, "row_id", [col for col, _ in self._schema_row_id()], ) return new_args_dict def _normalize_row(self, raw_row: Any) -> Dict[str, Any]: if isinstance(raw_row, dict): dict_row = raw_row else: dict_row = raw_row._asdict() # repacked_row = handle_multicolumn_packing( unpacked_row=dict_row, key_name="row_id", unpacked_keys=[col for col, _ in self._schema_row_id()], ) return repacked_row def _delete(self, is_async: bool, **kwargs: Any) -> Union[None, ResponseFuture]: n_kwargs = self._normalize_kwargs(kwargs) ( rest_kwargs, where_clause_blocks, delete_cql_vals, ) = self._extract_where_clause_blocks(n_kwargs) assert rest_kwargs == {} where_clause = "WHERE " + " AND ".join(where_clause_blocks) delete_cql = DELETE_CQL_TEMPLATE.format( where_clause=where_clause, ) if is_async: return self.execute_cql_async( delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE ) else: self.execute_cql(delete_cql, args=delete_cql_vals, op_type=CQLOpType.WRITE) return None def delete(self, **kwargs: Any) -> None: self._ensure_db_setup() self._delete(is_async=False, **kwargs) return None def delete_async(self, **kwargs: Any) -> ResponseFuture: self._ensure_db_setup() return self._delete(is_async=True, **kwargs) async def adelete(self, **kwargs: Any) -> None: await self._aensure_db_setup() await call_wrapped_async(self.delete_async, **kwargs) def _clear(self, is_async: bool) -> Union[None, ResponseFuture]: truncate_table_cql = TRUNCATE_TABLE_CQL_TEMPLATE.format() if is_async: return self.execute_cql_async( truncate_table_cql, args=tuple(), op_type=CQLOpType.WRITE ) else: self.execute_cql(truncate_table_cql, args=tuple(), op_type=CQLOpType.WRITE) return None def clear(self) -> None: self._ensure_db_setup() self._clear(is_async=False) return None def clear_async(self) -> ResponseFuture: self._ensure_db_setup() return self._clear(is_async=True) async def aclear(self) -> None: await self._aensure_db_setup() await call_wrapped_async(self.clear_async) def _has_index_analyzers(self) -> bool: if not self._body_index_options: return False for option in self._body_index_options: if option[0] == "index_analyzer": return True return False def _extract_index_analyzers( self, args_dict: Any ) -> Tuple[Any, List[str], Tuple[Any, ...]]: rest_args = args_dict.copy() where_clause_blocks: List[str] = [] where_clause_vals: List[Any] = [] if "body_search" in args_dict: if not self._has_index_analyzers(): raise ValueError( "Cannot do body search because no index analyzer " "was configured on the table" ) body_search_texts = rest_args.pop("body_search") if not isinstance(body_search_texts, list): body_search_texts = [body_search_texts] for text in body_search_texts: where_clause_blocks.append("body_blob : %s") where_clause_vals.append(text) return rest_args, where_clause_blocks, tuple(where_clause_vals) def _parse_select_core_params( self, **kwargs: Any ) -> Tuple[str, str, Tuple[Any, ...]]: n_kwargs = self._normalize_kwargs(kwargs) # TODO: work on a columns: Optional[List[str]] = None # (but with nuanced handling of the column-magic we have here) columns = None if columns is None: columns_desc = "*" else: # TODO: handle translations here? # columns_desc = ", ".join(columns) raise NotImplementedError("Column selection is not implemented.") # ( rest_kwargs, where_clause_blocks, select_cql_vals, ) = self._extract_where_clause_blocks(n_kwargs) ( rest_kwargs, analyzer_clause_blocks, analyzer_cql_vals, ) = self._extract_index_analyzers(rest_kwargs) assert rest_kwargs == {} all_where_clauses = where_clause_blocks + analyzer_clause_blocks if not all_where_clauses: where_clause = "" else: where_clause = "WHERE " + " AND ".join(all_where_clauses) return columns_desc, where_clause, select_cql_vals + analyzer_cql_vals def _get_select_cql(self, **kwargs: Any) -> Tuple[str, Tuple[Any, ...]]: columns_desc, where_clause, get_cql_vals = self._parse_select_core_params( **kwargs ) limit_clause = "" limit_cql_vals: List[Any] = [] select_vals = tuple(list(get_cql_vals) + limit_cql_vals) # select_cql = SELECT_CQL_TEMPLATE.format( columns_desc=columns_desc, where_clause=where_clause, limit_clause=limit_clause, ) return select_cql, select_vals def _normalize_result_set( self, result_set: Iterable[RowType] ) -> Optional[Dict[str, Any]]: if isinstance(result_set, ResultSet): result = result_set.one() else: result = None # if result is None: return result else: return self._normalize_row(result) def get(self, **kwargs: Any) -> Optional[RowType]: self._ensure_db_setup() select_cql, select_vals = self._get_select_cql(**kwargs) # dancing around the result set (to comply with type checking): result_set = self.execute_cql( select_cql, args=select_vals, op_type=CQLOpType.READ ) return self._normalize_result_set(result_set) def get_async(self, **kwargs: Any) -> ResponseFuture: raise NotImplementedError("Asynchronous reads are not supported.") async def aget(self, **kwargs: Any) -> Optional[RowType]: await self._aensure_db_setup() select_cql, select_vals = self._get_select_cql(**kwargs) # dancing around the result set (to comply with type checking): result_set = await self.aexecute_cql( select_cql, args=select_vals, op_type=CQLOpType.READ ) return self._normalize_result_set(result_set) def _put(self, is_async: bool, **kwargs: Any) -> Union[None, ResponseFuture]: n_kwargs = self._normalize_kwargs(kwargs) primary_key = self._schema_primary_key() assert set(col for col, _ in primary_key) - set(n_kwargs.keys()) == set() columns = [col for col, _ in self._schema_collist() if col in n_kwargs] columns_desc = ", ".join(columns) insert_cql_vals = [n_kwargs[col] for col in columns] value_placeholders = ", ".join("%s" for _ in columns) # ttl_seconds = ( n_kwargs["ttl_seconds"] if "ttl_seconds" in n_kwargs else self.ttl_seconds ) if ttl_seconds is not None: ttl_spec = "USING TTL %s" ttl_vals = [ttl_seconds] else: ttl_spec = "" ttl_vals = [] # insert_cql_args = tuple(insert_cql_vals + ttl_vals) insert_cql = INSERT_ROW_CQL_TEMPLATE.format( columns_desc=columns_desc, value_placeholders=value_placeholders, ttl_spec=ttl_spec, ) # if is_async: return self.execute_cql_async( insert_cql, args=insert_cql_args, op_type=CQLOpType.WRITE ) else: self.execute_cql(insert_cql, args=insert_cql_args, op_type=CQLOpType.WRITE) return None def put(self, **kwargs: Any) -> None: self._ensure_db_setup() self._put(is_async=False, **kwargs) return None def put_async(self, **kwargs: Any) -> ResponseFuture: self._ensure_db_setup() return self._put(is_async=True, **kwargs) async def aput(self, **kwargs: Any) -> None: await self._aensure_db_setup() await call_wrapped_async(self.put_async, **kwargs) def _get_db_setup_cql(self, schema: Dict[str, List[ColumnSpecType]]) -> str: column_specs = [ f"{col_spec[0]} {col_spec[1]}" for _schema_grp in ["pk", "cc", "da"] for col_spec in schema[_schema_grp] ] pk_spec = ", ".join(col for col, _ in schema["pk"]) cc_spec = ", ".join(col for col, _ in schema["cc"]) primkey_spec = f"( ( {pk_spec} ) {',' if schema['cc'] else ''} {cc_spec} )" table_options = [] if schema["cc"]: if self.ordering_in_partition is None: raise ValueError("Unspecified ordering for clustering column(s)") if isinstance(self.ordering_in_partition, str): _cc_orderings = [self.ordering_in_partition for _ in schema["cc"]] else: # must be a list assert len(self.ordering_in_partition) == len(schema["cc"]) _cc_orderings = self.ordering_in_partition clu_core = ", ".join( f"{col} {ordering}" for (col, _), ordering in zip(schema["cc"], _cc_orderings) ) table_options.append(f"CLUSTERING ORDER BY ({clu_core})") if len(table_options) > 0: options_clause = "WITH " + " AND ".join(table_options) else: options_clause = "" create_table_cql = CREATE_TABLE_CQL_TEMPLATE.format( columns_spec=" ".join(f" {cs}," for cs in column_specs), primkey_spec=primkey_spec, options_clause=options_clause, ) return create_table_cql @staticmethod def _get_create_index_cql( index_name: str, index_column: str, index_options: List[Tuple[str, Any]] ) -> str: options_clause = "" if len(index_options) > 0: formatted_options = [] for option in index_options: key, value = option if isinstance(value, dict): formatted_options.append(f"'{key}': '{json.dumps(value)}'") elif isinstance(value, str): formatted_options.append(f"'{key}': '{value}'") elif isinstance(value, bool): if value: formatted_options.append(f"'{key}': true") else: formatted_options.append(f"'{key}': false") else: raise ValueError("Unsupported index_option format") formatted_options.sort() options_text = ", ".join(formatted_options) # this is double escaped because the cql will go through # another format method before being executed options_clause = f"WITH OPTIONS = {{{{ {options_text} }}}}" return CREATE_INDEX_CQL_TEMPLATE.format( index_name=index_name, index_column=index_column, options_clause=options_clause, ) @staticmethod def _get_create_analyzer_index_cql(index_options: List[Tuple[str, Any]]) -> str: index_name = "idx_body" index_column = "body_blob" return BaseTable._get_create_index_cql( index_name=index_name, index_column=index_column, index_options=index_options, ) def db_setup(self) -> None: create_table_cql = self._get_db_setup_cql(self._schema()) self.execute_cql(create_table_cql, op_type=CQLOpType.SCHEMA) if self._body_index_options: self.execute_cql( self._get_create_analyzer_index_cql(self._body_index_options), op_type=CQLOpType.SCHEMA, ) async def adb_setup(self) -> None: schema = await self._aschema() create_table_cql = self._get_db_setup_cql(schema) await self.aexecute_cql(create_table_cql, op_type=CQLOpType.SCHEMA) if self._body_index_options: await self.aexecute_cql( self._get_create_analyzer_index_cql(self._body_index_options), op_type=CQLOpType.SCHEMA, ) def _ensure_db_setup(self) -> None: if self.db_setup_task: try: self.db_setup_task.result() except InvalidStateError: raise ValueError( "Asynchronous setup of the DB not finished. " "NB: Table sync methods shouldn't be called from the " "event loop. Consider using their async equivalents." ) async def _aensure_db_setup(self) -> None: if self.db_setup_task: await self.db_setup_task def _finalize_cql_semitemplate(self, cql_semitemplate: str) -> str: table_fqname = f"{self.keyspace}.{self.table}" table_name = self.table final_cql = cql_semitemplate.format( table_fqname=table_fqname, table_name=table_name ) return final_cql def _obtain_prepared_statement(self, final_cql: str) -> PreparedStatement: # TODO: improve this placeholder handling _preparable_cql = final_cql.replace("%s", "?") # handle the cache of prepared statements if _preparable_cql not in self._prepared_statements: logger.debug(f'Preparing statement "{_preparable_cql}"') self._prepared_statements[_preparable_cql] = self.session.prepare( _preparable_cql ) return self._prepared_statements[_preparable_cql] def execute_cql( self, cql_semitemplate: str, op_type: CQLOpType, args: Tuple[Any, ...] = tuple(), ) -> Iterable[RowType]: final_cql = self._finalize_cql_semitemplate(cql_semitemplate) # if op_type == CQLOpType.SCHEMA and self.skip_provisioning: # these operations are not executed for this instance: logger.debug(f'Not executing statement "{final_cql}"') return [] if op_type == CQLOpType.SCHEMA: # schema operations are not to be 'prepared' statement = SimpleStatement(final_cql) logger.debug(f'Executing statement "{final_cql}" as simple (unprepared)') else: statement = self._obtain_prepared_statement(final_cql) logger.debug(f'Executing statement "{final_cql}" as prepared') logger.trace(f'Statement "{final_cql}" has args: "{str(args)}"') # type: ignore return cast(Iterable[RowType], self.session.execute(statement, args)) def execute_cql_async( self, cql_semitemplate: str, op_type: CQLOpType, args: Tuple[Any, ...] = tuple(), ) -> ResponseFuture: final_cql = self._finalize_cql_semitemplate(cql_semitemplate) # if op_type == CQLOpType.SCHEMA: raise RuntimeError("Schema operations cannot be asynchronous") statement = self._obtain_prepared_statement(final_cql) logger.debug(f'Executing_async statement "{final_cql}" as prepared') logger.trace(f'Statement "{final_cql}" has args: "{str(args)}"') # type: ignore return self.session.execute_async(statement, args) async def aexecute_cql( self, cql_semitemplate: str, op_type: CQLOpType, args: Tuple[Any, ...] = tuple(), ) -> Iterable[RowType]: final_cql = self._finalize_cql_semitemplate(cql_semitemplate) # if op_type == CQLOpType.SCHEMA and self.skip_provisioning: # these operations are not executed for this instance: logger.debug(f'Not aexecuting statement "{final_cql}"') return [] if op_type == CQLOpType.SCHEMA: # schema operations are not to be 'prepared' statement = SimpleStatement(final_cql) logger.debug(f'aExecuting statement "{final_cql}" as simple (unprepared)') else: statement = self._obtain_prepared_statement(final_cql) logger.debug(f'aExecuting statement "{final_cql}" as prepared') logger.trace(f'Statement "{final_cql}" has args: "{str(args)}"') # type: ignore return cast( Iterable[RowType], await call_wrapped_async(self.session.execute_async, statement, args), )