Draken007's picture
Upload 7228 files
2a0bc63 verified
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import datetime, timedelta
import time
from cassandra.query import FETCH_SIZE_UNSET
from cassandra.cqlengine import columns
from cassandra.cqlengine import UnicodeMixin
from cassandra.cqlengine.functions import QueryValue
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator
class StatementException(Exception):
pass
class ValueQuoter(UnicodeMixin):
def __init__(self, value):
self.value = value
def __unicode__(self):
from cassandra.encoder import cql_quote
if isinstance(self.value, (list, tuple)):
return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
elif isinstance(self.value, dict):
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}'
elif isinstance(self.value, set):
return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
return cql_quote(self.value)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value
return False
class InQuoter(ValueQuoter):
def __unicode__(self):
from cassandra.encoder import cql_quote
return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')'
class BaseClause(UnicodeMixin):
def __init__(self, field, value):
self.field = field
self.value = value
self.context_id = None
def __unicode__(self):
raise NotImplementedError
def __hash__(self):
return hash(self.field) ^ hash(self.value)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.field == other.field and self.value == other.value
return False
def __ne__(self, other):
return not self.__eq__(other)
def get_context_size(self):
""" returns the number of entries this clause will add to the query context """
return 1
def set_context_id(self, i):
""" sets the value placeholder that will be used in the query """
self.context_id = i
def update_context(self, ctx):
""" updates the query context with this clauses values """
assert isinstance(ctx, dict)
ctx[str(self.context_id)] = self.value
class WhereClause(BaseClause):
""" a single where statement used in queries """
def __init__(self, field, operator, value, quote_field=True):
"""
:param field:
:param operator:
:param value:
:param quote_field: hack to get the token function rendering properly
:return:
"""
if not isinstance(operator, BaseWhereOperator):
raise StatementException(
"operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator))
)
super(WhereClause, self).__init__(field, value)
self.operator = operator
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
self.quote_field = quote_field
def __unicode__(self):
field = ('"{0}"' if self.quote_field else '{0}').format(self.field)
return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value))
def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator)
def __eq__(self, other):
if super(WhereClause, self).__eq__(other):
return self.operator.__class__ == other.operator.__class__
return False
def get_context_size(self):
return self.query_value.get_context_size()
def set_context_id(self, i):
super(WhereClause, self).set_context_id(i)
self.query_value.set_context_id(i)
def update_context(self, ctx):
if isinstance(self.operator, InOperator):
ctx[str(self.context_id)] = InQuoter(self.value)
else:
self.query_value.update_context(ctx)
class IsNotNullClause(WhereClause):
def __init__(self, field):
super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '')
def __unicode__(self):
field = ('"{0}"' if self.quote_field else '{0}').format(self.field)
return u'{0} {1}'.format(field, self.operator)
def update_context(self, ctx):
pass
def get_context_size(self):
return 0
# alias for convenience
IsNotNull = IsNotNullClause
class AssignmentClause(BaseClause):
""" a single variable st statement """
def __unicode__(self):
return u'"{0}" = %({1})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ConditionalClause(BaseClause):
""" A single variable iff statement """
def __unicode__(self):
return u'"{0}" = %({1})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ContainerUpdateTypeMapMeta(type):
def __init__(cls, name, bases, dct):
if not hasattr(cls, 'type_map'):
cls.type_map = {}
else:
cls.type_map[cls.col_type] = cls
super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct)
class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta):
def __init__(self, field, value, operation=None, previous=None):
super(ContainerUpdateClause, self).__init__(field, value)
self.previous = previous
self._assignments = None
self._operation = operation
self._analyzed = False
def _analyze(self):
raise NotImplementedError
def get_context_size(self):
raise NotImplementedError
def update_context(self, ctx):
raise NotImplementedError
class SetUpdateClause(ContainerUpdateClause):
""" updates a set collection """
col_type = columns.Set
_additions = None
_removals = None
def __unicode__(self):
qs = []
ctx_id = self.context_id
if (self.previous is None and
self._assignments is None and
self._additions is None and
self._removals is None):
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
if self._assignments is not None:
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._additions is not None:
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._removals is not None:
qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)]
return ', '.join(qs)
def _analyze(self):
""" works out the updates to be performed """
if self.value is None or self.value == self.previous:
pass
elif self._operation == "add":
self._additions = self.value
elif self._operation == "remove":
self._removals = self.value
elif self.previous is None:
self._assignments = self.value
else:
# partial update time
self._additions = (self.value - self.previous) or None
self._removals = (self.previous - self.value) or None
self._analyzed = True
def get_context_size(self):
if not self._analyzed:
self._analyze()
if (self.previous is None and
not self._assignments and
self._additions is None and
self._removals is None):
return 1
return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))
def update_context(self, ctx):
if not self._analyzed:
self._analyze()
ctx_id = self.context_id
if (self.previous is None and
self._assignments is None and
self._additions is None and
self._removals is None):
ctx[str(ctx_id)] = set()
if self._assignments is not None:
ctx[str(ctx_id)] = self._assignments
ctx_id += 1
if self._additions is not None:
ctx[str(ctx_id)] = self._additions
ctx_id += 1
if self._removals is not None:
ctx[str(ctx_id)] = self._removals
class ListUpdateClause(ContainerUpdateClause):
""" updates a list collection """
col_type = columns.List
_append = None
_prepend = None
def __unicode__(self):
if not self._analyzed:
self._analyze()
qs = []
ctx_id = self.context_id
if self._assignments is not None:
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._prepend is not None:
qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)]
ctx_id += 1
if self._append is not None:
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
return ', '.join(qs)
def get_context_size(self):
if not self._analyzed:
self._analyze()
return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))
def update_context(self, ctx):
if not self._analyzed:
self._analyze()
ctx_id = self.context_id
if self._assignments is not None:
ctx[str(ctx_id)] = self._assignments
ctx_id += 1
if self._prepend is not None:
ctx[str(ctx_id)] = self._prepend
ctx_id += 1
if self._append is not None:
ctx[str(ctx_id)] = self._append
def _analyze(self):
""" works out the updates to be performed """
if self.value is None or self.value == self.previous:
pass
elif self._operation == "append":
self._append = self.value
elif self._operation == "prepend":
self._prepend = self.value
elif self.previous is None:
self._assignments = self.value
elif len(self.value) < len(self.previous):
# if elements have been removed,
# rewrite the whole list
self._assignments = self.value
elif len(self.previous) == 0:
# if we're updating from an empty
# list, do a complete insert
self._assignments = self.value
else:
# the max start idx we want to compare
search_space = len(self.value) - max(0, len(self.previous) - 1)
# the size of the sub lists we want to look at
search_size = len(self.previous)
for i in range(search_space):
# slice boundary
j = i + search_size
sub = self.value[i:j]
idx_cmp = lambda idx: self.previous[idx] == sub[idx]
if idx_cmp(0) and idx_cmp(-1) and self.previous == sub:
self._prepend = self.value[:i] or None
self._append = self.value[j:] or None
break
# if both append and prepend are still None after looking
# at both lists, an insert statement will be created
if self._prepend is self._append is None:
self._assignments = self.value
self._analyzed = True
class MapUpdateClause(ContainerUpdateClause):
""" updates a map collection """
col_type = columns.Map
_updates = None
_removals = None
def _analyze(self):
if self._operation == "update":
self._updates = self.value.keys()
elif self._operation == "remove":
self._removals = {v for v in self.value.keys()}
else:
if self.previous is None:
self._updates = sorted([k for k, v in self.value.items()])
else:
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
self._analyzed = True
def get_context_size(self):
if self.is_assignment:
return 1
return int((len(self._updates or []) * 2) + int(bool(self._removals)))
def update_context(self, ctx):
ctx_id = self.context_id
if self.is_assignment:
ctx[str(ctx_id)] = {}
elif self._removals is not None:
ctx[str(ctx_id)] = self._removals
else:
for key in self._updates or []:
val = self.value.get(key)
ctx[str(ctx_id)] = key
ctx[str(ctx_id + 1)] = val
ctx_id += 2
@property
def is_assignment(self):
if not self._analyzed:
self._analyze()
return self.previous is None and not self._updates and not self._removals
def __unicode__(self):
qs = []
ctx_id = self.context_id
if self.is_assignment:
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)]
elif self._removals is not None:
qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)]
ctx_id += 1
else:
for _ in self._updates or []:
qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)]
ctx_id += 2
return ', '.join(qs)
class CounterUpdateClause(AssignmentClause):
col_type = columns.Counter
def __init__(self, field, value, previous=None):
super(CounterUpdateClause, self).__init__(field, value)
self.previous = previous or 0
def get_context_size(self):
return 1
def update_context(self, ctx):
ctx[str(self.context_id)] = abs(self.value - self.previous)
def __unicode__(self):
delta = self.value - self.previous
sign = '-' if delta < 0 else '+'
return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id)
class BaseDeleteClause(BaseClause):
pass
class FieldDeleteClause(BaseDeleteClause):
""" deletes a field from a row """
def __init__(self, field):
super(FieldDeleteClause, self).__init__(field, None)
def __unicode__(self):
return '"{0}"'.format(self.field)
def update_context(self, ctx):
pass
def get_context_size(self):
return 0
class MapDeleteClause(BaseDeleteClause):
""" removes keys from a map """
def __init__(self, field, value, previous=None):
super(MapDeleteClause, self).__init__(field, value)
self.value = self.value or {}
self.previous = previous or {}
self._analyzed = False
self._removals = None
def _analyze(self):
self._removals = sorted([k for k in self.previous if k not in self.value])
self._analyzed = True
def update_context(self, ctx):
if not self._analyzed:
self._analyze()
for idx, key in enumerate(self._removals):
ctx[str(self.context_id + idx)] = key
def get_context_size(self):
if not self._analyzed:
self._analyze()
return len(self._removals)
def __unicode__(self):
if not self._analyzed:
self._analyze()
return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])
class BaseCQLStatement(UnicodeMixin):
""" The base cql statement class """
def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.context_id = 0
self.context_counter = self.context_id
self.timestamp = timestamp
self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET
self.where_clauses = []
for clause in where or []:
self._add_where_clause(clause)
self.conditionals = []
for conditional in conditionals or []:
self.add_conditional_clause(conditional)
def _update_part_key_values(self, field_index_map, clauses, parts):
for clause in filter(lambda c: c.field in field_index_map, clauses):
parts[field_index_map[clause.field]] = clause.value
def partition_key_values(self, field_index_map):
parts = [None] * len(field_index_map)
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts)
return parts
def add_where(self, column, operator, value, quote_field=True):
value = column.to_database(value)
clause = WhereClause(column.db_field_name, operator, value, quote_field)
self._add_where_clause(clause)
def _add_where_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.where_clauses.append(clause)
def get_context(self):
"""
returns the context dict for this statement
:rtype: dict
"""
ctx = {}
for clause in self.where_clauses or []:
clause.update_context(ctx)
return ctx
def add_conditional_clause(self, clause):
"""
Adds a iff clause to this statement
:param clause: The clause that will be added to the iff statement
:type clause: ConditionalClause
"""
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.conditionals.append(clause)
def _get_conditionals(self):
return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals]))
def get_context_size(self):
return len(self.get_context())
def update_context_id(self, i):
self.context_id = i
self.context_counter = self.context_id
for clause in self.where_clauses:
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
@property
def timestamp_normalized(self):
"""
we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta
:return:
"""
if not self.timestamp:
return None
if isinstance(self.timestamp, int):
return self.timestamp
if isinstance(self.timestamp, timedelta):
tmp = datetime.now() + self.timestamp
else:
tmp = self.timestamp
return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond)
def __unicode__(self):
raise NotImplementedError
def __repr__(self):
return self.__unicode__()
@property
def _where(self):
return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses]))
class SelectStatement(BaseCQLStatement):
""" a cql select statement """
def __init__(self,
table,
fields=None,
count=False,
where=None,
order_by=None,
limit=None,
allow_filtering=False,
distinct_fields=None,
fetch_size=None):
"""
:param where
:type where list of cqlengine.statements.WhereClause
"""
super(SelectStatement, self).__init__(
table,
where=where,
fetch_size=fetch_size
)
self.fields = [fields] if isinstance(fields, str) else (fields or [])
self.distinct_fields = distinct_fields
self.count = count
self.order_by = [order_by] if isinstance(order_by, str) else order_by
self.limit = limit
self.allow_filtering = allow_filtering
def __unicode__(self):
qs = ['SELECT']
if self.distinct_fields:
if self.count:
qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
else:
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))]
elif self.count:
qs += ['COUNT(*)']
else:
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*']
qs += ['FROM', self.table]
if self.where_clauses:
qs += [self._where]
if self.order_by and not self.count:
qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))]
if self.limit:
qs += ['LIMIT {0}'.format(self.limit)]
if self.allow_filtering:
qs += ['ALLOW FILTERING']
return ' '.join(qs)
class AssignmentStatement(BaseCQLStatement):
""" value assignment statements """
def __init__(self,
table,
assignments=None,
where=None,
ttl=None,
timestamp=None,
conditionals=None):
super(AssignmentStatement, self).__init__(
table,
where=where,
conditionals=conditionals
)
self.ttl = ttl
self.timestamp = timestamp
# add assignments
self.assignments = []
for assignment in assignments or []:
self._add_assignment_clause(assignment)
def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i)
for assignment in self.assignments:
assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size()
def partition_key_values(self, field_index_map):
parts = super(AssignmentStatement, self).partition_key_values(field_index_map)
self._update_part_key_values(field_index_map, self.assignments, parts)
return parts
def add_assignment(self, column, value):
value = column.to_database(value)
clause = AssignmentClause(column.db_field_name, value)
self._add_assignment_clause(clause)
def _add_assignment_clause(self, clause):
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.assignments.append(clause)
@property
def is_empty(self):
return len(self.assignments) == 0
def get_context(self):
ctx = super(AssignmentStatement, self).get_context()
for clause in self.assignments:
clause.update_context(ctx)
return ctx
class InsertStatement(AssignmentStatement):
""" an cql insert statement """
def __init__(self,
table,
assignments=None,
where=None,
ttl=None,
timestamp=None,
if_not_exists=False):
super(InsertStatement, self).__init__(table,
assignments=assignments,
where=where,
ttl=ttl,
timestamp=timestamp)
self.if_not_exists = if_not_exists
def __unicode__(self):
qs = ['INSERT INTO {0}'.format(self.table)]
# get column names and context placeholders
fields = [a.insert_tuple() for a in self.assignments]
columns, values = zip(*fields)
qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))]
qs += ['VALUES']
qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))]
if self.if_not_exists:
qs += ["IF NOT EXISTS"]
using_options = []
if self.ttl:
using_options += ["TTL {}".format(self.ttl)]
if self.timestamp:
using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)]
if using_options:
qs += ["USING {}".format(" AND ".join(using_options))]
return ' '.join(qs)
class UpdateStatement(AssignmentStatement):
""" an cql update select statement """
def __init__(self,
table,
assignments=None,
where=None,
ttl=None,
timestamp=None,
conditionals=None,
if_exists=False):
super(UpdateStatement, self). __init__(table,
assignments=assignments,
where=where,
ttl=ttl,
timestamp=timestamp,
conditionals=conditionals)
self.if_exists = if_exists
def __unicode__(self):
qs = ['UPDATE', self.table]
using_options = []
if self.ttl:
using_options += ["TTL {0}".format(self.ttl)]
if self.timestamp:
using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)]
if using_options:
qs += ["USING {0}".format(" AND ".join(using_options))]
qs += ['SET']
qs += [', '.join([str(c) for c in self.assignments])]
if self.where_clauses:
qs += [self._where]
if len(self.conditionals) > 0:
qs += [self._get_conditionals()]
if self.if_exists:
qs += ["IF EXISTS"]
return ' '.join(qs)
def get_context(self):
ctx = super(UpdateStatement, self).get_context()
for clause in self.conditionals:
clause.update_context(ctx)
return ctx
def update_context_id(self, i):
super(UpdateStatement, self).update_context_id(i)
for conditional in self.conditionals:
conditional.set_context_id(self.context_counter)
self.context_counter += conditional.get_context_size()
def add_update(self, column, value, operation=None, previous=None):
value = column.to_database(value)
col_type = type(column)
container_update_type = ContainerUpdateClause.type_map.get(col_type)
if container_update_type:
previous = column.to_database(previous)
clause = container_update_type(column.db_field_name, value, operation, previous)
elif col_type == columns.Counter:
clause = CounterUpdateClause(column.db_field_name, value, previous)
else:
clause = AssignmentClause(column.db_field_name, value)
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates
self._add_assignment_clause(clause)
class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """
def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False):
super(DeleteStatement, self).__init__(
table,
where=where,
timestamp=timestamp,
conditionals=conditionals
)
self.fields = []
if isinstance(fields, str):
fields = [fields]
for field in fields or []:
self.add_field(field)
self.if_exists = if_exists
def update_context_id(self, i):
super(DeleteStatement, self).update_context_id(i)
for field in self.fields:
field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size()
for t in self.conditionals:
t.set_context_id(self.context_counter)
self.context_counter += t.get_context_size()
def get_context(self):
ctx = super(DeleteStatement, self).get_context()
for field in self.fields:
field.update_context(ctx)
for clause in self.conditionals:
clause.update_context(ctx)
return ctx
def add_field(self, field):
if isinstance(field, str):
field = FieldDeleteClause(field)
if not isinstance(field, BaseClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size()
self.fields.append(field)
def __unicode__(self):
qs = ['DELETE']
if self.fields:
qs += [', '.join(['{0}'.format(f) for f in self.fields])]
qs += ['FROM', self.table]
delete_option = []
if self.timestamp:
delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)]
if delete_option:
qs += [" USING {0} ".format(" AND ".join(delete_option))]
if self.where_clauses:
qs += [self._where]
if self.conditionals:
qs += [self._get_conditionals()]
if self.if_exists:
qs += ["IF EXISTS"]
return ' '.join(qs)